In [2]:
!nvidia-smi

Sat Aug 10 14:34:37 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   30C    P0              30W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [3]:
from dataclasses import dataclass
from functools import partial
from math import floor 
from typing import Callable, Tuple, Any

import jax
from jax import debug
import jax.numpy as jnp
import flax.linen as nn
import optax
from chex import ArrayTree
from qdax.core.containers.repertoire import Repertoire
from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.environments.base_wrappers import QDEnv
from qdax.core.neuroevolution.buffers.buffer import QDTransition, QDMCTransition
#from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer
import flashbax as fbx
import chex
from rein_related import *

from qdax.core.emitters.emitter import Emitter, EmitterState

In [4]:
import os

os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['WANDB_CACHE_DIR'] = '/tmp/wandb_cache'
os.environ['JAX_LOG_COMPILATION'] = '1'

import logging
import time
from dataclasses import dataclass
from functools import partial
from math import floor
from typing import Any, Dict, Tuple, List, Callable
import pickle
from flax import serialization
#logging.basicConfig(level=logging.DEBUG)
import hydra
from omegaconf import OmegaConf, DictConfig
import jax
import jax.numpy as jnp
from hydra.core.config_store import ConfigStore
from qdax.core.map_elites import MAPElites
from qdax.types import RNGKey, Genotype
from qdax.utils.sampling import sampling 
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.core.neuroevolution.networks.networks import MLPMCPG
from qdax.core.emitters.me_mcpg_emitter import MEMCPGConfig, MEMCPGEmitter
#from qdax.core.emitters.rein_emitter_advanced import REINaiveConfig, REINaiveEmitter
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.environments import behavior_descriptor_extractor
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from utils import Config, get_env
from qdax.core.emitters.mutation_operators import isoline_variation
import wandb
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results, plot_2d_map_elites_repertoire
import matplotlib.pyplot as plt
from set_up_brax import get_reward_offset_brax
from qdax import environments_v1, environments
import numpy as np

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline
import jax.numpy as jnp  # Assuming you are using jax.numpy as jnp

In [6]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [10]:
@dataclass
class Config:
    """Configuration from this experiment script
    """
    # Env config
    #alg_name: str
    seed: int
    env_name: str
    episode_length: int
    policy_hidden_layer_sizes: Tuple[int, ...]   
    # ME config
    num_evaluations: int
    num_iterations: int
    no_agents: int
    num_samples: int
    fixed_init_state: bool
    discard_dead: bool
    # Emitter config
    iso_sigma: float
    line_sigma: float
    #crossover_percentage: float
    # Grid config 
    grid_shape: Tuple[int, ...]
    num_init_cvt_samples: int
    num_centroids: int
    # Log config
    log_period: int
    store_repertoire: bool
    store_repertoire_log_period: int
    
    # REINFORCE Parameters
    proportion_mutation_ga : float
    buffer_sample_batch_size : int
    buffer_add_batch_size: int
    adam_optimizer: bool
    learning_rate: float
    discount_rate: float
    #buffer_size: int
    clip_param: float
    no_epochs: int

In [11]:
no_epoch = 1

env_ = "walker2d_uni"

#os.makedirs("grad_steps_experiments/reps=16_new/", exist_ok=True)



config = Config(
    seed=10,
    env_name=env_,
    episode_length=1000,
    policy_hidden_layer_sizes=[128, 128],
    num_evaluations=0,
    num_iterations=4000,
    num_samples=8,
    no_agents=2,
    fixed_init_state=False,
    discard_dead=False,
    grid_shape=[50, 50],
    num_init_cvt_samples=50000,
    num_centroids=1024,
    log_period=400,
    store_repertoire=True,
    store_repertoire_log_period=800,
    iso_sigma=0.005,
    line_sigma=0.05,
    proportion_mutation_ga=0.5,
    buffer_sample_batch_size=2,
    buffer_add_batch_size=2,
    no_epochs=no_epoch,
    #buffer_size=64000,
    adam_optimizer=True,
    learning_rate=3e-4,
    discount_rate=0.99,
    clip_param=0.2
)

random_key = jax.random.PRNGKey(config.seed)

# Init environment
env = get_env(env_)
reset_fn = jax.jit(env.reset)

# Compute the centroids
centroids, random_key = compute_cvt_centroids(
    num_descriptors=env.behavior_descriptor_length,
    num_init_cvt_samples=config.num_init_cvt_samples,
    num_centroids=config.num_centroids,
    minval=0,
    maxval=1,
    random_key=random_key,
)
# Init policy network
policy_layer_sizes = config.policy_hidden_layer_sizes #+ (env.action_size,)
print(policy_layer_sizes)




policy_network = MLPMCPG(
    hidden_layers_size=policy_layer_sizes,
    action_size=env.action_size,
    activation=jax.nn.tanh,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    mean_init=jax.nn.initializers.orthogonal(scale=0.01),
)


    
'''
policy_network = MLPMCPG(
    hidden_layers_size=policy_layer_sizes,
    action_size=env.action_size,
    activation=jax.nn.tanh,
    hidden_init=jax.nn.initializers.variance_scaling(scale=jnp.sqrt(2), mode='fan_in', distribution='uniform'),
    mean_init=jax.nn.initializers.variance_scaling(scale=0.02*jnp.sqrt(2), mode='fan_in', distribution='uniform'),
)
'''


'''
policy_network = MLPMCPG(
    hidden_layers_size=policy_layer_sizes,
    action_size=env.action_size,
    activation=jax.nn.tanh,
    hidden_init=jax.nn.initializers.lecun_uniform(),
    mean_init=jax.nn.initializers.lecun_uniform(),
)
'''



# Init population of controllers

