In [1]:
from deap import tools, base, algorithms, creator
import diversity_algorithms
import jax
from brax.v1 import envs
from diversity_algorithms.controllers.fixed_structure_nn_flax import SimpleNeuralControllerFlax
import jax.numpy as jnp
from functools import partial
from diversity_algorithms.environments.brax_env import EvaluationFunctor
from diversity_algorithms.algorithms.novelty_search import set_creator
from diversity_algorithms.environments.behavior_descriptors import feet_contact_descriptor
creator.create("FitnessMax", base.Fitness, weights=(1.0,)*1)
import numpy as np
creator.create("Individual", np.ndarray, fitness=creator.FitnessMax)
set_creator(creator)

#!/usr/bin python -w

from scipy.spatial import KDTree
import numpy as np

import pickle

from deap import tools, base, algorithms

from diversity_algorithms.algorithms.utils import *
from diversity_algorithms.analysis.population_analysis import *
from diversity_algorithms.analysis.data_utils import *

import alphashape
from shapely.geometry import Point, Polygon, LineString

import jax
from jax import numpy as jnp
from diversity_algorithms.algorithms.jax_utils import *
from diversity_algorithms.algorithms.quality_diversity import *
from diversity_algorithms.environments.brax_env import create

from IPython.display import HTML
from brax.v1.io import html

%load_ext autoreload
%autoreload 2


In [2]:
env = create("ant-uni", episode_length=300)
random_key = jax.random.PRNGKey(0)

In [3]:
controller = SimpleNeuralControllerFlax(env.observation_size, env.action_size, n_hidden_layers=2, n_neurons_per_hidden=64)
eval_functor = EvaluationFunctor("ant", output="total_reward", controller=controller, bd_function=feet_contact_descriptor)

Environment set to ant


In [30]:
params = {
        "verbosity": None,
        "pop_size": 8192,
        "n_add": -1,
        "initial_seed_size": 8192,
        "variant": "QD",
        "archive_type": "grid",
        "grid_n_bin": -1,
        "unstructured_neighborhood_radius": -1.0,
        "replace_strategy": "fitness",
        "sample_strategy": "random",
        "kdtree_update": "default",
        "env_name": "ant-uni",
        "nb_gen": 20,
        "dump_period_evolvability": 0,
        "extra_evolvability_gens": [],
        "dump_period_offspring": 1,
        "dump_period_population": 1,
        "dump_period_archive_full": 100,
        "dump_period_archive_small": 1,
        "cxpb": 0,
        "mutpb": 1,
        "indpb": 0.1,
        "eta_m": 15.0,
        "min": -1.0,
        "max": 1.0,
        "k_nov": 15,
        "geno_type": "realarray",
        "eval_budget": -1,
        "seed": 0,
        "episode_length": 300,
        "ind_size": controller.n_weights,
        "evolvability_nb_samples": 5000,
}

In [31]:
toolbox = base.Toolbox()
toolbox.register("population", init_pop_controller, controller=eval_functor.get_controller())
toolbox.register("mutate", mutate, eta=params["eta_m"], min_val=params["min"], max_val=params["max"], indpb=params["indpb"])
toolbox.register("mate", cxBLend, alpha=0.5)
toolbox.register("map_eval", eval_functor)

v=str(params["variant"])
variant=v.replace(",","")
if (variant == "NS"): 
    toolbox.register("select", tools.selBest, fit_attr='novelty')
elif (variant == "Fit"):
    toolbox.register("select", tools.selBest, fit_attr='fitness')
else:
    toolbox.register("select", tools.selNSGA2)

In [32]:
population, random_key = toolbox.population(random_key, params["pop_size"])

In [33]:
fit, bd, random_key = toolbox.map_eval(jnp.asarray(population), random_key)

In [34]:
for ind, f, b in zip(population, fit, bd):
	ind.fitness.values = f
	ind.fit = f
	ind.parent_bd=None
	ind.bd=b
	ind.id = generate_uuid()
	ind.parent_id = None
	ind.dist_parent = -1
	ind.gen_created = 0

for ind in population:
	ind.am_parent=0

