In [1]:
!nvidia-smi

Thu Aug  8 14:55:15 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   31C    P0              29W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
from dataclasses import dataclass
from functools import partial
from math import floor 
from typing import Callable, Tuple, Any

import jax
from jax import debug
import jax.numpy as jnp
import flax.linen as nn
import optax
from chex import ArrayTree
from qdax.core.containers.repertoire import Repertoire
from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.environments.base_wrappers import QDEnv
from qdax.core.neuroevolution.buffers.buffer import QDTransition, QDMCTransition
#from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer
import flashbax as fbx
import chex
from rein_related import *

from qdax.core.emitters.emitter import Emitter, EmitterState

In [3]:
import os

os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['WANDB_CACHE_DIR'] = '/tmp/wandb_cache'
os.environ['JAX_LOG_COMPILATION'] = '1'

import logging
import time
from dataclasses import dataclass
from functools import partial
from math import floor
from typing import Any, Dict, Tuple, List, Callable
import pickle
from flax import serialization
#logging.basicConfig(level=logging.DEBUG)
import hydra
from omegaconf import OmegaConf, DictConfig
import jax
import jax.numpy as jnp
from hydra.core.config_store import ConfigStore
from qdax.core.map_elites import MAPElites
from qdax.types import RNGKey, Genotype
from qdax.utils.sampling import sampling 
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.core.neuroevolution.networks.networks import MLPMCPG, MLPPPO
from qdax.core.emitters.me_mcpg_emitter import MEMCPGConfig, MEMCPGEmitter
from qdax.core.emitters.ppo_me_emitter import PPOMEConfig, PPOMEmitter
#from qdax.core.emitters.rein_emitter_advanced import REINaiveConfig, REINaiveEmitter
from qdax.core.neuroevolution.buffers.buffer import QDTransition, QDMCTransition, PPOTransition
from qdax.environments import behavior_descriptor_extractor
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from utils import Config, get_env
from qdax.core.emitters.mutation_operators import isoline_variation
import wandb
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results, plot_2d_map_elites_repertoire
import matplotlib.pyplot as plt
from set_up_brax import get_reward_offset_brax
from qdax import environments_v1, environments

In [4]:
import matplotlib.pyplot as plt
%matplotlib inline
import jax.numpy as jnp  # Assuming you are using jax.numpy as jnp

In [5]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1024

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1024

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [6]:
@dataclass
class Config:
    """Configuration from this experiment script
    """
    # Env config
    #alg_name: str
    seed: int
    env_name: str
    episode_length: int
    policy_hidden_layer_sizes: Tuple[int, ...]   
    # ME config
    num_evaluations: int
    num_iterations: int
    no_agents: int
    num_samples: int
    fixed_init_state: bool
    discard_dead: bool
    # Emitter config
    iso_sigma: float
    line_sigma: float
    #crossover_percentage: float
    # Grid config 
    grid_shape: Tuple[int, ...]
    num_init_cvt_samples: int
    num_centroids: int
    # Log config
    log_period: int
    store_repertoire: bool
    store_repertoire_log_period: int
    
    # REINFORCE Parameters
    proportion_mutation_ga : float
    buffer_sample_batch_size : int
    buffer_add_batch_size: int
    adam_optimizer: bool
    learning_rate: float
    discount_rate: float
    #buffer_size: int
    clip_param: float
    no_epochs: int
    
    no_neurons: int
    activation: str
    vf_coef: float
    num_minibatches: int
    max_grad_norm: float
    lecun: bool

In [8]:
from dataclasses import replace

