In [150]:
from deap import tools, base, creator, algorithms
import jax
from diversity_algorithms.controllers.fixed_structure_nn_flax import SimpleNeuralControllerFlax
import jax.numpy as jnp
from functools import partial
from diversity_algorithms.environments.brax_env import EvaluationFunctor
from diversity_algorithms.algorithms.novelty_search import set_creator
from diversity_algorithms.environments.behavior_descriptors import feet_contact_descriptor
creator.create("FitnessMax", base.Fitness, weights=(1.0,)*1)
import numpy as np
creator.create("Individual", np.ndarray, typecode="d", fitness=creator.FitnessMax)
set_creator(creator)


from scipy.spatial import KDTree
import numpy as np

import pickle
import time
from diversity_algorithms.algorithms.utils import *
from diversity_algorithms.analysis.population_analysis import *
from diversity_algorithms.analysis.data_utils import *

from diversity_algorithms.algorithms.novelty_management import *

import alphashape
from shapely.geometry import Point, Polygon, LineString
from diversity_algorithms.algorithms.jax_utils import *
from diversity_algorithms.environments.brax_env import create

from IPython.display import HTML
from brax.v1.io import html


%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload




In [151]:
env = create("ant-uni", episode_length=300)
random_key = jax.random.PRNGKey(0)

In [152]:
controller = SimpleNeuralControllerFlax(env.observation_size, env.action_size, n_hidden_layers=2, n_neurons_per_hidden=64)
eval_functor = EvaluationFunctor("ant", controller=controller, bd_function=feet_contact_descriptor)

Environment set to ant


In [153]:
def dist_to_shape(pp, s):
	p=Point(pp)
	d=p.distance(s)
	if (d==0.0):
		d=-p.distance(s.exterior)
	return d

def dist_to_shapes(pp, ls):
	if (not hasattr(ls, '__iter__')):
		ls=[ls] 
	p=Point(pp)
	imin=-1
	dmin=sys.float_info.max
	for i in range(len(ls)):
		d=p.distance(ls[i])
		if (d<dmin):
			imin=i
			dmin=d
	if (dmin==0.0):
		d=-p.distance(ls[i].exterior)
	else:
		d=dmin
	return d

In [154]:
params = {"nb_gen": 5,
		  "pop_size": 5000,
			"geno_type":"realarray", 
		  "variant":"NS", 
		  "ind_size":controller.n_weights, 
		  "eta_m":15.0,
		  "indpb":0.1,
		  "mutpb":1,
		  "cxpb":0,
		  "min":-1,
		  "max":1,
		  "k":15,
		  "add_strategy":"novel",
		  "lambda_nov":300,
		  "verbosity":"none",
		  "variant":"NS",
	}

In [155]:
toolbox = base.Toolbox()
toolbox.register("population", init_pop_controller, controller=controller)
toolbox.register("mutate", mutate, eta=params["eta_m"], min_val=params["min"], max_val=params["max"], indpb=params["indpb"])
toolbox.register("mate", cxBLend, alpha=params["indpb"])
toolbox.register("map_eval", eval_functor)
v=str(params["variant"])
variant=v.replace(",","")
if (variant == "NS"): 
	toolbox.register("select", selBest, fit_attr='novelty')
elif (variant == "Fit"):
	toolbox.register("select", selBest, fit_attr='fitness')
elif (variant == "Random"):
	toolbox.register("select", random.sample)
elif (variant == "DistExplArea"):
	toolbox.register("select", selBest, fit_attr='dist_to_explored_area')
else:
	print("Variant not among the authorized variants (NS, Fit, Random, DistExplArea), assuming multi-objective variant")
	toolbox.register("select", tools.selNSGA2)

In [156]:
population, random_key = toolbox.population(random_key, params["pop_size"])

In [157]:
fit, bd, random_key = toolbox.map_eval(jnp.asarray(population), random_key)

In [158]:
bd_dimension = bd.shape[-1]

for ind, f, b in zip(population, fit, bd):
	ind.fit = f # fit is an attribute just used to store the fitness value
	ind.parent_bd = None
	ind.bd = b
	ind.id = generate_uuid()
	ind.parent_id = None

for ind in population:
	ind.am_parent=0

In [159]:
archive=updateNovelty(population,population,None,params)
alpha_shape = None
if bd_dimension == 2 or bd_dimension == 3:
	alpha_shape = alphashape.alphashape(archive.all_bd, 5)