In [35]:
if((params["archive_type"] == "unstructured") or (params["archive_type"] == "archive")):
	# If no ball size is given, take a diameter of average size of a dimension / nb_bin
	if(params["unstructured_neighborhood_radius"] < 0):
		#Fetch behavior space dimensions
		gridinfo = registered_environments[params["env_name"]]["grid_features"]
		avg_dim_sizes = np.mean(np.array(gridinfo["max_x"]) - np.array(gridinfo["min_x"]))
		params["unstructured_neighborhood_radius"] = avg_dim_sizes / (2*gridinfo["nb_bin"])
		print("Unstructured archive replace radius autoset to %f" % params["unstructured_neighborhood_radius"])
	archive = UnstructuredArchive(population, r_ball_replace=params["unstructured_neighborhood_radius"], replace_strategy=replace_strategies[params["replace_strategy"]], k_nov_knn=params["k_nov"], kd_update_scheme=params["kdtree_update"])
elif(params["archive_type"] == "grid"):
	#Fetch behavior space dimensions
	gridinfo = registered_environments[params["env_name"]]["grid_features"]
	dim_ranges = list(zip(gridinfo["min_x"],gridinfo["max_x"]))
	if(params["grid_n_bin"] <= 0):
		params["grid_n_bin"] = gridinfo["nb_bin"] # If no specific discretization is given, take the environment default
		print("Archive grid bin number autoset to %d" % params["grid_n_bin"])
	archive = StructuredGrid(population, bins_per_dim=params["grid_n_bin"], dims_ranges=dim_ranges, replace_strategy=replace_strategies[params["replace_strategy"]], compute_novelty=True, k_nov_knn=params["k_nov"], kd_update_scheme=params["kdtree_update"])
else:
	raise RuntimeError("Unknown archive type %s" % params["archive_type"])

Archive grid bin number autoset to 5


In [36]:
seed_population = archive.get_content_as_list()

Start of the loop

In [37]:
gen = 0

In [38]:
population = archive.sample_archive(params["pop_size"], strategy=params["sample_strategy"])
parents = population

In [39]:
# Mutate the geneotypes
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, len(parents))
mutate_gen = jax.vmap(toolbox.mutate)(keys, jnp.asarray(parents))
	
# Create the offsprings
offspring = [creator.Individual([x]) for x in np.asarray(mutate_gen)]
for i in range(len(offspring)):
	offspring[i] =  offspring[i][0]
	offspring[i].fitness = creator.FitnessMax()
	offspring[i].bd = parents[i].bd
	offspring[i].id = parents[i].id


In [40]:
fit, bd, random_key = toolbox.map_eval(jnp.array(population), random_key)

In [41]:
for ind, f, b in zip(offspring, fit, bd):
	ind.fitness.values = f
	ind.fit = f
	ind.parent_bd = ind.bd
	ind.bd = b
	ind.parent_id = ind.id
	ind.id = generate_uuid()
	ind.am_parent = 0
	ind.dist_parent = get_bd_dist_to_parent(ind)
	ind.gen_created = gen

In [42]:
if(len(offspring)) < params["n_add"]:
	print("WARNING: Not enough parents sampled to get %d offspring; will complete with %d random individuals" % (params["n_add"], params["n_add"]-len(offspring)))
	extra_random_indivs, random_key = toolbox.population(random_key, params["n_add"] - len(offspring))
	extrat_fit, extra_bd, random_key = toolbox.map_eval(jnp.array(extra_random_indivs), random_key)
 
	for ind, f, b in zip(extra_random_indivs, extrat_fit, extra_bd):
		ind.fitness.values = f
		ind.fit = f
		ind.parent_bd = None
		ind.bd = b
		ind.id = generate_uuid()
		ind.parent_id = None
		ind.am_parent = 0
		ind.dist_parent = -1
		ind.gen_created = gen
	offspring += extra_random_indivs
 
for ind in parents:
	ind.am_parent=1
for ind in offspring:
	ind.am_parent=0

In [43]:
n_added = 0
for ind in offspring:
	if(archive.try_add(ind)):
		n_added += 1
  
# Rebuild novelty for whole archive
archive.update_novelty()

In [44]:
import time

