In [171]:
from deap import tools, base, algorithms, creator
import diversity_algorithms
import jax
from brax.v1 import envs
from diversity_algorithms.controllers.fixed_structure_nn_flax import SimpleNeuralControllerFlax
import jax.numpy as jnp
from functools import partial
from diversity_algorithms.environments.brax_env import EvaluationFunctor
from diversity_algorithms.algorithms.novelty_search import set_creator
from diversity_algorithms.environments.behavior_descriptors import ant_behavior_descriptor
creator.create("FitnessMax", base.Fitness, weights=(1.0,)*1)
import numpy as np
creator.create("Individual", np.ndarray, fitness=creator.FitnessMax)
set_creator(creator)

#!/usr/bin python -w

from scipy.spatial import KDTree
import numpy as np

import pickle

from deap import tools, base, algorithms

from diversity_algorithms.algorithms.utils import *
from diversity_algorithms.analysis.population_analysis import *
from diversity_algorithms.analysis.data_utils import *

from diversity_algorithms.algorithms.novelty_management import *

import alphashape
from shapely.geometry import Point, Polygon, LineString

import jax
from jax import numpy as jnp
from diversity_algorithms.algorithms.jax_utils import *

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload




In [172]:
env = envs.create("ant")
key = jax.random.PRNGKey(0)

In [173]:
controller = SimpleNeuralControllerFlax(env.observation_size, env.action_size, n_hidden_layers=2, n_neurons_per_hidden=64)
eval_functor = EvaluationFunctor("ant", controller, bd_function=ant_behavior_descriptor)

Environment set to ant


In [174]:
params = {"nb_gen": 5,
          "pop_size": 5000,
    		"geno_type":"realarray", 
		  "variant":"NS", 
		  "ind_size":controller.n_weights, 
		  "eta_m":15.0,
		  "indpb":0.1,
		  "mutpb":1,
		  "cxpb":0,
		  "min":-5,
		  "max":5,
		  "k":15,
		  "add_strategy":"random",
		  "lambda_nov":6,
		  "verbosity":"none",
		  "variant":"NS",
    }

In [175]:
class Individual:
	def __init__(self, genotype):
		self.genotype = genotype
		self.fitness = creator.FitnessMax()
		self.bd = None
		self.parent_bd = None
		self.parent_bd = None
		self.am_parent = 0
		self.rank_novelty = None
		self.dist_to_parent = None
		self.evolvability_samples = None

In [176]:
@partial(jax.jit, static_argnames=("eta","min_val", "max_val", "indpb"))
def mutate(gen, random_key, eta, min_val, max_val, indpb):
    """Polynomial mutation as implemented in deap (mutPolynomialBounded).
    """
    # Select the genes to mutate
    random_key, subkey = jax.random.split(random_key)
    mut_id = jnp.arange(gen.shape[0])
    mut_id = jax.random.choice(subkey, mut_id, (int(indpb*gen.shape[0]),), replace=False)

    # Compute the mutation
    mut_var = gen[mut_id]
    delta_1 = (gen[mut_id] - min_val) / (max_val - min_val)
    delta_2 = (max_val - gen[mut_id]) / (max_val - min_val)
    mut_pow = 1.0 / (eta + 1.)

    random_key, subkey = jax.random.split(random_key)
    rands = jax.random.uniform(subkey, mut_var.shape, jnp.float32 ,0, 1)

    val1 = 2.0 * rands + ((1.0 - 2.0*rands) * jnp.power((1.0 - delta_1), (mut_pow + 1.0)))
    val1 = jnp.power(val1, mut_pow) - 1.0
    val2 = 2.0 * (1.0 - rands) + (2.0 * (rands - 0.5) * jnp.power(1.0 - delta_2, eta + 1))
    val2 = 1.0 - jnp.power(val2, mut_pow)

    zero_arr = jnp.zeros_like(mut_var)
    delta_q = jnp.where(rands < 0.5, val1, zero_arr)
    delta_q = jnp.where(rands >= 0.5, val2, delta_q)

    # Check if the mutation has exceeded the minimum/maximum value
    new_val = mut_var + delta_q * (max_val - min_val)
    new_val = jnp.where(new_val < min_val, min_val, new_val)
    new_val = jnp.where(new_val > max_val, max_val, new_val)
    
    # Create the new individual
    new_gen = gen.at[mut_id].set(new_val)
    return new_gen