# maybe consider adding two random keys for each policy
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=config.no_agents)
#split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
#keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
fake_batch_obs = jnp.zeros(shape=(config.no_agents, env.observation_size))
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)


[128, 128]


In [24]:
init_params['params']['hidden_layers_0']['kernel'].shape

(1, 18, 128)

In [25]:
duplicated_params = jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), init_params, init_params)

In [26]:
duplicated_params['params']['hidden_layers_0']['kernel'].shape

(2, 18, 128)

In [12]:

envs = ["ant_uni"]
nums = [128, 256, 512, 1024]
grad_steps = [1, 4, 8, 16, 32, 64, 128]
#os.makedirs("grad_steps_experiments/reps=16_new/", exist_ok=True)


for env_ in envs:
    
    #env_dir = f"grad_steps_experiments/no_value/no_agents=16/epoch_1_reps=16/{env_}"
    #os.makedirs(env_dir, exist_ok=True)
    
    for num_ in nums:
        for grad_step in grad_steps:
            config = Config(
                seed=1,
                env_name=env_,
                episode_length=1000,
                policy_hidden_layer_sizes=[128, 128],
                num_evaluations=0,
                num_iterations=4000,
                num_samples=8,
                no_agents=256,
                fixed_init_state=False,
                discard_dead=False,
                grid_shape=[50, 50],
                num_init_cvt_samples=50000,
                num_centroids=1024,
                log_period=400,
                store_repertoire=True,
                store_repertoire_log_period=800,
                iso_sigma=0.005,
                line_sigma=0.05,
                proportion_mutation_ga=0.5,
                buffer_sample_batch_size=2,
                buffer_add_batch_size=256,
                no_epochs=grad_step,
                #buffer_size=1000,
                adam_optimizer=True,
                learning_rate=3e-4,#3e-4,
                discount_rate=0.99,
                clip_param=0.2
            )
            
            env_dir = f"grad_steps_experiments/no_value/buffer/no_agents={256}/trans_sampled={num_}/grad_steps={grad_step}/{env_}"
            os.makedirs(env_dir, exist_ok=True)

            random_key = jax.random.PRNGKey(config.seed)

            # Init environment
            env = get_env(env_)
            reset_fn = jax.jit(env.reset)

            # Compute the centroids
            centroids, random_key = compute_cvt_centroids(
                num_descriptors=env.behavior_descriptor_length,
                num_init_cvt_samples=config.num_init_cvt_samples,
                num_centroids=config.num_centroids,
                minval=0,
                maxval=1,
                random_key=random_key,
            )
            # Init policy network
            policy_layer_sizes = config.policy_hidden_layer_sizes #+ (env.action_size,)
            print(policy_layer_sizes)




            policy_network = MLPMCPG(
                hidden_layers_size=policy_layer_sizes,
                action_size=env.action_size,
                activation=jax.nn.tanh,
                hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
                mean_init=jax.nn.initializers.orthogonal(scale=0.01),
                #hidden_init = jax.nn.initializers.lecun_uniform(),
                #mean_init = jax.nn.initializers.lecun_uniform(),
            )


                
            '''
            policy_network = MLPMCPG(
                hidden_layers_size=policy_layer_sizes,
                action_size=env.action_size,
                activation=jax.nn.tanh,
                hidden_init=jax.nn.initializers.variance_scaling(scale=jnp.sqrt(2), mode='fan_in', distribution='uniform'),
                mean_init=jax.nn.initializers.variance_scaling(scale=0.02*jnp.sqrt(2), mode='fan_in', distribution='uniform'),
            )
            '''


            '''
            policy_network = MLPMCPG(
                hidden_layers_size=policy_layer_sizes,
                action_size=env.action_size,
                activation=jax.nn.tanh,
                hidden_init=jax.nn.initializers.lecun_uniform(),
                mean_init=jax.nn.initializers.lecun_uniform(),
            )
            '''



            # Init population of controllers

            # maybe consider adding two random keys for each policy
            random_key, subkey = jax.random.split(random_key)
            keys = jax.random.split(subkey, num=config.no_agents)
            #split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
            #keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
            fake_batch_obs = jnp.zeros(shape=(config.no_agents, env.observation_size))
            init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

            param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
            print("Number of parameters in policy_network: ", param_count)

            # Define the fonction to play a step with the policy in the environment
            @jax.jit
            def play_step_fn(env_state, policy_params, random_key):
                #random_key, subkey = jax.random.split(random_key)
                actions, logp = policy_network.apply(policy_params, env_state.obs)
                #logp = policy_network.apply(policy_params, env_state.obs, actions, method=policy_network.logp)
                state_desc = env_state.info["state_descriptor"]
                next_state = env.step(env_state, actions)

                transition = QDMCTransition(
                    obs=env_state.obs,
                    next_obs=next_state.obs,
                    rewards=next_state.reward,
                    dones=next_state.done,
                    truncations=next_state.info["truncation"],
                    actions=actions,
                    state_desc=state_desc,
                    next_state_desc=next_state.info["state_descriptor"],
                    logp=logp,
                    #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
                    #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
                )

                return next_state, policy_params, random_key, transition


            # Prepare the scoring function
            bd_extraction_fn = behavior_descriptor_extractor['ant_uni']
            scoring_fn = partial(
                scoring_function,
                episode_length=env.episode_length,
                play_reset_fn=reset_fn,
                play_step_fn=play_step_fn,
                behavior_descriptor_extractor=bd_extraction_fn,
            )
            #reward_offset = get_reward_offset_brax(env, config.env_name)
            #print(f"Reward offset: {reward_offset}")

            me_scoring_fn = partial(
            sampling,
            scoring_fn=scoring_fn,
            num_samples=config.num_samples,
            )

            reward_offset = 0

            # Define a metrics function
            metrics_function = partial(
                default_qd_metrics,
                qd_offset=reward_offset * env.episode_length,
            )

            # Define the PG-emitter config

            me_mcpg_config = MEMCPGConfig(
                proportion_mutation_ga=config.proportion_mutation_ga,
                no_agents=config.no_agents,
                buffer_sample_batch_size=config.buffer_sample_batch_size,
                buffer_add_batch_size=config.buffer_add_batch_size,
                #batch_size=config.batch_size,
                #mini_batch_size=config.mini_batch_size,
                no_epochs=config.no_epochs,
                #buffer_size=config.buffer_size,
                learning_rate=config.learning_rate,
                adam_optimizer=config.adam_optimizer,
                clip_param=config.clip_param,
            )

            variation_fn = partial(
                isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
            )

            me_mcpg_emitter = MEMCPGEmitter(
                config=me_mcpg_config,
                policy_network=policy_network,
                env=env,
                variation_fn=variation_fn,
                )

            '''
            rein_emitter = REINaiveEmitter(
                config=rein_emitter_config,
                policy_network=policy_network,
                env=env,
                )
            '''
            '''
            me_scoring_fn = partial(
                sampling,
                scoring_fn=scoring_fn,
                num_samples=config.num_samples,
            )
            '''

            # Instantiate MAP Elites
            map_elites = MAPElites(
                scoring_function=scoring_fn,
                emitter=me_mcpg_emitter,
                metrics_function=metrics_function,
            )

            #duplicated_params = jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), init_params, init_params)
            fitnesses, descriptors, extra_scores, random_key = scoring_fn(
                init_params, random_key
            )
            
            repertoire = MapElitesRepertoire.init(
                genotypes=init_params,
                fitnesses=fitnesses,
                descriptors=descriptors,
                centroids=centroids,
                extra_scores=extra_scores,
            )
            

            emitter_state, random_key = me_mcpg_emitter.init(
                random_key=random_key,
                repertoire=repertoire,
                genotypes=init_params,
                fitnesses=fitnesses,
                descriptors=descriptors,
                extra_scores=extra_scores,
            )
            
            emitter_state = me_mcpg_emitter.state_update(
                emitter_state=emitter_state,
                repertoire=repertoire,
                genotypes=init_params,
                fitnesses=fitnesses,
                descriptors=descriptors,
                extra_scores={**extra_scores}#, **extra_info},
            )
            
            emitter_state = emitter_state.emitter_states[0]
            
            


            returns = []
            old_params = init_params
            random_key = jax.random.PRNGKey(0)
            for _ in range(250):
                random_keys = jax.random.split(random_key, config.no_agents)
                new_params = me_mcpg_emitter.emitters[0].emit_mcpg(emitter_state, old_params, random_keys)
                
                #duplicated_params = jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), new_params, new_params)
                #jax.debug.print("duplicated params: {}", duplicated_params.bias.shape)
                
                fitnesses, descriptors, extra_scores, random_key = scoring_fn(
                    new_params, random_key
                )
                
                #print(f"obs: {extra_scores['transitions'].obs}")
                emitter_state = me_mcpg_emitter.emitters[0].state_update(
                    emitter_state=emitter_state,
                    repertoire=repertoire,
                    genotypes=new_params,
                    fitnesses=fitnesses,
                    descriptors=descriptors,
                    extra_scores=extra_scores,
                )
                old_params = new_params
                #print(fitnesses)
                print(f"mean fitness: {fitnesses.mean()}")
                returns.append(fitnesses)
                
            returns = jnp.array(returns)  # Assuming 'returns' is already defined as a 2D array

            # Determine the overall min and max fitness values for setting y-axis limits
            ymin = returns.min()
            ymax = returns.max()

            # Plotting
            #policy_index = 0  # Index for the policy you want to plot
            num_plots = 256  # You can change this to the number of plots you need

            # Calculate the number of rows and columns needed to accommodate num_plots
            rows = int(np.ceil(np.sqrt(num_plots)))
            cols = int(np.ceil(num_plots / rows))

            # Set up the figure and axes for the plots
            fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))  # Adjust figsize based on rows and cols

            # Make sure axs is always an array, even when it's just one subplot
            axs = np.array(axs).reshape(-1)

            # Loop through the number of plots you want to create
            for i in range(num_plots):
                if i < len(returns[0]):  # Check if the returns array has enough columns
                    axs[i].plot(returns[:, i])  # Indexing directly into the returns data for each policy
                    axs[i].set_title(f"Policy {i+1}", fontsize=10)
                    axs[i].set_xlabel('Steps', fontsize=8)
                    axs[i].set_ylabel('Returns', fontsize=8)
                    axs[i].set_ylim([ymin, ymax])  # Ensure ymin and ymax are defined or adjust as necessary
                    axs[i].tick_params(axis='both', which='major', labelsize=8)
                else:
                    axs[i].axis('off')  # Turn off unused subplots

            plt.tight_layout()
            # Save each plot in the specified directory with the correct filename
            plot_filename = f"{env_dir}/no_GA.png"
            plt.savefig(plot_filename)
            plt.close(fig)  # Close the plot to free up memory