In [45]:
for gen in range(10):
    start = time.time()
    parents = archive.sample_archive(params["pop_size"], strategy=params["sample_strategy"])
    print("Sampling took %f" % (time.time() - start))

    # Mutate the geneotypes
    start = time.time()
    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.split(subkey, len(parents))
    mutate_gen = jax.vmap(toolbox.mutate)(keys, jnp.asarray(parents))
    
    # Create the offsprings
    offspring = [creator.Individual([x]) for x in np.asarray(mutate_gen)]
    for i in range(len(offspring)):
        offspring[i] =  offspring[i][0]
        offspring[i].fitness = creator.FitnessMax()
        offspring[i].bd = parents[i].bd
        offspring[i].id = parents[i].id
    print("Mutating took %f" % (time.time() - start))
    
    start = time.time()
    fit, bd, random_key = toolbox.map_eval(jnp.array(offspring), random_key)
    print("Evaluating took %f" % (time.time() - start))
    
    start = time.time()
    for ind, f, b in zip(offspring, fit, bd):
        ind.fitness.values = f
        ind.fit = f
        ind.parent_bd = ind.bd
        ind.bd = b
        ind.parent_id = ind.id
        ind.id = generate_uuid()
        ind.am_parent = 0
        ind.dist_parent = get_bd_dist_to_parent(ind)
        ind.gen_created = gen

    if(len(offspring)) < params["n_add"]:
        print("WARNING: Not enough parents sampled to get %d offspring; will complete with %d random individuals" % (params["n_add"], params["n_add"]-len(offspring)))
        extra_random_indivs, random_key = toolbox.population(random_key, params["n_add"] - len(offspring))
        extrat_fit, extra_bd, random_key = toolbox.map_eval(jnp.array(extra_random_indivs), random_key)
    
        for ind, f, b in zip(extra_random_indivs, extrat_fit, extra_bd):
            ind.fitness.values = f
            ind.fit = f
            ind.parent_bd = None
            ind.bd = b
            ind.id = generate_uuid()
            ind.parent_id = None
            ind.am_parent = 0
            ind.dist_parent = -1
            ind.gen_created = gen
        offspring += extra_random_indivs
    
    for ind in parents:
        ind.am_parent=1
    for ind in offspring:
        ind.am_parent=0
    print("Creating offspring took %f" % (time.time() - start))
    
    start = time.time()
    n_added = 0
    for ind in offspring:
        if(archive.try_add(ind)):
            n_added += 1
    print("Adding to archive took %f" % (time.time() - start))
    
    # Rebuild novelty for whole archive
    start = time.time()
    archive.update_novelty()
    print("Novelty update took %f seconds" % (time.time() - start))

Sampling took 0.001275
Mutating took 1.110059
Evaluating took 3.395124
Creating offspring took 0.834028
Adding to archive took 0.066421
Novelty update took 0.006263 seconds
Sampling took 0.000589
Mutating took 0.981590
Evaluating took 3.216330
Creating offspring took 0.629973
Adding to archive took 0.065223
Novelty update took 0.006206 seconds
Sampling took 0.000595
Mutating took 1.296715
Evaluating took 3.174379
Creating offspring took 0.840863
Adding to archive took 0.068599
Novelty update took 0.007073 seconds
Sampling took 0.000752
Mutating took 1.015734
Evaluating took 3.334339
Creating offspring took 0.790992
Adding to archive took 0.066032
Novelty update took 0.006695 seconds
Sampling took 0.000648
Mutating took 0.959250
Evaluating took 3.069427
Creating offspring took 0.835851
Adding to archive took 0.067222
Novelty update took 0.007149 seconds
Sampling took 0.000754
Mutating took 1.210581
Evaluating took 3.068168
Creating offspring took 0.793265
Adding to archive took 0.066487

In [46]:
best = selBest(archive.get_content_as_list(), 1, fit_attr="fitness")[0]
print(best.fitness, best.bd)
best = controller.array_to_dict(jnp.array(best))

(103.70453643798828,) [0.88       1.         0.45       0.98999995]


In [47]:
random_key, subkey = jax.random.split(random_key)
state = jax.jit(env.reset)(subkey)

In [48]:
jit_step = jax.jit(env.step)
jit_inf = jax.jit(controller.predict)

In [49]:
rollout = []
while not state.done:
    rollout.append(state)
    action = jit_inf(best, state.obs)
    state = jit_step(state, action)


In [50]:
HTML(html.render(env.sys, [s.qp for s in rollout]))

In [51]:
from diversity_algorithms.analysis.population_analysis import *

In [52]:
all_bd = [ind.bd for ind in archive.get_content_as_list()]

In [54]:
len(all_bd)

273

In [55]:
np.sum([ind.fitness.values for ind in archive.get_content_as_list()])

26703.78701210022