isortednov=sorted(range(len(population)), key=lambda k: population[k].novelty, reverse=True)
varian=params["variant"].replace(",","")

In [160]:
if ("+" in variant):
	emo=True
else:
	emo=False

In [161]:
for i,ind in enumerate(population):
	if alpha_shape:
		ind.dist_to_explored_area = dist_to_shapes(ind.bd,alpha_shape)
	ind.rank_novelty=isortednov.index(i)
	ind.dist_to_parent=0
	if (emo): 
		if (varian == "NS+Fit"):
			ind.fitness.values = (ind.novelty, ind.fit)
		elif (varian == "NS+BDDistP"):
			ind.fitness.values = (ind.novelty, 0)
		elif (varian == "NS+Fit+BDDistP"):
			ind.fitness.values = (ind.novelty, ind.fit, 0)
		else:
			print("WARNING: unknown variant: "+variant)
			ind.fitness.values=ind.fit
	else:
		ind.fitness.values=ind.fit
	# if it is not a multi-objective experiment, the select tool from DEAP 
	# has been configured above to take the right attribute into account
	# and the fitness.values is thus ignored
gen=0	

In [162]:
lambda_ = 10000

In [163]:
offspring, random_key = varOr(random_key, population, toolbox, lambda_, params["cxpb"], params["mutpb"])

In [164]:
fit, bd, random_key = toolbox.map_eval(jnp.asarray(offspring), random_key)

In [165]:
for ind, f, b in zip(offspring, fit, bd):
	ind.fit = f
	ind.fitness.values = f
	ind.parent_bd = ind.bd
	ind.parent_id = ind.id
	ind.id = generate_uuid()
	ind.bd = b
for ind in population:
	ind.am_parent=1
for ind in offspring:
	ind.am_parent=0
	
pq=population+offspring


pop_for_novelty_estimation=pq

In [166]:
archive = updateNovelty(pq,offspring,archive,params, pop_for_novelty_estimation)
if bd_dimension == 2 or bd_dimension == 3:
	alpha_shape = alphashape.alphashape(archive.all_bd, 5)

In [167]:
nov = [ind.novelty for ind in pq]
isortednov = np.argsort(nov)[::-1]
rank = np.empty_like(isortednov)
rank[isortednov] = np.arange(len(isortednov))

In [168]:
for i,ind in enumerate(pq):
	if bd_dimension == 2 or bd_dimension == 3:
		ind.dist_to_explored_area=dist_to_shapes(ind.bd,alpha_shape)
	ind.rank_novelty = rank[i]
	#print("Indiv #%d: novelty=%f rank=%d"%(i, ind.novelty, ind.rank_novelty))
	if (ind.parent_bd is None):
		ind.dist_to_parent=0
	else:
		ind.dist_to_parent=np.linalg.norm(np.asarray(ind.bd)-np.asarray(ind.parent_bd))
	if (emo):
		if (varian == "NS+Fit"):
			ind.fitness.values = (ind.novelty, ind.fit)
		elif (varian == "NS+BDDistP"):
			if (ind.parent_bd is None):
				bddistp=0
			else:
				bddistp=np.linalg.norm(np.asarray(ind.bd) - np.asarray(ind.parent_bd))
			ind.fitness.values = (ind.novelty, bddistp)
		elif (varian == "NS+Fit+BDDistP"):
			if (ind.parent_bd is None):
				bddistp=0
			else:
				bddistp=np.linalg.norm(np.asarray(ind.bd) - np.asarray(ind.parent_bd))
			ind.fitness.values = (ind.novelty, ind.fit, bddistp)
		else:
			print("WARNING: unknown variant: "+variant)
			ind.fitness.values=ind.fit
	else:
		ind.fitness.values=ind.fit

In [169]:
if ("," in variant):
	population[:] = toolbox.select(offspring, params["pop_size"])		
else:
	population[:] = toolbox.select(pq, params["pop_size"])  

In [170]:
offspring, random_key = varOr(random_key, population, toolbox, lambda_, params["cxpb"], params["mutpb"])

In [171]:
best_ind = []