[128, 128]
Number of parameters in policy_network:  21264


  repertoire = MapElitesRepertoire.init(


trans: QDMCTransition(obs=array([[[ 5.29044211e-01,  1.00000000e+00,  0.00000000e+00, ...,
          8.14421475e-02, -6.96336702e-02,  1.00000000e+00],
        [ 5.19541442e-01,  9.99999344e-01, -4.66448895e-04, ...,
         -6.05668016e-02,  1.53644040e-01,  9.99000072e-01],
        [ 5.08960366e-01,  9.99971449e-01, -4.49728267e-03, ...,
         -3.27226967e-02, -1.18919991e-01,  9.98000026e-01],
        ...,
        [ 3.56704026e-01,  9.99670923e-01,  2.16642555e-04, ...,
          4.26656147e-03, -6.06911555e-02,  3.00000003e-03],
        [ 3.56722116e-01,  9.99672115e-01,  1.53753339e-04, ...,
          2.24845763e-03, -2.28568669e-02,  2.00000009e-03],
        [ 3.56680900e-01,  9.99673247e-01,  1.58998169e-04, ...,
         -8.39541666e-04,  1.06303589e-02,  1.00000005e-03]],

       [[ 5.34505486e-01,  1.00000000e+00,  0.00000000e+00, ...,
         -4.29172665e-02,  8.84915218e-02,  1.00000000e+00],
        [ 5.24497390e-01,  9.99999523e-01, -6.93486189e-04, ...,
         -3.

2024-08-10 14:44:06.491570: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(381): _leading_trailing
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(386): _leading_trailing
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(534): _array2string
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(515): wrapper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(736): array2string
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(1508): _array_r

XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(381): _leading_trailing
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(386): _leading_trailing
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(534): _array2string
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(515): wrapper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(736): array2string
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/numpy/core/arrayprint.py(1508): _array_repr_implementation
  <string>(3): __repr__
  /usr/lib/python3.10/dataclasses.py(241): wrapper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/debugging.py(268): _format_print_callback
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/debugging.py(247): _flat_callback
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/debugging.py(90): debug_callback_impl
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/debugging.py(154): _callback
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(2367): _wrapped_callback
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1151): __call__
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/profiler.py(336): wrapper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1185): _pjit_call_impl_python
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1229): call_impl_cache_miss
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1245): _pjit_call_impl
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/core.py(935): process_primitive
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/core.py(447): bind_with_trace
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/core.py(2740): bind
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(168): _python_pjit_helper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(257): cache_miss
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/traceback_util.py(179): reraise_with_filtered_traceback
  /tmp/ipykernel_2609435/2829918755.py(260): <module>
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3577): run_code
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3517): run_ast_nodes
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3334): run_cell_async
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3130): _run_cell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3075): run_cell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/zmqshell.py(549): run_cell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/ipkernel.py(449): do_execute
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(778): execute_request
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/ipkernel.py(362): execute_request
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(437): dispatch_shell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(534): process_one
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(545): dispatch_queue
  /usr/lib/python3.10/asyncio/events.py(80): _run
  /usr/lib/python3.10/asyncio/base_events.py(1909): _run_once
  /usr/lib/python3.10/asyncio/base_events.py(603): run_forever
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/tornado/platform/asyncio.py(205): start
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelapp.py(739): start
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/traitlets/config/application.py(1075): launch_instance
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel_launcher.py(18): <module>
  /usr/lib/python3.10/runpy.py(86): _run_code
  /usr/lib/python3.10/runpy.py(196): _run_module_as_main
