In [438]:
from deap import tools, base, creator
import jax
from diversity_algorithms.controllers.fixed_structure_nn_flax import SimpleNeuralControllerFlax
import jax.numpy as jnp
from diversity_algorithms.environments.brax_env import EvaluationFunctor
from diversity_algorithms.algorithms.novelty_search import set_creator
creator.create("FitnessMax", base.Fitness, weights=(1.0,)*1)
creator.create("Individual", np.ndarray, fitness=creator.FitnessMax)
set_creator(creator)

import numpy as np

from diversity_algorithms.algorithms.utils import *
from diversity_algorithms.analysis.population_analysis import *
from diversity_algorithms.analysis.data_utils import *

import jax
from diversity_algorithms.algorithms.quality_diversity import *
from diversity_algorithms.environments.brax_env import create
from diversity_algorithms.environments.environments import registered_environments

from IPython.display import HTML
from brax.v1.io import html
import time

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload




In [439]:
params = {
        "env_name": "ant-omni",  # "ant-uni" or "ant-omni"
        "verbosity": None,
        "pop_size": 5000,
        "n_add": -1,
        "initial_seed_size": 5000,
        "variant": "QD",
        "archive_type": "grid",
        "grid_n_bin": -1,
        "unstructured_neighborhood_radius": -1.0,
        "replace_strategy": "fitness",
        "sample_strategy": "random",
        "kdtree_update": "default",
        "nb_gen": 20,
        "dump_period_evolvability": 0,
        "extra_evolvability_gens": [],
        "dump_period_offspring": 1,
        "dump_period_population": 1,
        "dump_period_archive_full": 100,
        "dump_period_archive_small": 1,
        "cxpb": 0,
        "mutpb": 1,
        "indpb": 0.1,
        "eta_m": 15.0,
        "min": -1.0,
        "max": 1.0,
        "k_nov": 15,
        "geno_type": "realarray",
        "eval_budget": -1,
        "seed": 0,
        "episode_length": 100,
        "evolvability_nb_samples": 5000,
}

# Initialise the controller and the Evalution function

In [440]:
env_params = registered_environments[params["env_name"]]

In [441]:
random_key = jax.random.PRNGKey(0)
controller = SimpleNeuralControllerFlax(87, 8, 2, 64)
eval_functor = EvaluationFunctor(params["env_name"], 
                                 controller=controller, 
                                 episode_length=params["episode_length"],
                                 bd_function=env_params["bd_func"],
                                 output=env_params["eval_params"]["output"])

Environment set to ant-omni


# Toolbox Initilisation

In [442]:
toolbox = base.Toolbox()
toolbox.register("population", init_pop_controller, controller=controller)
toolbox.register("mutate", mutate, eta=params["eta_m"], min_val=params["min"], max_val=params["max"], indpb=params["indpb"])
toolbox.register("mate", cxBLend, alpha=0.5)
toolbox.register("map_eval", eval_functor)

v=str(params["variant"])
variant=v.replace(",","")
if (variant == "NS"): 
    toolbox.register("select", tools.selBest, fit_attr='novelty')
elif (variant == "Fit"):
    toolbox.register("select", tools.selBest, fit_attr='fitness')
else:
    toolbox.register("select", tools.selNSGA2)

# Initialisation run

In [443]:
population, random_key = toolbox.population(random_key, params["pop_size"])
fit, bd, random_key = toolbox.map_eval(jnp.asarray(population), random_key)

for ind, f, b in zip(population, fit, bd):
	ind.fitness.values = f
	ind.fit = f
	ind.parent_bd=None
	ind.bd=b
	ind.id = generate_uuid()
	ind.parent_id = None
	ind.dist_parent = -1
	ind.gen_created = 0

for ind in population:
	ind.am_parent=0
 
if((params["archive_type"] == "unstructured") or (params["archive_type"] == "archive")):
	# If no ball size is given, take a diameter of average size of a dimension / nb_bin
	if(params["unstructured_neighborhood_radius"] < 0):
		#Fetch behavior space dimensions
		gridinfo = registered_environments[params["env_name"]]["grid_features"]
		avg_dim_sizes = np.mean(np.array(gridinfo["max_x"]) - np.array(gridinfo["min_x"]))
		params["unstructured_neighborhood_radius"] = avg_dim_sizes / (2*gridinfo["nb_bin"])
		print("Unstructured archive replace radius autoset to %f" % params["unstructured_neighborhood_radius"])
	archive = UnstructuredArchive(population, r_ball_replace=params["unstructured_neighborhood_radius"], replace_strategy=replace_strategies[params["replace_strategy"]], k_nov_knn=params["k_nov"], kd_update_scheme=params["kdtree_update"])
elif(params["archive_type"] == "grid"):
	#Fetch behavior space dimensions
	gridinfo = registered_environments[params["env_name"]]["grid_features"]
	dim_ranges = list(zip(gridinfo["min_x"],gridinfo["max_x"]))
	if(params["grid_n_bin"] <= 0):
		params["grid_n_bin"] = gridinfo["nb_bin"] # If no specific discretization is given, take the environment default
		print("Archive grid bin number autoset to %d" % params["grid_n_bin"])
	archive = StructuredGrid(population, bins_per_dim=params["grid_n_bin"], dims_ranges=dim_ranges, replace_strategy=replace_strategies[params["replace_strategy"]], compute_novelty=True, k_nov_knn=params["k_nov"], kd_update_scheme=params["kdtree_update"])