In [177]:
def varOr(population, random_key, toolbox, lambda_, cxpb, mutpb):
    random_key, mut_key, mate_key, rep_key = jax.random.split(random_key, 4)
    
    # Select the individuals that are going to be mutated
    mut_ind = jnp.arange(len(population))
    mut_ind = jax.random.choice(mut_key, mut_ind, (int(mutpb*lambda_),)).tolist()    # indices of the individuals to mutate
    mut_gen = jnp.asarray([population[x].genotype for x in mut_ind])

    # Mutate the geneotypes
    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.split(subkey, len(mut_ind))
    mutate_gen = jax.vmap(toolbox.mutate)(mut_gen, keys)
    
    # Create the new offspring
    offspring = [toolbox.clone(population[x]) for x in mut_ind]
    for i, gen in enumerate(mutate_gen):
        offspring[i].genotype = gen
        del offspring[i].fitness.values

    return offspring, random_key

In [178]:
def init_pop(size, random_key):
	random_key, subkey = jax.random.split(random_key)
	genotypes = jax.random.uniform(subkey, (size, controller.n_weights,), float ,params["min"], params["max"])
	pop = [Individual(genotype) for genotype in genotypes]
	return pop, random_key

toolbox = base.Toolbox()
toolbox.register("evaluate", eval)
toolbox.register("population", init_pop)
toolbox.register("mutate", mutate, eta=params["eta_m"], min_val=params["min"], max_val=params["max"], indpb=params["indpb"])
toolbox.register("map_eval", eval_functor)
v=str(params["variant"])
variant=v.replace(",","")
if (variant == "NS"): 
    toolbox.register("select", tools.selBest, fit_attr='novelty')
elif (variant == "Fit"):
    toolbox.register("select", tools.selBest, fit_attr='fitness')
elif (variant == "Random"):
    toolbox.register("select", random.sample)
elif (variant == "DistExplArea"):
    toolbox.register("select", tools.selBest, fit_attr='dist_to_explored_area')
else:
    print("Variant not among the authorized variants (NS, Fit, Random, DistExplArea), assuming multi-objective variant")
    toolbox.register("select", tools.selNSGA2)

In [179]:
population, random_key = toolbox.population(5000, key)
invalid_ind = [ind for ind in population if not ind.fitness.valid]
gens = [ind.genotype for ind in invalid_ind]
fitnesses, random_key = toolbox.map_eval(jnp.asarray(gens), random_key)

In [180]:
for ind, fit in zip(invalid_ind, fitnesses):
    ind.fit = fit[0] # fit is an attribute just used to store the fitness value
    ind.parent_bd=None
    ind.bd=fit[1]
    ind.id = generate_uuid()
    ind.parent_id = None

for ind in population:
    ind.am_parent=0
    
archive=updateNovelty(population,population,None,params)
isortednov=sorted(range(len(population)), key=lambda k: population[k].novelty, reverse=True)
varian=params["variant"].replace(",","")


In [181]:
variant=params["variant"]
if ("+" in variant):
    emo=True
else:
    emo=False
for i,ind in enumerate(population):
    #ind.dist_to_explored_area=dist_to_shapes(ind.bd,alpha_shape)
    ind.rank_novelty=isortednov.index(i)
    ind.dist_to_parent=0
    if (emo): 
        if (varian == "NS+Fit"):
            ind.fitness.values = (ind.novelty, ind.fit)
        elif (varian == "NS+BDDistP"):
            ind.fitness.values = (ind.novelty, 0)
        elif (varian == "NS+Fit+BDDistP"):
            ind.fitness.values = (ind.novelty, ind.fit, 0)
        else:
            print("WARNING: unknown variant: "+variant)
            ind.fitness.values=ind.fit
    else:
        ind.fitness.values=ind.fit
    # if it is not a multi-objective experiment, the select tool from DEAP 
    # has been configured above to take the right attribute into account
    # and the fitness.values is thus ignored
gen=0    

In [182]:
lambda_ = 10000