; current tracing scope: custom-call.152; current profiling annotation: XlaModule:#prefix=jit(emit_mcpg)/jit(main)/vmap(jit(_mutation_function_mcpg)),hlo_module=jit_emit_mcpg,program_id=144#.

In [17]:
envs = ["walker2d_uni", "ant_uni"]
nums = [2, 4, 8, 16, 32, 64]
#os.makedirs("grad_steps_experiments/reps=16_new/", exist_ok=True)


for env_ in envs:
    
    #env_dir = f"grad_steps_experiments/no_value/no_agents=16/epoch_1_reps=16/{env_}"
    #os.makedirs(env_dir, exist_ok=True)
    
    for num_ in nums:
        config = Config(
            seed=10,
            env_name=env_,
            episode_length=1000,
            policy_hidden_layer_sizes=[128, 128],
            num_evaluations=0,
            num_iterations=4000,
            num_samples=8,
            no_agents=num_,
            fixed_init_state=False,
            discard_dead=False,
            grid_shape=[50, 50],
            num_init_cvt_samples=50000,
            num_centroids=1024,
            log_period=400,
            store_repertoire=True,
            store_repertoire_log_period=800,
            iso_sigma=0.005,
            line_sigma=0.05,
            proportion_mutation_ga=0.5,
            buffer_sample_batch_size=num_,
            buffer_add_batch_size=num_,
            no_epochs=1,
            #buffer_size=64000,
            adam_optimizer=True,
            learning_rate=3e-4,
            discount_rate=0.99,
            clip_param=0.2
        )
        
        env_dir = f"grad_steps_experiments/no_value/no_buffer/no_agents={num_}/grad_steps=4/{env_}"
        os.makedirs(env_dir, exist_ok=True)

        random_key = jax.random.PRNGKey(config.seed)

        # Init environment
        env = get_env(env_)
        reset_fn = jax.jit(env.reset)

        # Compute the centroids
        centroids, random_key = compute_cvt_centroids(
            num_descriptors=env.behavior_descriptor_length,
            num_init_cvt_samples=config.num_init_cvt_samples,
            num_centroids=config.num_centroids,
            minval=0,
            maxval=1,
            random_key=random_key,
        )
        # Init policy network
        policy_layer_sizes = config.policy_hidden_layer_sizes #+ (env.action_size,)
        print(policy_layer_sizes)




        policy_network = MLPMCPG(
            hidden_layers_size=policy_layer_sizes,
            action_size=env.action_size,
            activation=jax.nn.tanh,
            hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
            mean_init=jax.nn.initializers.orthogonal(scale=0.01),
        )


            
        '''
        policy_network = MLPMCPG(
            hidden_layers_size=policy_layer_sizes,
            action_size=env.action_size,
            activation=jax.nn.tanh,
            hidden_init=jax.nn.initializers.variance_scaling(scale=jnp.sqrt(2), mode='fan_in', distribution='uniform'),
            mean_init=jax.nn.initializers.variance_scaling(scale=0.02*jnp.sqrt(2), mode='fan_in', distribution='uniform'),
        )
        '''


        '''
        policy_network = MLPMCPG(
            hidden_layers_size=policy_layer_sizes,
            action_size=env.action_size,
            activation=jax.nn.tanh,
            hidden_init=jax.nn.initializers.lecun_uniform(),
            mean_init=jax.nn.initializers.lecun_uniform(),
        )
        '''



        # Init population of controllers

        # maybe consider adding two random keys for each policy
        random_key, subkey = jax.random.split(random_key)
        keys = jax.random.split(subkey, num=config.no_agents)
        #split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
        #keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
        fake_batch_obs = jnp.zeros(shape=(config.no_agents, env.observation_size))
        init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

        param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
        print("Number of parameters in policy_network: ", param_count)

        # Define the fonction to play a step with the policy in the environment
        @jax.jit
        def play_step_fn(env_state, policy_params, random_key):
            #random_key, subkey = jax.random.split(random_key)
            actions, logp = policy_network.apply(policy_params, env_state.obs)
            #logp = policy_network.apply(policy_params, env_state.obs, actions, method=policy_network.logp)
            state_desc = env_state.info["state_descriptor"]
            next_state = env.step(env_state, actions)

            transition = TestTransition(
                obs=env_state.obs,
                #next_obs=next_state.obs,
                rewards=next_state.reward,
                dones=next_state.done,
                truncations=next_state.info["truncation"],
                actions=actions,
                state_desc=state_desc,
                #next_state_desc=next_state.info["state_descriptor"],
                logp=logp,
                #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
                #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
            )

            return next_state, policy_params, random_key, transition


        # Prepare the scoring function
        bd_extraction_fn = behavior_descriptor_extractor['ant_uni']
        scoring_fn = partial(
            scoring_function,
            episode_length=env.episode_length,
            play_reset_fn=reset_fn,
            play_step_fn=play_step_fn,
            behavior_descriptor_extractor=bd_extraction_fn,
        )
        #reward_offset = get_reward_offset_brax(env, config.env_name)
        #print(f"Reward offset: {reward_offset}")

        me_scoring_fn = partial(
        sampling,
        scoring_fn=scoring_fn,
        num_samples=config.num_samples,
        )

        reward_offset = 0

        # Define a metrics function
        metrics_function = partial(
            default_qd_metrics,
            qd_offset=reward_offset * env.episode_length,
        )

        # Define the PG-emitter config

        me_mcpg_config = MEMCPGConfig(
            proportion_mutation_ga=config.proportion_mutation_ga,
            no_agents=config.no_agents,
            buffer_sample_batch_size=config.buffer_sample_batch_size,
            buffer_add_batch_size=config.buffer_add_batch_size,
            #batch_size=config.batch_size,
            #mini_batch_size=config.mini_batch_size,
            no_epochs=config.no_epochs,
            #buffer_size=config.buffer_size,
            learning_rate=config.learning_rate,
            adam_optimizer=config.adam_optimizer,
            clip_param=config.clip_param,
        )

        variation_fn = partial(
            isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
        )

        me_mcpg_emitter = MEMCPGEmitter(
            config=me_mcpg_config,
            policy_network=policy_network,
            env=env,
            variation_fn=variation_fn,
            )

        '''
        rein_emitter = REINaiveEmitter(
            config=rein_emitter_config,
            policy_network=policy_network,
            env=env,
            )
        '''
        '''
        me_scoring_fn = partial(
            sampling,
            scoring_fn=scoring_fn,
            num_samples=config.num_samples,
        )
        '''

        # Instantiate MAP Elites
        map_elites = MAPElites(
            scoring_function=scoring_fn,
            emitter=me_mcpg_emitter,
            metrics_function=metrics_function,
        )

        #duplicated_params = jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), init_params, init_params)
        fitnesses, descriptors, extra_scores, random_key = scoring_fn(
            init_params, random_key
        )
        
        repertoire = MapElitesRepertoire.init(
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            extra_scores=extra_scores,
        )
        

        emitter_state, random_key = me_mcpg_emitter.init(
            random_key=random_key,
            repertoire=repertoire,
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )
        
        emitter_state = me_mcpg_emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores={**extra_scores}#, **extra_info},
        )
        
        emitter_state = emitter_state.emitter_states[0]

        returns = []
        old_params = init_params
        random_key = jax.random.PRNGKey(0)
        for _ in range(500):
            random_keys = jax.random.split(random_key, config.no_agents+1)
            
            n_variation = config.no_agents
            rng = random_keys[-1]
            
            if n_variation > 0:
                rng, rng_ = jax.random.split(rng)
                samples =  jax.tree_util.tree_map(
            lambda x: jax.random.choice(rng_, x, shape=(n_variation,), p=1/config.no_agents*jnp.ones(config.no_agents)),
            old_params,
        )
            x1, rng = samples, rng
            
            rng, rng_ = jax.random.split(rng)
            samples_ =  jax.tree_util.tree_map(
            lambda x: jax.random.choice(rng_, x, shape=(n_variation,), p=1/config.no_agents*jnp.ones(config.no_agents)),
            old_params,
            )
            
            x2, rng = samples_, rng
            
            x_variation, rng = variation_fn(x1,x2, rng)
            
            new_params_ = x_variation
            
            
            
            new_params = me_mcpg_emitter.emitters[0].emit_mcpg(emitter_state, new_params_, random_keys[:config.no_agents])
            fitnesses, descriptors, extra_scores, random_key = scoring_fn(
                new_params, random_key
            )
            emitter_state = me_mcpg_emitter.emitters[0].state_update(
                emitter_state=emitter_state,
                repertoire=repertoire,
                genotypes=new_params,
                fitnesses=fitnesses,
                descriptors=descriptors,
                extra_scores=extra_scores,
            )
            old_params = new_params
            print(f"mean fitness: {fitnesses.mean()}")
            returns.append(fitnesses)
            
        returns = jnp.array(returns)  # Assuming 'returns' is already defined as a 2D array
        


        # Determine the overall min and max fitness values for setting y-axis limits
        ymin = returns.min()
        ymax = returns.max()

        # Plotting
        #policy_index = 0  # Index for the policy you want to plot
        num_plots = num_  # You can change this to the number of plots you need

        # Calculate the number of rows and columns needed to accommodate num_plots
        rows = int(np.ceil(np.sqrt(num_plots)))
        cols = int(np.ceil(num_plots / rows))

        # Set up the figure and axes for the plots
        fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))  # Adjust figsize based on rows and cols

        # Make sure axs is always an array, even when it's just one subplot
        axs = np.array(axs).reshape(-1)

        # Loop through the number of plots you want to create
        for i in range(num_plots):
            if i < len(returns[0]):  # Check if the returns array has enough columns
                axs[i].plot(returns[:, i])  # Indexing directly into the returns data for each policy
                axs[i].set_title(f"Policy {i+1}", fontsize=10)
                axs[i].set_xlabel('Steps', fontsize=8)
                axs[i].set_ylabel('Returns', fontsize=8)
                axs[i].set_ylim([ymin, ymax])  # Ensure ymin and ymax are defined or adjust as necessary
                axs[i].tick_params(axis='both', which='major', labelsize=8)
            else:
                axs[i].axis('off')  # Turn off unused subplots

        plt.tight_layout()
        # Save each plot in the specified directory with the correct filename
        plot_filename = f"{env_dir}/GA.png"
        plt.savefig(plot_filename)
        plt.close(fig)  # Close the plot to free up memory