else:
	raise RuntimeError("Unknown archive type %s" % params["archive_type"])

gen = 0

Archive grid bin number autoset to 100


# Learning loop

In [444]:
for gen in range(1, params["nb_gen"]):
    print("Generation %d" % gen)
    start = time.time()
    parents = archive.sample_archive(params["pop_size"], strategy=params["sample_strategy"])
    print("Sampling took %f" % (time.time() - start))

    # Mutate the geneotypes
    start = time.time()
    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.split(subkey, len(parents))
    mutate_gen = jax.vmap(toolbox.mutate)(keys, jnp.asarray(parents))
    
    # Create the offsprings
    offspring = [creator.Individual([x]) for x in np.asarray(mutate_gen)]
    for i in range(len(offspring)):
        offspring[i] =  offspring[i][0]
        offspring[i].fitness = creator.FitnessMax()
        offspring[i].bd = parents[i].bd
        offspring[i].id = parents[i].id
    print("Mutating took %f" % (time.time() - start))
    
    start = time.time()
    fit, bd, random_key = toolbox.map_eval(jnp.array(offspring), random_key)
    print("Evaluating took %f" % (time.time() - start))
    
    for ind, f, b in zip(offspring, fit, bd):
        ind.fitness.values = f
        ind.fit = f
        ind.parent_bd = ind.bd
        ind.bd = b
        ind.parent_id = ind.id
        ind.id = generate_uuid()
        ind.am_parent = 0
        ind.dist_parent = get_bd_dist_to_parent(ind)
        ind.gen_created = gen

    if(len(offspring)) < params["n_add"]:
        print("WARNING: Not enough parents sampled to get %d offspring; will complete with %d random individuals" % (params["n_add"], params["n_add"]-len(offspring)))
        extra_random_indivs, random_key = toolbox.population(random_key, params["n_add"] - len(offspring))
        extrat_fit, extra_bd, random_key = toolbox.map_eval(jnp.array(extra_random_indivs), random_key)
    
        for ind, f, b in zip(extra_random_indivs, extrat_fit, extra_bd):
            ind.fitness.values = f
            ind.fit = f
            ind.parent_bd = None
            ind.bd = b
            ind.id = generate_uuid()
            ind.parent_id = None
            ind.am_parent = 0
            ind.dist_parent = -1
            ind.gen_created = gen
        offspring += extra_random_indivs
    
    for ind in parents:
        ind.am_parent=1
    for ind in offspring:
        ind.am_parent=0
    
    start = time.time()
    n_added = 0
    for ind in offspring:
        if(archive.try_add(ind)):
            n_added += 1
    print("Adding to archive took %f" % (time.time() - start))
    
    # Rebuild novelty for whole archive
    start = time.time()
    archive.update_novelty()
    print("Novelty update took %f seconds" % (time.time() - start))

Generation 1
Sampling took 0.001218
Mutating took 0.442703
Evaluating took 1.117199
Adding to archive took 0.039953
Novelty update took 0.005122 seconds
Generation 2
Sampling took 0.000272
Mutating took 0.426766
Evaluating took 1.107111
Adding to archive took 0.039541
Novelty update took 0.008573 seconds
Generation 3
Sampling took 0.000441
Mutating took 0.458552
Evaluating took 1.121341
Adding to archive took 0.040436
Novelty update took 0.012346 seconds
Generation 4
Sampling took 0.000515
Mutating took 0.464275
Evaluating took 1.115049
Adding to archive took 0.054951
Novelty update took 0.019311 seconds
Generation 5
Sampling took 0.001121
Mutating took 0.508281
Evaluating took 1.134260
Adding to archive took 0.040378
Novelty update took 0.016848 seconds
Generation 6
Sampling took 0.000614
Mutating took 0.457695
Evaluating took 1.110954
Adding to archive took 0.039009
Novelty update took 0.018536 seconds
Generation 7
Sampling took 0.000554
Mutating took 0.464534
Evaluating took 1.77551

In [445]:
best = selBest(archive.get_content_as_list(), 1, fit_attr="fitness")[0]
print(best.fitness, best.bd)
best = controller.array_to_dict(jnp.array(best))

(96.38684844970703,) [0.06173204 0.09549972]


In [446]:
env = create(params["env_name"], episode_length=params["episode_length"])

In [447]:
random_key, subkey = jax.random.split(random_key)
state = jax.jit(env.reset)(subkey)

In [448]:
jit_step = jax.jit(env.step)
jit_inf = jax.jit(controller.predict)

In [449]:
rollout = []
while not state.done:
    rollout.append(state)
    action = jit_inf(best, state.obs)
    state = jit_step(state, action)


In [450]:
HTML(html.render(env.sys, [s.qp for s in rollout]))

In [451]:
from diversity_algorithms.analysis.population_analysis import *

In [457]:
all_bd = [ind.bd for ind in archive.get_content_as_list()]
get_coverage([-15,-15], [15,15], 100, all_bd)

0.118

In [None]:
np.sum([ind.fitness.values for ind in archive.get_content_as_list()])

82822.84254837036