In [188]:
for gen in range(1, 3):
    offspring, random_key = varOr(population, random_key, toolbox, lambda_, params["cxpb"], params["mutpb"])
    invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
    gens = jnp.asarray([ind.genotype for ind in invalid_ind])
    fitnesses, random_key = toolbox.map_eval(gens, random_key)
    
    for ind, fit in zip(invalid_ind, fitnesses):
        ind.fit = fit[0]
        ind.fitness.values = fit[0]
        ind.parent_bd=ind.bd
        ind.parent_id=ind.id
        ind.id = generate_uuid()
        ind.bd = fit[1]

    for ind in population:
        ind.am_parent=1
    for ind in offspring:
        ind.am_parent=0
        
    pq=population+offspring
    pop_for_novelty_estimation=[]
    archive=updateNovelty(pq,offspring,archive,params, pop_for_novelty_estimation)
    #alpha_shape = alphashape.alphashape(archive.all_bd, alphas)
    isortednov=sorted(range(len(pq)), key=lambda k: pq[k].novelty, reverse=True)
    
    for i,ind in enumerate(pq):
        #ind.dist_to_explored_area=dist_to_shapes(ind.bd,alpha_shape)
        ind.rank_novelty=isortednov.index(i)
        #print("Indiv #%d: novelty=%f rank=%d"%(i, ind.novelty, ind.rank_novelty))
        if (ind.parent_bd is None):
            ind.dist_to_parent=0
        else:
            ind.dist_to_parent=np.linalg.norm(ind.bd-ind.parent_bd)
        if (emo):
            if (varian == "NS+Fit"):
                ind.fitness.values = (ind.novelty, ind.fit)
            elif (varian == "NS+BDDistP"):
                if (ind.parent_bd is None):
                    bddistp=0
                else:
                    bddistp=np.linalg.norm(ind.bd - ind.parent_bd)
                ind.fitness.values = (ind.novelty, bddistp)
            elif (varian == "NS+Fit+BDDistP"):
                if (ind.parent_bd is None):
                    bddistp=0
                else:
                    bddistp=np.linalg.norm(ind.bd - ind.parent_bd)
                ind.fitness.values = (ind.novelty, ind.fit, bddistp)
            else:
                print("WARNING: unknown variant: "+variant)
                ind.fitness.values=ind.fit
        else:
            ind.fitness.values=ind.fit
            
    if ("," in variant):
        population[:] = toolbox.select(offspring, params["pop_size"])        
    else:
        population[:] = toolbox.select(pq, params["pop_size"])  

nov


TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.

In [185]:
jax.tree_map(lambda x: x-bd, all_bd)

[Array([0., 0.], dtype=float32),
 Array([ 0.15698725, -0.04733457], dtype=float32),
 Array([-0.09858611,  0.27820536], dtype=float32),
 Array([-0.0309216 ,  0.18486066], dtype=float32),
 Array([0.04180362, 0.12217511], dtype=float32),
 Array([-0.11401062,  0.17287636], dtype=float32),
 Array([-0.05428775,  0.22015467], dtype=float32),
 Array([0.03390821, 0.15923178], dtype=float32),
 Array([0.00722026, 0.01505821], dtype=float32),
 Array([-0.05568022, -0.02349279], dtype=float32),
 Array([-0.03113294,  0.08967122], dtype=float32),
 Array([0.17812544, 0.15989606], dtype=float32),
 Array([0.05786125, 0.02708932], dtype=float32),
 Array([0.01724936, 0.10452253], dtype=float32),
 Array([0.06105787, 0.12738426], dtype=float32),
 Array([ 0.1004561 , -0.07355768], dtype=float32),
 Array([0.15078521, 0.07341236], dtype=float32),
 Array([0.11769016, 0.11750896], dtype=float32),
 Array([0.0780078 , 0.05486313], dtype=float32),
 Array([ 0.07795909, -0.09659717], dtype=float32),
 Array([0.1469988 

In [186]:
archive=updateNovelty(pq,offspring,archive,params, pop_for_novelty_estimation)


nov


TypeError: jnp.linalg.norm requires ndarray or scalar arguments, got <class 'list'> at position 0.