[128, 128]
Number of parameters in policy_network:  19724


  repertoire = MapElitesRepertoire.init(


mean fitness: 101.20379638671875
mean fitness: 221.11297607421875
mean fitness: 235.9285125732422
mean fitness: 313.3321533203125
mean fitness: 338.99774169921875
mean fitness: 353.32244873046875
mean fitness: 371.5552978515625
mean fitness: 409.09820556640625
mean fitness: 416.9298095703125
mean fitness: 221.32861328125
mean fitness: 408.37109375
mean fitness: 343.1261291503906
mean fitness: 358.4908752441406
mean fitness: 346.9027404785156
mean fitness: 363.14813232421875
mean fitness: 353.48199462890625
mean fitness: 257.61083984375
mean fitness: 264.4591369628906
mean fitness: 380.06201171875
mean fitness: 397.9027099609375
mean fitness: 380.66241455078125
mean fitness: 349.48114013671875
mean fitness: 365.24407958984375
mean fitness: 374.01190185546875
mean fitness: 358.41741943359375
mean fitness: 358.4165344238281
mean fitness: 361.4411315917969
mean fitness: 389.627685546875
mean fitness: 462.6044616699219
mean fitness: 454.854736328125
mean fitness: 410.25830078125
mean fitnes

  repertoire = MapElitesRepertoire.init(


mean fitness: 246.82373046875
mean fitness: 311.4052429199219
mean fitness: 313.65338134765625
mean fitness: 327.39794921875
mean fitness: 268.77581787109375
mean fitness: 343.9360656738281
mean fitness: 298.47247314453125
mean fitness: 326.738037109375
mean fitness: 266.7596435546875
mean fitness: 323.3248596191406
mean fitness: 193.0327606201172
mean fitness: 278.9822998046875
mean fitness: 286.07257080078125
mean fitness: 327.9527282714844
mean fitness: 285.991455078125
mean fitness: 327.47235107421875
mean fitness: 454.2737121582031
mean fitness: 334.44342041015625
mean fitness: 351.16546630859375
mean fitness: 371.9578857421875
mean fitness: 411.59130859375
mean fitness: 422.74273681640625
mean fitness: 308.7674560546875
mean fitness: 228.77450561523438
mean fitness: 313.2408142089844
mean fitness: 368.40093994140625
mean fitness: 343.8827209472656
mean fitness: 281.3729248046875
mean fitness: 345.3567199707031
mean fitness: 331.80694580078125
mean fitness: 371.434326171875
mean f

  repertoire = MapElitesRepertoire.init(


mean fitness: 159.47265625
mean fitness: 282.4602355957031
mean fitness: 288.38531494140625
mean fitness: 266.53814697265625
mean fitness: 281.6738586425781
mean fitness: 294.4285888671875
mean fitness: 277.1993408203125
mean fitness: 278.862548828125
mean fitness: 345.6484375
mean fitness: 357.08367919921875
mean fitness: 381.388427734375
mean fitness: 366.80908203125
mean fitness: 332.30401611328125
mean fitness: 326.26904296875
mean fitness: 294.0029296875
mean fitness: 399.1662292480469
mean fitness: 306.5213317871094
mean fitness: 298.62579345703125
mean fitness: 337.5686340332031
mean fitness: 361.5914001464844
mean fitness: 279.4268798828125
mean fitness: 273.18487548828125
mean fitness: 322.8315124511719
mean fitness: 275.3382568359375
mean fitness: 305.2926330566406
mean fitness: 260.9662780761719
mean fitness: 307.3330383300781
mean fitness: 265.240234375
mean fitness: 311.13543701171875
mean fitness: 276.6916809082031
mean fitness: 259.5745849609375
mean fitness: 273.8926696

  repertoire = MapElitesRepertoire.init(


mean fitness: 114.90567016601562
mean fitness: 286.9957275390625
mean fitness: 273.83404541015625
mean fitness: 265.35894775390625
mean fitness: 288.07843017578125
mean fitness: 285.52252197265625
mean fitness: 276.0851745605469
mean fitness: 272.1968078613281
mean fitness: 300.3630676269531
mean fitness: 280.6463623046875
mean fitness: 131.64675903320312
mean fitness: 259.2718505859375
mean fitness: 310.6446228027344
mean fitness: 338.26763916015625
mean fitness: 283.980224609375
mean fitness: 351.22161865234375
mean fitness: 361.25518798828125
mean fitness: 321.7648010253906
mean fitness: 340.54400634765625
mean fitness: 345.75457763671875
mean fitness: 337.65765380859375
mean fitness: 349.2409973144531
mean fitness: 235.78009033203125
mean fitness: 311.9605712890625
mean fitness: 291.1492614746094
mean fitness: 334.53900146484375
mean fitness: 353.06195068359375
mean fitness: 329.34033203125
mean fitness: 335.26373291015625
mean fitness: 327.3946228027344
mean fitness: 365.176330566

  repertoire = MapElitesRepertoire.init(


mean fitness: 138.78884887695312
mean fitness: 270.73333740234375
mean fitness: 271.07330322265625
mean fitness: 263.5583801269531
mean fitness: 273.33453369140625
mean fitness: 293.23291015625
mean fitness: 257.43023681640625
mean fitness: 277.70172119140625
mean fitness: 267.6070556640625
mean fitness: 281.7461853027344
mean fitness: 267.1669921875
mean fitness: 274.1107482910156
mean fitness: 307.90399169921875
mean fitness: 225.34017944335938
mean fitness: 296.95172119140625
mean fitness: 293.20220947265625
mean fitness: 277.4762878417969
mean fitness: 289.5174560546875
mean fitness: 316.9580993652344
mean fitness: 326.87109375
mean fitness: 338.7632141113281
mean fitness: 356.5633239746094
mean fitness: 410.204345703125
mean fitness: 358.09814453125
mean fitness: 354.53607177734375
mean fitness: 389.55120849609375
mean fitness: 406.3512268066406
mean fitness: 398.86883544921875
mean fitness: 326.0578308105469
mean fitness: 376.8964538574219
mean fitness: 323.9906311035156
mean fit

  repertoire = MapElitesRepertoire.init(


mean fitness: 111.52474975585938
mean fitness: 236.91241455078125
mean fitness: 287.92584228515625
mean fitness: 271.850830078125
mean fitness: 292.716552734375
mean fitness: 266.354248046875
mean fitness: 293.7925109863281
mean fitness: 267.16461181640625
mean fitness: 283.5797119140625
mean fitness: 249.76205444335938
mean fitness: 309.3751525878906
mean fitness: 286.0057373046875
mean fitness: 265.4831848144531
mean fitness: 286.2658386230469
mean fitness: 312.8748779296875
mean fitness: 329.60797119140625
mean fitness: 332.60052490234375
mean fitness: 323.9108581542969
mean fitness: 314.9349365234375
mean fitness: 337.54132080078125
mean fitness: 272.50274658203125
mean fitness: 319.7987976074219
mean fitness: 215.1125030517578
mean fitness: 313.5736083984375
mean fitness: 353.7313537597656
mean fitness: 320.0770263671875
mean fitness: 346.9225769042969
mean fitness: 365.8255615234375
mean fitness: 341.5171813964844
mean fitness: 383.18157958984375
mean fitness: 358.5174560546875
m

  repertoire = MapElitesRepertoire.init(


mean fitness: 978.0534057617188
mean fitness: 857.990234375
mean fitness: 955.0418090820312
mean fitness: 537.417724609375
mean fitness: 682.457275390625
mean fitness: 581.46240234375
mean fitness: 657.0904541015625
mean fitness: 623.7232666015625
mean fitness: 683.1290893554688
mean fitness: 676.3953857421875
mean fitness: 804.2655639648438
mean fitness: 737.8770751953125
mean fitness: 396.20361328125
mean fitness: 845.73193359375
mean fitness: 768.076416015625
mean fitness: 886.3348388671875
mean fitness: 866.5972900390625
mean fitness: 859.6617431640625
mean fitness: 821.705078125
mean fitness: 891.9891357421875
mean fitness: 656.6300048828125
mean fitness: 409.3643493652344
mean fitness: 742.237548828125
mean fitness: 548.4837036132812
mean fitness: 292.50506591796875
mean fitness: 463.45989990234375
mean fitness: 575.7135009765625
mean fitness: 611.153076171875
mean fitness: 544.03955078125
mean fitness: 464.9637756347656
mean fitness: 511.6178283691406
mean fitness: 507.403472900

  repertoire = MapElitesRepertoire.init(


mean fitness: 949.5767211914062
mean fitness: 980.8359375
mean fitness: 977.6769409179688
mean fitness: 982.6015625
mean fitness: 977.25830078125
mean fitness: 976.2156982421875
mean fitness: 837.0089721679688
mean fitness: 935.1986083984375
mean fitness: 882.6375122070312
mean fitness: 838.4920043945312
mean fitness: 869.451416015625
mean fitness: 958.2095947265625
mean fitness: 809.783203125
mean fitness: 958.001953125
mean fitness: 965.3832397460938
mean fitness: 931.2962036132812
mean fitness: 885.8930053710938
mean fitness: 970.01171875
mean fitness: 871.906494140625
mean fitness: 920.9986572265625
mean fitness: 905.5587158203125
mean fitness: 752.5625
mean fitness: 888.3526611328125
mean fitness: 883.380859375
mean fitness: 868.3768310546875
mean fitness: 825.1533203125
mean fitness: 786.1097412109375
mean fitness: 768.4002075195312
mean fitness: 843.4575805664062
mean fitness: 782.294189453125
mean fitness: 814.2530517578125
mean fitness: 822.692138671875
mean fitness: 392.62927

  repertoire = MapElitesRepertoire.init(


mean fitness: 955.4833984375
mean fitness: 942.3570556640625
mean fitness: 978.9051513671875
mean fitness: 755.7518920898438
mean fitness: 949.4219360351562
mean fitness: 960.2784423828125
mean fitness: 921.3300170898438
mean fitness: 903.1066284179688
mean fitness: 953.038818359375
mean fitness: 957.1637573242188
mean fitness: 817.2145385742188
mean fitness: 866.9396362304688
mean fitness: 958.18359375
mean fitness: 782.7852172851562
mean fitness: 885.7709350585938
mean fitness: 830.4513549804688
mean fitness: 964.9437255859375
mean fitness: 772.2672119140625
mean fitness: 903.4418334960938
mean fitness: 876.251953125
mean fitness: 935.2109375
mean fitness: 911.9022827148438
mean fitness: 940.9013671875
mean fitness: 888.9413452148438
mean fitness: 933.9766845703125
mean fitness: 832.9078369140625
mean fitness: 879.5755004882812
mean fitness: 916.7008056640625
mean fitness: 937.4151000976562
mean fitness: 949.7883911132812
mean fitness: 924.8409423828125
mean fitness: 943.580444335937

  repertoire = MapElitesRepertoire.init(


mean fitness: 973.4189453125
mean fitness: 983.1883544921875
mean fitness: 904.130615234375
mean fitness: 924.7152099609375
mean fitness: 941.6031494140625
mean fitness: 974.619873046875
mean fitness: 942.78125
mean fitness: 973.48974609375
mean fitness: 964.4776000976562
mean fitness: 986.36181640625
mean fitness: 889.5186767578125
mean fitness: 955.5464477539062
mean fitness: 976.5444946289062
mean fitness: 880.855224609375
mean fitness: 847.667724609375
mean fitness: 907.8348388671875
mean fitness: 962.2066650390625
mean fitness: 948.6849365234375
mean fitness: 965.2679443359375
mean fitness: 962.4400634765625
mean fitness: 933.592041015625
mean fitness: 880.31298828125
mean fitness: 921.70849609375
mean fitness: 933.0811767578125
mean fitness: 953.8823852539062
mean fitness: 935.2974243164062
mean fitness: 924.0590209960938
mean fitness: 859.6145629882812
mean fitness: 932.6071166992188
mean fitness: 968.846923828125
mean fitness: 868.4738159179688
mean fitness: 968.77197265625
mea

  repertoire = MapElitesRepertoire.init(


mean fitness: 972.4139404296875
mean fitness: 986.1680297851562
mean fitness: 937.0501708984375
mean fitness: 974.9779052734375
mean fitness: 960.872314453125
mean fitness: 986.7969970703125
mean fitness: 993.3392333984375
mean fitness: 972.92724609375
mean fitness: 973.10693359375
mean fitness: 952.3115234375
mean fitness: 988.61328125
mean fitness: 919.8995971679688
mean fitness: 979.3229370117188
mean fitness: 961.6806640625
mean fitness: 960.2367553710938
mean fitness: 928.983154296875
mean fitness: 961.2230224609375
mean fitness: 964.3626708984375
mean fitness: 903.7022705078125
mean fitness: 956.4638061523438
mean fitness: 975.6012573242188
mean fitness: 990.6779174804688
mean fitness: 950.3560791015625
mean fitness: 966.5010375976562
mean fitness: 991.116455078125
mean fitness: 975.43994140625
mean fitness: 1005.52734375
mean fitness: 978.516357421875
mean fitness: 998.5469360351562
mean fitness: 1029.0869140625
mean fitness: 1007.1132202148438
mean fitness: 1028.598388671875
me

  repertoire = MapElitesRepertoire.init(


mean fitness: 963.1318359375
mean fitness: 985.5206298828125
mean fitness: 962.8487548828125
mean fitness: 984.1138916015625
mean fitness: 979.6318359375
mean fitness: 971.547607421875
mean fitness: 976.6240234375
mean fitness: 982.2915649414062
mean fitness: 985.3194580078125
mean fitness: 984.771484375
mean fitness: 941.1282348632812
mean fitness: 967.291748046875
mean fitness: 950.4185180664062
mean fitness: 993.9168090820312
mean fitness: 969.46826171875
mean fitness: 951.8135986328125
mean fitness: 968.2984619140625
mean fitness: 969.197509765625
mean fitness: 974.17724609375
mean fitness: 950.8778076171875
mean fitness: 927.2979736328125
mean fitness: 943.4139404296875
mean fitness: 922.6597290039062
mean fitness: 948.9742431640625
mean fitness: 943.6939086914062
mean fitness: 959.4990234375
mean fitness: 930.4453735351562
mean fitness: 916.223876953125
mean fitness: 930.2914428710938
mean fitness: 919.1533203125
mean fitness: 887.792236328125
mean fitness: 886.1984252929688
mean