@dataclass
class ObsNormalizer:
    size: int
    mean: jnp.ndarray = None
    var: jnp.ndarray = None
    count: jnp.ndarray = 1e-4
    
    def __post_init__(self):
        if self.mean is None:
            self.mean = jnp.zeros(self.size)
        if self.var is None:
            self.var = jnp.ones(self.size)
            
    def update(self, x):
        # Flatten the first two dimensions (x, y) to treat as a single batch dimension
        flat_x = x.reshape(-1, self.size)
        batch_mean = jnp.mean(flat_x, axis=0)
        batch_var = jnp.var(flat_x, axis=0)
        batch_count = flat_x.shape[0]

        new_mean, new_var, new_count = self._update_mean_var_count(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
        
        return replace(self, mean=new_mean, var=new_var, count=new_count)

    def normalize(self, x):
        # Normalize maintaining the original shape, using broadcasting
        return (x - self.mean) / jnp.sqrt(self.var + 1e-8)

    def _update_mean_var_count(self, mean, var, count, batch_mean, batch_var, batch_count):
        delta = batch_mean - mean
        tot_count = count + batch_count

        new_mean = mean + delta * batch_count / tot_count
        m_a = var * count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        return new_mean, new_var, new_count


    def tree_flatten(self):
        return ((self.mean, self.var, self.count), self.size)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        size = aux_data
        mean, var, count = children
        return cls(size=size, mean=mean, var=var, count=count)

# Register Normalizer as a pytree node with JAX
jax.tree_util.register_pytree_node(
    ObsNormalizer,
    ObsNormalizer.tree_flatten,
    ObsNormalizer.tree_unflatten
)


@dataclass
class RewardNormalizer:
    size: int
    mean: jnp.ndarray = 0.0
    var: jnp.ndarray = 1.0
    count: jnp.ndarray = 1e-4
    return_val: jnp.ndarray = None
    
    def __post_init__(self):
        if self.return_val is None:
            self.return_val = jnp.zeros((self.size,))

         
    def update(self, reward, done, gamma=0.99):
        
        def _update_column_scan(carry, x):
            mean, var, count, return_val = carry
            (reward, done) = x
            
            #jax.debug.print("Reward shape: {}", reward.shape)
            
            # Update the return value
            new_return_val = reward + gamma * return_val * (1 - done)
            
            # Update the mean, var, and count
            batch_mean = jnp.mean(new_return_val, axis=0)
            batch_var = jnp.var(new_return_val, axis=0)
            batch_count = new_return_val.shape[0]
            
            delta = batch_mean - mean
            tot_count = count + batch_count
            
            new_mean = mean + delta * batch_count / tot_count
            m_a = var * count
            m_b = batch_var * batch_count
            M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
            new_var = M2 / tot_count
            new_count = tot_count
            
            normalized_reward = reward / jnp.sqrt(new_var + 1e-8)
            
            return (new_mean, new_var, new_count, new_return_val), normalized_reward
        
        (new_mean, new_var, new_count, _), normalized_rewards = jax.lax.scan(
            _update_column_scan,
            (self.mean, self.var, self.count, self.return_val),
            (reward.T, done.T),
        )

        
        return replace(self, mean=new_mean, var=new_var, count=new_count), normalized_rewards.T






    def tree_flatten(self):
        return ((self.mean, self.var, self.count, self.return_val), self.size)
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        size = aux_data
        mean, var, count, return_val = children
        return cls(size=size, mean=mean, var=var, count=count, return_val=return_val)
    
        
jax.tree_util.register_pytree_node(
    RewardNormalizer,
    RewardNormalizer.tree_flatten,
    RewardNormalizer.tree_unflatten
)

In [9]:
no_epochs = [4]

envs = ["walker2d_uni", "ant_uni"]

#os.makedirs("grad_steps_experiments/reps=16_new/", exist_ok=True)


for env_ in envs:
    
    env_dir = f"grad_steps_experiments/value/no_agents=1/epochs=4/{env_}"
    os.makedirs(env_dir, exist_ok=True)
    
    for no_epoch in no_epochs:
        config = Config(
            seed=10,
            env_name=env_,
            episode_length=1024,
            policy_hidden_layer_sizes=[128, 128],
            num_evaluations=0,
            num_iterations=4000,
            num_samples=8,
            no_agents=1,
            fixed_init_state=False,
            discard_dead=False,
            grid_shape=[50, 50],
            num_init_cvt_samples=50000,
            num_centroids=1024,
            log_period=400,
            store_repertoire=True,
            store_repertoire_log_period=800,
            iso_sigma=0.005,
            line_sigma=0.05,
            proportion_mutation_ga=0.5,
            buffer_sample_batch_size=1,
            buffer_add_batch_size=1,
            no_epochs=no_epoch,
            #buffer_size=64000,
            adam_optimizer=True,
            learning_rate=3e-4,
            discount_rate=0.99,
            clip_param=0.2,
            no_neurons=64,
            vf_coef=0.5,
            activation="tanh",
            num_minibatches=32,
            max_grad_norm=0.5,
            lecun=False,
        )
        random_key = jax.random.PRNGKey(config.seed)

        # Init environment
        env = get_env(env_)
        reset_fn = jax.jit(env.reset)

        # Compute the centroids
        centroids, random_key = compute_cvt_centroids(
            num_descriptors=env.behavior_descriptor_length,
            num_init_cvt_samples=config.num_init_cvt_samples,
            num_centroids=config.num_centroids,
            minval=0,
            maxval=1,
            random_key=random_key,
        )
        # Init policy network
        
        
        obs_normalizer = ObsNormalizer(env.observation_size)
        reward_normalizer = RewardNormalizer(config.no_agents)

        

        
        
        policy_network = MLPPPO(
            action_dim=env.action_size,
            activation=config.activation,
            no_neurons=config.no_neurons,
        )

        # Init population of controllers
        
        # maybe consider adding two random keys for each policy
        random_key, subkey = jax.random.split(random_key)
        keys = jax.random.split(subkey, num=config.no_agents)
        #split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
        #keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
        fake_batch_obs = jnp.zeros(shape=(config.no_agents, env.observation_size))
        init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

        param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
        print("Number of parameters in policy_network: ", param_count)

        # Define the fonction to play a step with the policy in the environment
        @jax.jit
        def play_step_fn(env_state, policy_params, random_key):
            random_key, subkey = jax.random.split(random_key)
            #pi, val = policy_network.apply(policy_params, env_state.obs)
            #action = pi.sample(seed=subkey)
            pi, action, val = policy_network.apply(policy_params, env_state.obs)
            
            logp = pi.log_prob(action)
            state_desc = env_state.info["state_descriptor"]
            next_state = env.step(env_state, action)
            _, _, next_val = policy_network.apply(policy_params, next_state.obs)

            transition = PPOTransition(
                obs=env_state.obs,
                next_obs=next_state.obs,
                rewards=next_state.reward,
                dones=next_state.done,
                truncations=next_state.info["truncation"],
                actions=action,
                state_desc=state_desc,
                next_state_desc=next_state.info["state_descriptor"],
                val_adv=val,
                target=next_val,
                logp=logp
                #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
                #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
            )

            return (next_state, policy_params, random_key), transition




        # Prepare the scoring function
        bd_extraction_fn = behavior_descriptor_extractor[config.env_name]
        scoring_fn = partial(
            scoring_function,
            episode_length=config.episode_length,
            play_reset_fn=reset_fn,
            play_step_fn=play_step_fn,  
            behavior_descriptor_extractor=bd_extraction_fn,
        )
        #reward_offset = get_reward_offset_brax(env, config.env_name)
        #print(f"Reward offset: {reward_offset}")
        
        me_scoring_fn = partial(
        sampling,
        scoring_fn=scoring_fn,
        num_samples=config.num_samples,
    )

        reward_offset = 0
    
                    
        
            

    
        '''
        def get_n_offspring_added(metrics):
            split = jnp.cumsum(jnp.array([emitter.batch_size for emitter in map_elites._emitter.emitters]))
            split = jnp.split(metrics["is_offspring_added"], split, axis=-1)[:-1]
            qpg_offspring_added, ai_offspring_added = jnp.split(split[0], (split[0].shape[1]-1,), axis=-1)
            return (jnp.sum(split[1], axis=-1), jnp.sum(qpg_offspring_added, axis=-1), jnp.sum(ai_offspring_added, axis=-1))
        '''
        # Get minimum reward value to make sure qd_score are positive
        

        # Define a metrics function
        metrics_function = partial(
            default_qd_metrics,
            qd_offset=reward_offset * config.episode_length,
        )

        # Define the PG-emitter config
        
        ppo_me_config = PPOMEConfig(
            proportion_mutation_ga=config.proportion_mutation_ga,
            no_agents=config.no_agents,
            buffer_sample_batch_size=config.buffer_sample_batch_size,
            buffer_add_batch_size=config.buffer_add_batch_size,
            #batch_size=config.batch_size,
            #mini_batch_size=config.mini_batch_size,
            no_epochs=config.no_epochs,
            #buffer_size=config.buffer_size,
            learning_rate=config.learning_rate,
            adam_optimizer=config.adam_optimizer,
            clip_param=config.clip_param,
            num_minibatches=config.num_minibatches,
            vf_coef=config.vf_coef,
            max_grad_norm=config.max_grad_norm,
        )
        
        variation_fn = partial(
            isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
        )
        
        ppo_me_emitter = PPOMEmitter(
            config=ppo_me_config,
            policy_network=policy_network,
            env=env,
            variation_fn=variation_fn,
            )


        # Instantiate MAP Elites
        map_elites = MAPElites(
            scoring_function=scoring_fn,
            emitter=ppo_me_emitter,
            metrics_function=metrics_function,
        )

        fitnesses, descriptors, extra_scores, random_key, obs_normalizer, reward_normalizer = scoring_fn(
            init_params, random_key, obs_normalizer, reward_normalizer
        )
        
        repertoire = MapElitesRepertoire.init(
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            extra_scores=extra_scores,
        )
        

        emitter_state, random_key = ppo_me_emitter.init(
            random_key=random_key,
            repertoire=repertoire,
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )
        
        emitter_state = ppo_me_emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores={**extra_scores}#, **extra_info},
        )
        
        emitter_state = emitter_state.emitter_states[0]
        
        


        returns = []
        old_params = init_params
        random_key = jax.random.PRNGKey(0)
        for _ in range(250):
            random_keys = jax.random.split(random_key, config.no_agents)
            new_params = ppo_me_emitter.emitters[0].emit_mcpg(emitter_state, old_params, random_keys)
            fitnesses, descriptors, extra_scores, random_key, obs_normalizer, reward_normalizer = scoring_fn(
                new_params, random_key, obs_normalizer, reward_normalizer
            )
            emitter_state = ppo_me_emitter.emitters[0].state_update(
                emitter_state=emitter_state,
                repertoire=repertoire,
                genotypes=new_params,
                fitnesses=fitnesses,
                descriptors=descriptors,
                extra_scores=extra_scores,
            )
            old_params = new_params
            print(f"mean fitness: {fitnesses.mean()}")
            returns.append(fitnesses)
            
        returns = jnp.array(returns)  # Assuming 'returns' is already defined as a 2D array

        # Determine the overall min and max fitness values for setting y-axis limits
        ymin = returns.min()
        ymax = returns.max()

        # Plotting
        fig, axs = plt.subplots(16, 16, figsize=(32, 64))  # Adjust the subplot grid to 16x16
        for i in range(256):  # Loop through 256 plots
            ax = axs[i // 16, i % 16]  # This assumes a 16x16 grid of subplots
            ax.plot(returns[:, i])
            ax.set_title(f"Policy {i+1}", fontsize=8)
            ax.set_xlabel('Steps', fontsize=6)
            ax.set_ylabel('Returns', fontsize=6)
            ax.set_ylim([ymin, ymax])  # Set the same y-axis limits for all subplots
            ax.tick_params(axis='both', which='major', labelsize=6)

        plt.tight_layout()
        # Save each plot in the specified directory with the correct filename
        plot_filename = f"{env_dir}/no_GA.png"
        plt.savefig(plot_filename)
        plt.close(fig)  # Close the plot to free up memory

Number of parameters in policy_network:  11213


  repertoire = MapElitesRepertoire.init(


mean fitness: 73.96699523925781
mean fitness: 79.9056625366211
mean fitness: 169.95632934570312
mean fitness: 145.59222412109375
mean fitness: 179.7495574951172
mean fitness: 152.00436401367188
mean fitness: 163.63609313964844
mean fitness: 335.499267578125
mean fitness: 316.33038330078125
mean fitness: 290.455810546875
mean fitness: 329.5206604003906
mean fitness: 144.17889404296875
mean fitness: 114.39218139648438
mean fitness: 136.88348388671875
mean fitness: 419.93743896484375
mean fitness: 488.0634460449219
mean fitness: 140.0913543701172
mean fitness: 408.1572265625
mean fitness: 204.183837890625
mean fitness: 393.0503845214844
mean fitness: 132.5863494873047
mean fitness: 412.3414306640625
mean fitness: 92.5159912109375
mean fitness: 122.61912536621094
mean fitness: 158.54681396484375
mean fitness: 126.49113464355469
mean fitness: 154.40621948242188
mean fitness: 151.46058654785156
mean fitness: 181.7580108642578
mean fitness: 118.58670043945312
mean fitness: 109.83159637451172


  repertoire = MapElitesRepertoire.init(


mean fitness: 900.4405517578125
mean fitness: 270.5308837890625
mean fitness: 229.413330078125
mean fitness: 32.42040252685547
mean fitness: 27.690799713134766
mean fitness: 87.607666015625
mean fitness: 24.929851531982422
mean fitness: 0.6529399156570435
mean fitness: 42.85212707519531
mean fitness: 18.09977149963379
mean fitness: 109.50254821777344
mean fitness: 12.309200286865234
mean fitness: 24.262470245361328
mean fitness: 15.440832138061523
mean fitness: 24.750446319580078
mean fitness: 21.30231285095215
mean fitness: 16.904834747314453
mean fitness: 25.967872619628906
mean fitness: 6.589447498321533
mean fitness: 8.174848556518555
mean fitness: 6.11345100402832
mean fitness: 4.38013219833374
mean fitness: 4.478180885314941
mean fitness: 1.7931606769561768
mean fitness: 3.423398494720459
mean fitness: 0.5623123645782471
mean fitness: 2.0097997188568115
mean fitness: 0.6213541626930237
mean fitness: 0.0
mean fitness: 2.120435953140259
mean fitness: 0.16271032392978668
mean fitnes

KeyboardInterrupt: 

In [12]:
no_epochs = [4]

envs = ["walker2d_uni", "ant_uni"]

#os.makedirs("grad_steps_experiments/reps=16_new/", exist_ok=True)


for env_ in envs:
    
    env_dir = f"grad_steps_experiments/value/sample=8/epochs=4/{env_}"
    os.makedirs(env_dir, exist_ok=True)
    
    for no_epoch in no_epochs:
        config = Config(
            seed=10,
            env_name=env_,
            episode_length=1024,
            policy_hidden_layer_sizes=[128, 128],
            num_evaluations=0,
            num_iterations=4000,
            num_samples=8,
            no_agents=256,
            fixed_init_state=False,
            discard_dead=False,
            grid_shape=[50, 50],
            num_init_cvt_samples=50000,
            num_centroids=1024,
            log_period=400,
            store_repertoire=True,
            store_repertoire_log_period=800,
            iso_sigma=0.005,
            line_sigma=0.05,
            proportion_mutation_ga=0.5,
            buffer_sample_batch_size=8,
            buffer_add_batch_size=256,
            no_epochs=no_epoch,
            #buffer_size=64000,
            adam_optimizer=True,
            learning_rate=3e-4,
            discount_rate=0.99,
            clip_param=0.2,
            no_neurons=64,
            vf_coef=0.5,
            activation="tanh",
            num_minibatches=32,
            max_grad_norm=0.5,
            lecun=False,
        )
        random_key = jax.random.PRNGKey(config.seed)

        # Init environment
        env = get_env(env_)
        reset_fn = jax.jit(env.reset)

        # Compute the centroids
        centroids, random_key = compute_cvt_centroids(
            num_descriptors=env.behavior_descriptor_length,
            num_init_cvt_samples=config.num_init_cvt_samples,
            num_centroids=config.num_centroids,
            minval=0,
            maxval=1,
            random_key=random_key,
        )
        # Init policy network
        
        
        obs_normalizer = ObsNormalizer(env.observation_size)
        reward_normalizer = RewardNormalizer(config.no_agents)

        

        
        
        policy_network = MLPPPO(
            action_dim=env.action_size,
            activation=config.activation,
            no_neurons=config.no_neurons,
        )

        # Init population of controllers
        
        # maybe consider adding two random keys for each policy
        random_key, subkey = jax.random.split(random_key)
        keys = jax.random.split(subkey, num=config.no_agents)
        #split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
        #keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
        fake_batch_obs = jnp.zeros(shape=(config.no_agents, env.observation_size))
        init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

        param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
        print("Number of parameters in policy_network: ", param_count)

        # Define the fonction to play a step with the policy in the environment
        @jax.jit
        def play_step_fn(env_state, policy_params, random_key):
            random_key, subkey = jax.random.split(random_key)
            #pi, val = policy_network.apply(policy_params, env_state.obs)
            #action = pi.sample(seed=subkey)
            pi, action, val = policy_network.apply(policy_params, env_state.obs)
            
            logp = pi.log_prob(action)
            state_desc = env_state.info["state_descriptor"]
            next_state = env.step(env_state, action)
            _, _, next_val = policy_network.apply(policy_params, next_state.obs)

            transition = PPOTransition(
                obs=env_state.obs,
                next_obs=next_state.obs,
                rewards=next_state.reward,
                dones=next_state.done,
                truncations=next_state.info["truncation"],
                actions=action,
                state_desc=state_desc,
                next_state_desc=next_state.info["state_descriptor"],
                val_adv=val,
                target=next_val,
                logp=logp
                #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
                #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
            )

            return (next_state, policy_params, random_key), transition




        # Prepare the scoring function
        bd_extraction_fn = behavior_descriptor_extractor[config.env_name]
        scoring_fn = partial(
            scoring_function,
            episode_length=config.episode_length,
            play_reset_fn=reset_fn,
            play_step_fn=play_step_fn,  
            behavior_descriptor_extractor=bd_extraction_fn,
        )
        #reward_offset = get_reward_offset_brax(env, config.env_name)
        #print(f"Reward offset: {reward_offset}")
        
        me_scoring_fn = partial(
        sampling,
        scoring_fn=scoring_fn,
        num_samples=config.num_samples,
    )

        reward_offset = 0
    
                    
        
            

    
        '''
        def get_n_offspring_added(metrics):
            split = jnp.cumsum(jnp.array([emitter.batch_size for emitter in map_elites._emitter.emitters]))
            split = jnp.split(metrics["is_offspring_added"], split, axis=-1)[:-1]
            qpg_offspring_added, ai_offspring_added = jnp.split(split[0], (split[0].shape[1]-1,), axis=-1)
            return (jnp.sum(split[1], axis=-1), jnp.sum(qpg_offspring_added, axis=-1), jnp.sum(ai_offspring_added, axis=-1))
        '''
        # Get minimum reward value to make sure qd_score are positive
        

        # Define a metrics function
        metrics_function = partial(
            default_qd_metrics,
            qd_offset=reward_offset * config.episode_length,
        )

        # Define the PG-emitter config
        
        ppo_me_config = PPOMEConfig(
            proportion_mutation_ga=config.proportion_mutation_ga,
            no_agents=config.no_agents,
            buffer_sample_batch_size=config.buffer_sample_batch_size,
            buffer_add_batch_size=config.buffer_add_batch_size,
            #batch_size=config.batch_size,
            #mini_batch_size=config.mini_batch_size,
            no_epochs=config.no_epochs,
            #buffer_size=config.buffer_size,
            learning_rate=config.learning_rate,
            adam_optimizer=config.adam_optimizer,
            clip_param=config.clip_param,
            num_minibatches=config.num_minibatches,
            vf_coef=config.vf_coef,
            max_grad_norm=config.max_grad_norm,
        )
        
        variation_fn = partial(
            isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
        )
        
        ppo_me_emitter = PPOMEmitter(
            config=ppo_me_config,
            policy_network=policy_network,
            env=env,
            variation_fn=variation_fn,
            )


        # Instantiate MAP Elites
        map_elites = MAPElites(
            scoring_function=scoring_fn,
            emitter=ppo_me_emitter,
            metrics_function=metrics_function,
        )

        fitnesses, descriptors, extra_scores, random_key, obs_normalizer, reward_normalizer = scoring_fn(
            init_params, random_key, obs_normalizer, reward_normalizer
        )
        
        repertoire = MapElitesRepertoire.init(
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            extra_scores=extra_scores,
        )
        

        emitter_state, random_key = ppo_me_emitter.init(
            random_key=random_key,
            repertoire=repertoire,
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )
        
        emitter_state = ppo_me_emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_params,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores={**extra_scores}#, **extra_info},
        )
        
        emitter_state = emitter_state.emitter_states[0]
        
        


        returns = []
        old_params = init_params
        random_key = jax.random.PRNGKey(0)
        for _ in range(250):
            random_keys = jax.random.split(random_key, config.no_agents)
            new_params = ppo_me_emitter.emitters[0].emit_mcpg(emitter_state, old_params, random_keys)
            fitnesses, descriptors, extra_scores, random_key, obs_normalizer, reward_normalizer = scoring_fn(
                new_params, random_key, obs_normalizer, reward_normalizer
            )
            emitter_state = ppo_me_emitter.emitters[0].state_update(
                emitter_state=emitter_state,
                repertoire=repertoire,
                genotypes=new_params,
                fitnesses=fitnesses,
                descriptors=descriptors,
                extra_scores=extra_scores,
            )
            old_params = new_params
            print(f"mean fitness: {fitnesses.mean()}")
            returns.append(fitnesses)
            
        returns = jnp.array(returns)  # Assuming 'returns' is already defined as a 2D array

        # Determine the overall min and max fitness values for setting y-axis limits
        ymin = returns.min()
        ymax = returns.max()

        # Plotting
        fig, axs = plt.subplots(16, 16, figsize=(32, 64))  # Adjust the subplot grid to 16x16
        for i in range(256):  # Loop through 256 plots
            ax = axs[i // 16, i % 16]  # This assumes a 16x16 grid of subplots
            ax.plot(returns[:, i])
            ax.set_title(f"Policy {i+1}", fontsize=8)
            ax.set_xlabel('Steps', fontsize=6)
            ax.set_ylabel('Returns', fontsize=6)
            ax.set_ylim([ymin, ymax])  # Set the same y-axis limits for all subplots
            ax.tick_params(axis='both', which='major', labelsize=6)

        plt.tight_layout()
        # Save each plot in the specified directory with the correct filename
        plot_filename = f"{env_dir}/no_GA.png"
        plt.savefig(plot_filename)
        plt.close(fig)  # Close the plot to free up memory

Number of parameters in policy_network:  11213


  repertoire = MapElitesRepertoire.init(


mean fitness: 175.44656372070312
mean fitness: 256.46307373046875
mean fitness: 264.683349609375
mean fitness: 260.86041259765625
mean fitness: 258.6240539550781
mean fitness: 255.4115753173828
mean fitness: 253.8948974609375
mean fitness: 253.1417236328125
mean fitness: 251.8593292236328
mean fitness: 250.58399963378906
mean fitness: 249.24386596679688
mean fitness: 247.8236541748047
mean fitness: 245.6031494140625
mean fitness: 243.05429077148438
mean fitness: 240.92770385742188
mean fitness: 238.82809448242188
mean fitness: 237.3988037109375
mean fitness: 236.3238525390625
mean fitness: 235.1883087158203
mean fitness: 234.16075134277344
mean fitness: 233.25027465820312
mean fitness: 232.9304962158203
mean fitness: 232.72140502929688
mean fitness: 232.60238647460938
mean fitness: 232.53341674804688
mean fitness: 232.27178955078125
mean fitness: 231.92001342773438
mean fitness: 231.93496704101562
mean fitness: 231.7150421142578
mean fitness: 231.8817138671875
mean fitness: 232.8479003

  repertoire = MapElitesRepertoire.init(


mean fitness: 985.8655395507812
mean fitness: 949.7366943359375
mean fitness: 948.462158203125
mean fitness: 928.8270263671875
mean fitness: 898.0606689453125
mean fitness: 862.854248046875
mean fitness: 839.592529296875
mean fitness: 821.0694580078125
mean fitness: 795.7606811523438
mean fitness: 751.163330078125
mean fitness: 695.0880126953125
mean fitness: 618.5626220703125
mean fitness: 550.2085571289062
mean fitness: 503.65948486328125
mean fitness: 462.3690185546875
mean fitness: 434.7431945800781
mean fitness: 401.716796875
mean fitness: 373.7886962890625
mean fitness: 345.31884765625
mean fitness: 309.61181640625
mean fitness: 265.2862854003906
mean fitness: 225.62075805664062
mean fitness: 197.2504425048828
mean fitness: 163.7965545654297
mean fitness: 154.7968292236328
mean fitness: 127.79566192626953
mean fitness: 104.03175354003906
mean fitness: 89.88478088378906
mean fitness: 79.24995422363281
mean fitness: 61.73939514160156
mean fitness: 53.6400260925293
mean fitness: 42.