In [172]:
for gen in range(1, params["nb_gen"] + 1):
	print("Generation %d" % gen)
	t = time.time()
	offspring, random_key = varOr(random_key, population, toolbox, lambda_, params["cxpb"], params["mutpb"])
	print("varOr time: ", time.time() - t)	
	# Evaluate the individuals with an invalid fitness
	t=time.time()
	fit, bd, random_key = toolbox.map_eval(jnp.asarray(offspring), random_key)
	print("map_eval time: ", time.time() - t)

	for ind, f, b in zip(offspring, fit, bd):
		ind.fit = f
		ind.fitness.values = f
		ind.parent_bd = ind.bd
		ind.parent_id = ind.id
		ind.id = generate_uuid()
		ind.bd = b

	for ind in population:
		ind.am_parent=1

	for ind in offspring:
		ind.am_parent=0
		
	pq = population+offspring
	pop_for_novelty_estimation = pq
	t=time.time()
	archive = updateNovelty(pq, offspring, archive, params, pop_for_novelty_estimation)
	print("updateNovelty time: ", time.time() - t)
	if bd_dimension == 2 or bd_dimension == 3:
		alpha_shape = alphashape.alphashape(archive.all_bd, 5)
	
	# Compute the novelty rank
	nov = [ind.novelty for ind in pq]
	isortednov = np.argsort(nov)[::-1]
	rank = np.empty_like(isortednov)
	rank[isortednov] = np.arange(len(isortednov))
	
	for i,ind in enumerate(pq):
		if alpha_shape:
			ind.dist_to_explored_area = dist_to_shapes(ind.bd,alpha_shape)
		ind.rank_novelty = rank[i]
		#print("Indiv #%d: novelty=%f rank=%d"%(i, ind.novelty, ind.rank_novelty))
		if (ind.parent_bd is None):
			ind.dist_to_parent=0
		else:
			ind.dist_to_parent=np.linalg.norm(np.asarray(ind.bd)-np.asarray(ind.parent_bd))
		if (emo):
			if (varian == "NS+Fit"):
				ind.fitness.values = (ind.novelty, ind.fit)
			elif (varian == "NS+BDDistP"):
				if (ind.parent_bd is None):
					bddistp=0
				else:
					bddistp=np.linalg.norm(np.asarray(ind.bd) - np.asarray(ind.parent_bd))
				ind.fitness.values = (ind.novelty, bddistp)
			elif (varian == "NS+Fit+BDDistP"):
				if (ind.parent_bd is None):
					bddistp=0
				else:
					bddistp=np.linalg.norm(np.asarray(ind.bd) - np.asarray(ind.parent_bd))
				ind.fitness.values = (ind.novelty, ind.fit, bddistp)
			else:
				print("WARNING: unknown variant: "+variant)
				ind.fitness.values=ind.fit
		else:
			ind.fitness.values=ind.fit
	if (verbosity(params)):
		print("Gen %d"%(gen))
	else:
		if(gen%100==0):
			print(" %d "%(gen), end='', flush=True)
		elif(gen%10==0):
			print("+", end='', flush=True)
		else:
			print(".", end='', flush=True)
	
	best_ind.append(selBest(pq, 1, fit_attr="fitness")[0])
	# Select the next generation population
	if ("," in variant):
		population[:] = toolbox.select(offspring, params["pop_size"])		
	else:
		population[:] = toolbox.select(pq, params["pop_size"])		


Generation 1
varOr time:  2.103006362915039
map_eval time:  3.561311960220337
updateNovelty time:  4.704352855682373
.Generation 2
varOr time:  1.806886911392212
map_eval time:  3.347842216491699
updateNovelty time:  3.951786756515503
.Generation 3
varOr time:  1.7333312034606934
map_eval time:  3.302612781524658
updateNovelty time:  3.8863656520843506
.Generation 4
varOr time:  1.7177038192749023
map_eval time:  3.3701670169830322
updateNovelty time:  5.020559787750244
.Generation 5
varOr time:  1.7211337089538574
map_eval time:  3.2770793437957764
updateNovelty time:  3.97468900680542
.

In [174]:
archive.size()

2400

In [180]:
best = selBest(best_ind, 1, fit_attr="novelty")[0]
print(best.fitness, best.bd)
best = controller.array_to_dict(best)

(102.31854248046875,) [1.         0.32999998 0.98999995 0.91999996]


In [176]:
state = jax.jit(env.reset)(random_key)

In [177]:
jit_step = jax.jit(env.step)
jit_inf = jax.jit(controller.predict)

In [178]:
rollout = []
while not state.done:
    rollout.append(state)
    action = jit_inf(best, state.obs)
    state = jit_step(state, action)
    

In [179]:
HTML(html.render(env.sys, [s.qp for s in rollout]))