In [1]:
!nvidia-smi

Thu Jul  4 11:00:01 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        Off | 00000000:2D:00.0  On |                  N/A |
|  0%   40C    P5              20W / 320W |     89MiB / 16376MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
from dataclasses import dataclass
from functools import partial
from math import floor 
from typing import Callable, Tuple, Any

import jax
from jax import debug
import jax.numpy as jnp
import flax.linen as nn
import optax
from chex import ArrayTree
from qdax.core.containers.repertoire import Repertoire
from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.environments.base_wrappers import QDEnv
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer
from rein_related import *

from qdax.core.emitters.emitter import Emitter, EmitterState

In [3]:
EPS = 1e-8


@dataclass
class MCPGConfig:
    """Configuration for the REINaive emitter.
    
    Args:
        rollout_number: num of rollouts for gradient estimate
        sample_sigma: std to sample the samples for gradient estimate  (IS THIS PARAMETER SPACE EXPLORATION?)
        sample_mirror: if True, use mirroring sampling
        sample_rank_norm: if True, use normalisation
        
        num_generations_sample: frequency of archive-sampling
        
        adam_optimizer: if True, use ADAM, if False, use SGD
        learning_rate: obvious
        l2_coefficient: coefficient for regularisation
        
        novelty_nearest_neighbors: num of nearest neigbors for novelty computation
        use_novelty_archive: if True, use novelty archive for novelty (default is to use the content of the reperoire)
        use_novelty_fifo: if True, use fifo archive for novelty (default is to use the content of the repertoire)
        fifo_size: size of the novelty fifo bugger if used
        
        proprtion_explore: proportion of explore
    """
    no_agents: int = 256
    batch_size: int = 1000*256
    mini_batch_size: int = 1000*256
    no_epochs: int = 16
    learning_rate: float = 3e-4
    discount_rate: float = 0.99
    adam_optimizer: bool = True
    buffer_size: int = 256000
    clip_param: float = 0.2
    
class MCPGEmitterState(EmitterState):
    """Containes the trajectory buffer.
    """
    buffer: TrajectoryBuffer
    random_key: RNGKey
    
class MCPGEmitter(Emitter):
    
    def __init__(
        self,
        config: MCPGConfig,
        policy_net: nn.Module,
        env: QDEnv,
    ) -> None:
        
        self._config = config
        self._policy = policy_net
        self._env = env
        
        self._policy_opt = optax.adam(
            learning_rate=self._config.learning_rate
        )
        
    @property
    def batch_size(self) -> int:
        """
        Returns:
            int: the batch size emitted by the emitter.
        """
        return self._config.no_agents
    
    @property
    def use_all_data(self) -> bool:
        """Whther to use all data or not when used along other emitters.
        """
        return True
    
    @partial(jax.jit, static_argnames=("self",))
    def init(
        self,
        random_key: RNGKey,
        repertoire: Repertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> Tuple[MCPGEmitterState, RNGKey]:
        """Initializes the emitter state.
        """
        obs_size = self._env.observation_size
        action_size = self._env.action_size
        descriptor_size = self._env.state_descriptor_length
        
        # Init trajectory buffer
        dummy_transition = QDTransition.init_dummy(
            observation_dim=obs_size,
            action_dim=action_size,
            descriptor_dim=descriptor_size,
        )
        
        buffer = TrajectoryBuffer.init(
            buffer_size=self._config.buffer_size,
            transition=dummy_transition,
            env_batch_size=self._config.no_agents*2,
            episode_length=self._env.episode_length,
        )
        
        random_key, subkey = jax.random.split(random_key)
        emitter_state = MCPGEmitterState(
            buffer=buffer,
            random_key=subkey,
        )
        
        return emitter_state, random_key
    
    @partial(jax.jit, static_argnames=("self",))
    def emit(
        self,
        repertoire: Repertoire,
        emitter_state: MCPGEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Do a step of MCPG emission.
        """
        
        no_agents = self._config.no_agents
        
        # sample parents
        parents, random_key = repertoire.sample(
            random_key=random_key,
            num_samples=no_agents,
        )
        
        offsprings_mcpg = self.emit_mcpg(emitter_state, parents)
        
        return offsprings_mcpg, {}, random_key
    
    @partial(jax.jit, static_argnames=("self",))
    def emit_mcpg(
        self,
        emitter_state: MCPGEmitterState,
        parents: Genotype,
    ) -> Genotype:
        """Emit the offsprings generated through MCPG mutation.
        """
        
        mutation_fn = partial(
            self._mutation_function_mcpg,
            emitter_state=emitter_state,
        )
        
        offsprings = jax.vmap(mutation_fn)(parents)
        
        return offsprings
    
    @partial(jax.jit, static_argnames=("self",))
    def state_update(
        self,
        emitter_state: MCPGEmitterState,
        repertoire: Optional[Repertoire],
        genotypes: Optional[Genotype],
        fitnesses: Optional[Fitness],
        descriptors: Optional[Descriptor],
        extra_scores: ExtraScores,
    ) -> MCPGEmitterState:
        """Update the emitter state.
        """
        
        assert "transitions" in extra_scores.keys(), "Missing transtitions or wrong key"
        transitions = extra_scores["transitions"]
        
        # update the buffer
        replay_buffer = emitter_state.buffer.insert(transitions)
        emitter_state = emitter_state.replace(buffer=replay_buffer)
        
        return emitter_state
    
    @partial(jax.jit, static_argnames=("self",))
    def compute_mask(
        self,
        done,
    ):
        return 1. - jnp.clip(jnp.cumsum(done), a_min=0., a_max=1.)
    '''
    @partial(jax.jit, static_argnames=("self",))
    def compute_logps(
        self,
        policy_params,
        obs,
        actions,
    ):
        """Compute the log probabilities of the actions.
        """
        compute_logp = partial(
            self._policy.apply,
            params=policy_params,
            method=self._policy.logp,
        )
        
        return jax.vmap(compute_logp)(obs, actions)
    '''
    
    @partial(jax.jit, static_argnames=("self",))
    def compute_logps(self, policy_params, obs, actions):
        def compute_logp(single_obs, single_action):
            # Correctly handle operations on single_obs and single_action
            # Ensure no inappropriate method calls like .items() are made
            return self._policy.apply(policy_params, single_obs, single_action, method=self._policy.logp)

        # Use jax.vmap to apply compute_logp across batches of obs and actions
        return jax.vmap(compute_logp, in_axes=(0, 0))(obs, actions)
       
    @partial(jax.jit, static_argnames=("self",))
    def get_return(
        self,
        rewards,
    ):
        def _body(carry, x):
            (next_return,) = carry
            (rewards,) = x

            current_return = rewards + self._config.discount_rate * next_return
            return (current_return,), (current_return,)
        
        
        
        #jax.debug.print("rewards", rewards.shape)
        
        _, (return_,) = jax.lax.scan(
            _body,
            (jnp.array(0.),),
            (rewards,),
            length=self._env.episode_length,
            reverse=True,
        )
        
        return return_
    
    
    '''
    @partial(jax.jit, static_argnames=("self",))
    def get_return(self, rewards):
        def _body(carry, reward):
            next_return = carry  # carry should be unpacked directly if it's a single element
            current_return = reward + self._config.discount_rate * next_return
            return current_return, current_return  # Maintain the same shape and type

        initial_return = jnp.array(0.0)  # Ensure initial_return is correctly shaped as a scalar
        _, return_ = jax.lax.scan(
            _body,
            initial_return,
            rewards,  # Pass rewards directly without extra tuple wrapping
            length=int(self._env.episode_length),
            reverse=True,
        )

        return return_
    '''
    
    @partial(jax.jit, static_argnames=("self",))
    def standardize(
        self,
        return_,
    ):
        return jax.nn.standardize(return_, axis=0, variance=1, epsilon=EPS)
    
    @partial(jax.jit, static_argnames=("self",))
    def get_standardized_return(
        self,
        rewards,
        mask,
    ):
        mask = jnp.expand_dims(mask, axis=-1)
        valid_rewards = (rewards * mask).squeeze(axis=-1)
        #jax.debug.print("mask: {}", mask.shape)
        #jax.debug.print("rewards*mask: {}", (rewards * mask).shape)
        return_ = jax.vmap(self.get_return)(valid_rewards)
        return self.standardize(return_)
    
    @partial(jax.jit, static_argnames=("self",))
    def _mutation_function_mcpg(
        self,
        policy_params,
        emitter_state: MCPGEmitterState,
    ) -> Genotype:
        """Mutation function for MCPG.
        """
        
        buffer = emitter_state.buffer
        
        policy_opt_state = self._policy_opt.init(policy_params)
        
        random_key = emitter_state.random_key
        
        #random_key, subkey = jax.random.split(emitter_state.random_key)
        sample_size = int(self._config.batch_size) // int(self._env.episode_length)
        #print(f"episodic_data_size: {int(buffer.current_episodic_data_size)}")
        #episodic_data_size = buffer.current_episodic_data_size.item()
        
        trans, random_key = buffer.sample(
            random_key=random_key,
            sample_size=sample_size,
            episodic_data_size=64,
            sample_traj=True,
        )
        new_emitter_state = emitter_state.replace(random_key=random_key)
        # trans has shape (episde_length*sample_size, transition_dim)
        
        obs = trans.obs.reshape(sample_size, self._env.episode_length, -1)
        actions = trans.actions.reshape(sample_size, self._env.episode_length, -1)
        rewards = trans.rewards.reshape(sample_size, self._env.episode_length, -1)
        #jax.debug.print("rewards shape: {}", rewards.shape)
        #print(f"rewards shape: {rewards.shape}")
        dones = trans.dones.reshape(sample_size, self._env.episode_length, -1)
        
        mask = jax.vmap(self.compute_mask, in_axes=0)(dones)
        logps = jax.vmap(self.compute_logps, in_axes=(None, 0, 0))(policy_params, obs, actions)
        
        standardized_returns = self.get_standardized_return(rewards, mask)
        
        def scan_train_policy(
            carry: Tuple[MCPGEmitterState, Genotype, optax.OptState],
            unused: Any,
        ) -> Tuple[Tuple[MCPGEmitterState, Genotype, optax.OptState], Any]:
            
            policy_params, policy_opt_state = carry
            
            (
                new_policy_params,
                new_policy_opt_state,
            ) = self._train_policy_(
                policy_params,
                policy_opt_state,
                obs,
                actions,
                standardized_returns,
                mask,
                logps,
            )
            
            return (
                new_policy_params,
                new_policy_opt_state,
            ), None

        (policy_params, policy_opt_state), _ = jax.lax.scan(
            scan_train_policy,
            (policy_params, policy_opt_state),
            None,
            length=self._config.no_epochs,
        )
        
        return policy_params
    
    
    @partial(jax.jit, static_argnames=("self",))
    def _train_policy_(
        self,
        policy_params,
        policy_opt_state,
        obs,
        actions,
        standardized_returns,
        mask,
        logps,
    ):
        """Train the policy.
        """
        
        def _scan_update(carry, _):
            policy_params, policy_opt_state = carry
            grads = jax.grad(self.loss_ppo)(policy_params, obs, actions, logps, mask, standardized_returns)
            updates, new_policy_opt_state = self._policy_opt.update(grads, policy_opt_state)
            new_policy_params = optax.apply_updates(policy_params, updates)
            return (new_policy_params, new_policy_opt_state), None
        
        (final_policy_params, final_policy_opt_state), _ = jax.lax.scan(
            _scan_update,
            (policy_params, policy_opt_state),
            None,
            length=1,
        )

        return final_policy_params, final_policy_opt_state
    
    @partial(jax.jit, static_argnames=("self",))
    def loss_ppo(
        self,
        params,
        obs,
        actions,
        logps,
        mask,
        standardized_returns,
    ):
        
        logps_ = self._policy.apply(
            params,
            jax.lax.stop_gradient(obs),
            jax.lax.stop_gradient(actions),
            method=self._policy.logp,
        )
        ratio = jnp.exp(logps_ - jax.lax.stop_gradient(logps))
        
        pg_loss_1 = jnp.multiply(ratio * mask, jax.lax.stop_gradient(standardized_returns))
        pg_loss_2 = jax.lax.stop_gradient(standardized_returns) * jax.lax.clamp(1. - self._config.clip_param, ratio, 1. + self._config.clip_param) * mask
        
        return -jnp.mean(jnp.minimum(pg_loss_1, pg_loss_2))

In [4]:
import os

os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['WANDB_CACHE_DIR'] = '/tmp/wandb_cache'
os.environ['JAX_LOG_COMPILATION'] = '1'

import logging
import time
from dataclasses import dataclass
from functools import partial
from math import floor
from typing import Any, Dict, Tuple, List, Callable
import pickle
from flax import serialization
#logging.basicConfig(level=logging.DEBUG)
import hydra
from omegaconf import OmegaConf, DictConfig
import jax
import jax.numpy as jnp
from hydra.core.config_store import ConfigStore
from qdax.core.map_elites import MAPElites
from qdax.types import RNGKey, Genotype
from qdax.utils.sampling import sampling 
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.core.neuroevolution.networks.networks import MLPMCPG
from qdax.core.emitters.me_mcpg_emitter import MEMCPGConfig, MEMCPGEmitter
#from qdax.core.emitters.rein_emitter_advanced import REINaiveConfig, REINaiveEmitter
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.environments import behavior_descriptor_extractor
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from utils import Config, get_env
from qdax.core.emitters.mutation_operators import isoline_variation
import wandb
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results, plot_2d_map_elites_repertoire
import matplotlib.pyplot as plt
from set_up_brax import get_reward_offset_brax
from qdax import environments_v1, environments

In [5]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [6]:
@dataclass
class Config:
    """Configuration from this experiment script
    """
    # Env config
    #alg_name: str
    seed: int
    env_name: str
    episode_length: int
    policy_hidden_layer_sizes: Tuple[int, ...]   
    # ME config
    num_evaluations: int
    num_iterations: int
    no_agents: int
    num_samples: int
    fixed_init_state: bool
    discard_dead: bool
    # Emitter config
    iso_sigma: float
    line_sigma: float
    #crossover_percentage: float
    # Grid config 
    grid_shape: Tuple[int, ...]
    num_init_cvt_samples: int
    num_centroids: int
    # Log config
    log_period: int
    store_repertoire: bool
    store_repertoire_log_period: int
    
    # REINFORCE Parameters
    proportion_mutation_ga : float
    batch_size: int
    mini_batch_size: int
    adam_optimizer: bool
    learning_rate: float
    discount_rate: float
    buffer_size: int
    clip_param: float
    no_epochs: int

In [7]:
config = Config(
    seed=0,
    env_name='ant_uni',
    episode_length=1000,
    policy_hidden_layer_sizes=[128, 128],
    num_evaluations=0,
    num_iterations=4000,
    num_samples=16,
    no_agents=64,
    fixed_init_state=False,
    discard_dead=False,
    grid_shape=[50, 50],
    num_init_cvt_samples=50000,
    num_centroids=1296,
    log_period=400,
    store_repertoire=True,
    store_repertoire_log_period=800,
    iso_sigma=0.005,
    line_sigma=0.05,
    proportion_mutation_ga=0.5,
    batch_size=64000,
    mini_batch_size=64000,
    no_epochs=16,
    buffer_size=64000,
    adam_optimizer=True,
    learning_rate=3e-4,
    discount_rate=0.99,
    clip_param=0.2
)

In [8]:
# Init a random key
random_key = jax.random.PRNGKey(config.seed)

# Init environment
env = get_env('ant_uni')
reset_fn = jax.jit(env.reset)

# Compute the centroids
centroids, random_key = compute_cvt_centroids(
    num_descriptors=env.behavior_descriptor_length,
    num_init_cvt_samples=config.num_init_cvt_samples,
    num_centroids=config.num_centroids,
    minval=0,
    maxval=1,
    random_key=random_key,
)
# Init policy network
policy_layer_sizes = config.policy_hidden_layer_sizes #+ (env.action_size,)
print(policy_layer_sizes)

'''
policy_network = MLPRein(
    action_size=env.action_size,
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    kernel_init_final=jax.nn.initializers.orthogonal(scale=0.01),
)
'''
policy_network = MLPMCPG(
    hidden_layers_size=policy_layer_sizes,
    action_size=env.action_size,
    activation=jax.nn.tanh,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    mean_init=jax.nn.initializers.orthogonal(scale=0.01),
)

# Init population of controllers

# maybe consider adding two random keys for each policy
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=config.no_agents)
#split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
#keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
fake_batch_obs = jnp.zeros(shape=(config.no_agents, env.observation_size))
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
print("Number of parameters in policy_network: ", param_count)

# Define the fonction to play a step with the policy in the environment
def play_step_fn(env_state, policy_params, random_key):
    #random_key, subkey = jax.random.split(random_key)
    actions = policy_network.apply(policy_params, env_state.obs)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=actions,
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
        #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
        #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
    )

    return next_state, policy_params, random_key, transition

# Prepare the scoring function
bd_extraction_fn = behavior_descriptor_extractor['ant_uni']
scoring_fn = partial(
    scoring_function,
    episode_length=env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)
#reward_offset = get_reward_offset_brax(env, config.env_name)
#print(f"Reward offset: {reward_offset}")

me_scoring_fn = partial(
sampling,
scoring_fn=scoring_fn,
num_samples=config.num_samples,
)

reward_offset = 0

metrics_function = partial(
    default_qd_metrics,
    qd_offset=reward_offset * env.episode_length,
)

# Define the PG-emitter config

me_mcpg_config = MEMCPGConfig(
    proportion_mutation_ga=config.proportion_mutation_ga,
    no_agents=config.no_agents,
    batch_size=config.batch_size,
    mini_batch_size=config.mini_batch_size,
    no_epochs=config.no_epochs,
    buffer_size=config.buffer_size,
    learning_rate=config.learning_rate,
    adam_optimizer=config.adam_optimizer,
    clip_param=config.clip_param,
)

variation_fn = partial(
    isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
)

me_mcpg_emitter = MEMCPGEmitter(
    config=me_mcpg_config,
    policy_network=policy_network,
    env=env,
    variation_fn=variation_fn,
    )

'''
rein_emitter = REINaiveEmitter(
    config=rein_emitter_config,
    policy_network=policy_network,
    env=env,
    )
'''
'''
me_scoring_fn = partial(
    sampling,
    scoring_fn=scoring_fn,
    num_samples=config.num_samples,
)
'''

# Instantiate MAP Elites
map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=me_mcpg_emitter,
    metrics_function=metrics_function,
)


[128, 128]
Number of parameters in policy_network:  21264


In [9]:
fitnesses, descriptors, extra_scores, random_key = scoring_fn(
    init_params, random_key
)

In [10]:

repertoire = MapElitesRepertoire.init(
    genotypes=init_params,
    fitnesses=fitnesses,
    descriptors=descriptors,
    centroids=centroids,
    extra_scores=extra_scores,
)

  repertoire = MapElitesRepertoire.init(


In [11]:
emitter_state, random_key = me_mcpg_emitter.init(
    random_key=random_key,
    repertoire=repertoire,
    genotypes=init_params,
    fitnesses=fitnesses,
    descriptors=descriptors,
    extra_scores=extra_scores,
)

In [12]:
emitter_state = me_mcpg_emitter.state_update(
    emitter_state=emitter_state,
    repertoire=repertoire,
    genotypes=init_params,
    fitnesses=fitnesses,
    descriptors=descriptors,
    extra_scores={**extra_scores}#, **extra_info},
)

In [13]:
repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key)

  repertoire = MapElitesRepertoire.init(


In [14]:
emitter = me_mcpg_emitter.emitters[0]

In [15]:
buffer = emitter._buffer

In [16]:
buffer

TrajectoryBuffer(init=functools.partial(<function init at 0x79daa2f177f0>, add_batch_size=64, max_length_time_axis=1000), add=functools.partial(<function add at 0x79daa2f452d0>), sample=functools.partial(<function sample at 0x79daa2f455a0>, batch_size=64, sequence_length=1000, period=1000), can_sample=functools.partial(<function can_sample at 0x79daa2f45630>, min_length_time_axis=1000))

In [16]:
buffer_state = emitter_state.emitter_states[0].buffer_state

In [17]:
buffer_state.experience.obs[0][0][:]

Array([ 0.5290031 ,  1.        ,  0.        ,  0.        ,  0.        ,
       -0.03629085,  0.81704986, -0.05197861, -0.8097921 ,  0.00165272,
       -0.91692144,  0.01418419,  0.7871757 ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , -0.08168617,
       -0.06450727, -0.08687162,  0.02566254, -0.08559051,  0.0939841 ,
        0.045632  ,  0.03564072,  1.        ], dtype=float32)

In [18]:
buffer_state.experience.obs[2][4][:]

Array([ 4.98008877e-01,  9.99349773e-01, -3.56806740e-02, -5.19142998e-03,
       -3.93481110e-04,  4.40969393e-02,  7.58652568e-01, -1.28756398e-02,
       -7.34257400e-01, -3.35761197e-02, -9.11176920e-01, -8.13868940e-02,
        9.62469995e-01,  1.61133230e-01, -4.71483134e-02,  1.97709322e-01,
        1.09147727e-02,  2.21215472e-01, -1.43821780e-02,  5.36329113e-02,
        1.73439249e-01, -1.05735414e-01,  5.38548417e-02,  7.56153017e-02,
        9.68139153e-03, -1.33329168e-01,  1.84919506e-01,  9.96000051e-01],      dtype=float32)

In [19]:
buffer_state.experience.actions[0][0][:]

Array([-0.002731  ,  0.00471119, -0.00401392, -0.00211869,  0.00360038,
        0.00064279, -0.00639079,  0.00423873], dtype=float32)

In [20]:
buffer_state.experience.actions[2][4][:]

Array([ 3.0507769e-03,  9.6688431e-04, -3.4789811e-03, -4.6893256e-06,
       -6.8702502e-04,  1.1279928e-03,  3.0877090e-03,  8.3098169e-03],      dtype=float32)

In [21]:
batch = buffer.sample(buffer_state, random_key)

In [22]:
trans = batch.experience

In [23]:
trans.obs[0][0][:]

Array([ 0.5291505 ,  1.        ,  0.        ,  0.        ,  0.        ,
       -0.02182825,  0.847605  , -0.01074276, -0.8526422 , -0.06634107,
       -0.91735   , -0.09717821,  0.91673166,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.04825523,
        0.00142464, -0.08119813,  0.08092397, -0.08638358, -0.02924797,
       -0.09949973,  0.06462932,  1.        ], dtype=float32)

In [24]:
trans.obs[2][4][:]

Array([ 5.3121257e-01,  9.9999011e-01,  1.7877686e-03, -3.5283160e-03,
       -2.0204734e-03,  8.3804600e-02,  9.6697682e-01,  8.0522284e-02,
       -8.5231072e-01,  4.2598136e-02, -9.2366564e-01,  8.9302110e-03,
        8.8921934e-01,  2.9175675e-02,  6.4497255e-02,  1.8257830e-01,
       -7.8799352e-02,  4.0022347e-02, -1.8428907e-02,  1.6422386e-04,
       -1.3612559e-01,  4.1865464e-02,  5.3015087e-02, -6.9273517e-02,
        1.8536733e-01,  8.5005034e-03, -4.6814773e-02,  9.9600005e-01],      dtype=float32)

In [25]:
trans.actions[0][0][:]

Array([-0.00435192, -0.00528283,  0.00254858, -0.00021622, -0.00222638,
        0.00776927, -0.00214178,  0.00236052], dtype=float32)

In [26]:
trans.actions[2][4][:]

Array([-0.00428008,  0.00254722, -0.0030441 ,  0.00269318,  0.00718779,
       -0.00362687, -0.0017461 , -0.00067989], dtype=float32)

In [28]:
new_genotypes, _, random_key = me_mcpg_emitter.emit(repertoire, emitter_state, random_key)

In [29]:
fitnesses, descriptors, extra_scores, random_key = scoring_fn(
    new_genotypes, random_key
)

In [30]:
emitter_state = me_mcpg_emitter.state_update(
    emitter_state=emitter_state,
    repertoire=repertoire,
    genotypes=init_params,
    fitnesses=fitnesses,
    descriptors=descriptors,
    extra_scores={**extra_scores}#, **extra_info},
)

In [31]:
buffer_state = emitter_state.emitter_states[0].buffer_state

In [32]:
buffer_state.experience.obs[0][0][:]

Array([ 0.541036  ,  1.        ,  0.        ,  0.        ,  0.        ,
       -0.06537589,  0.8383412 , -0.02807844, -0.9527363 , -0.05029454,
       -0.95151484, -0.04050948,  0.8234748 ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , -0.08284527,
        0.0715866 , -0.00264992,  0.07823871, -0.01610541,  0.06063195,
        0.08534649,  0.06462934,  1.        ], dtype=float32)

In [33]:
buffer_state.experience.obs[2][4][:]

Array([ 0.50336933,  0.99352807, -0.03358008,  0.05137687, -0.09557562,
       -0.0974462 ,  0.53065926,  0.21062675, -0.95509666,  0.41207787,
       -1.1381879 ,  0.34035403,  0.7715985 , -0.12753744,  0.06876729,
       -0.02821796, -0.5192963 ,  0.559862  , -0.8220054 ,  1.7768593 ,
       -0.19251561, -1.3835865 ,  0.26395983, -1.6554354 ,  1.644803  ,
       -1.2617568 ,  0.3183149 ,  0.99600005], dtype=float32)

In [34]:
batch = buffer.sample(buffer_state, random_key)

In [35]:
trans = batch.experience

In [36]:
trans.obs[0][0][:]

Array([ 0.54057497,  1.        ,  0.        ,  0.        ,  0.        ,
       -0.09035441,  0.94714147, -0.02431593, -0.8130315 , -0.03586044,
       -0.8318781 ,  0.0333339 ,  0.9513314 ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        , -0.08577353,
       -0.03481684, -0.03581014, -0.07189173,  0.00718719,  0.09306866,
       -0.08394337, -0.05480371,  1.        ], dtype=float32)

In [37]:
trans.obs[2][4][:]

Array([ 4.9661317e-01,  9.9974120e-01, -7.0252432e-04, -2.2742180e-02,
       -1.7487181e-04, -1.4617446e-02,  8.6953437e-01,  4.7205113e-02,
       -7.7981186e-01,  9.1657862e-02, -7.9237008e-01,  3.9201371e-02,
        8.8778913e-01,  7.7902801e-02,  4.2957343e-02,  2.5344259e-01,
       -1.4647022e-01,  1.9517195e-01, -5.8037054e-02,  2.5176302e-01,
        6.7438555e-01,  1.9923742e-01,  1.8061393e-01, -2.0841539e-01,
        3.6137107e-01, -4.1610140e-01, -1.7057478e-02,  9.9600005e-01],      dtype=float32)

In [38]:
trans.actions[0][0][:]

Array([ 0.00846978, -0.01904888,  0.01534687, -0.01836575,  0.03139308,
        0.01971916,  0.00268593,  0.00334678], dtype=float32)

In [39]:
trans.actions[2][4][:]

Array([-0.02622877, -0.01858241,  0.0061986 ,  0.00380347, -0.00801598,
        0.00849342,  0.03925945,  0.00011326], dtype=float32)

In [40]:
trans.obs.shape

(64, 1000, 28)

In [None]:
trans.actions.shape

(64, 1000, 8)

In [None]:
trans.rewards.shape

(64, 1000)

In [None]:
trans.dones.shape

(64, 1000)

In [None]:
obs = trans.obs
actions = trans.actions
rewards = trans.rewards
dones = trans.dones

In [None]:
mask = jax.vmap(emitter.compute_mask, in_axes=0)(dones)

In [None]:
mask.shape

(64, 1000)

In [None]:
returns = jax.vmap(emitter.get_return)(rewards *mask)

In [None]:
returns

Array([[ 98.876976  ,  98.86738   ,  98.86858   , ...,   1.975918  ,
          0.99457633,   0.        ],
       [100.559685  , 100.56604   , 100.56904   , ...,   1.9825337 ,
          0.9975025 ,   0.        ],
       [ 99.55968   ,  99.55858   ,  99.632225  , ...,   1.9780444 ,
          0.9920622 ,   0.        ],
       ...,
       [100.4241    , 100.42156   , 100.34936   , ...,   1.978764  ,
          0.99755573,   0.        ],
       [ 99.32288   ,  99.314575  ,  99.198715  , ...,   1.9770596 ,
          0.99299014,   0.        ],
       [ 99.32288   ,  99.314575  ,  99.198715  , ...,   1.9770596 ,
          0.99299014,   0.        ]], dtype=float32)

In [None]:
rewards.shape

(64, 1000)

In [None]:
returns = emitter.get_standardized_return(rewards, mask)

In [None]:
returns

Array([[-6.2009430e-01, -6.2589264e-01, -6.2570190e-01, ...,
        -2.6960373e-03, -2.7143955e-04,  0.0000000e+00],
       [ 1.0626144e+00,  1.0727692e+00,  1.0747528e+00, ...,
         3.9196014e-03,  2.6547313e-03,  0.0000000e+00],
       [ 6.2606812e-02,  6.5307617e-02,  1.3793945e-01, ...,
        -5.6970119e-04, -2.7855635e-03,  0.0000000e+00],
       ...,
       [ 9.2703247e-01,  9.2829132e-01,  8.5507202e-01, ...,
         1.4996529e-04,  2.7079582e-03,  0.0000000e+00],
       [-1.7418671e-01, -1.7869568e-01, -2.9557037e-01, ...,
        -1.5544891e-03, -1.8576384e-03,  0.0000000e+00],
       [-1.7418671e-01, -1.7869568e-01, -2.9557037e-01, ...,
        -1.5544891e-03, -1.8576384e-03,  0.0000000e+00]], dtype=float32)

In [None]:
trans, random_key = buffer.sample(
    random_key=random_key,
    sample_size=64,
    episodic_data_size=64,
    sample_traj=True,
)

In [None]:
trans.obs.shape

(64000, 28)

In [None]:
obs = trans.obs.reshape(64, 1000, -1)

In [None]:
obs.shape

(64, 1000, 28)

In [None]:
actions = trans.actions.reshape(64, 1000, -1)

In [None]:
rewards = trans.rewards.reshape(64, 1000, -1)

In [None]:
dones = trans.dones.reshape(64, 1000, -1)

In [None]:
dones

Array([[[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [1.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [1.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [1.]],

       ...,

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [1.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [1.]],

       [[0.],
        [0.],
        [0.],
        ...,
        [0.],
        [0.],
        [1.]]], dtype=float32)

In [None]:
mask = jax.vmap(me_mcpg_emitter.emitters[0].compute_mask, in_axes=0)(dones)

In [None]:
mask.shape

(64, 1000)

In [None]:
init_params

{'params': {'hidden_layers_0': {'bias': Array([[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),
   'kernel': Array([[[ 8.90522450e-02, -3.81690421e-04,  1.66703574e-03, ...,
            -7.17403442e-02,  1.57643072e-02, -6.27806410e-02],
           [-2.31234238e-01, -8.34264830e-02, -2.93827336e-03, ...,
             2.18603268e-01,  2.39269957e-01, -1.19056545e-01],
           [-1.64635301e-01,  2.17456639e-01,  4.62460518e-02, ...,
            -7.15165064e-02, -6.95979670e-02,  3.31199393e-02],
           ...,
           [-1.32311285e-01,  2.42075622e-01,  9.79097933e-03, ...,
             3.13385651e-02,  5.16709909e-02, -8.71773344e-03],
           [ 6.32021800e-02, -1.83664367e-01,  1.79741327e-02, ...,
            -1.99814420e-02, -5.96128218e-02,  7.21886987e-03],

In [None]:
first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_params)

In [None]:
logps = jax.vmap(me_mcpg_emitter.emitters[0].compute_logps, in_axes=(None, 0, 0))(first_genotype, obs, actions)

In [None]:
logps

Array([[-1.806705 , -1.8067458, -1.8067316, ..., -1.806603 , -1.8066237,
        -1.8066106],
       [-1.8064888, -1.8065231, -1.8065507, ..., -1.8064908, -1.8065021,
        -1.8064985],
       [-1.8068137, -1.8068669, -1.806854 , ..., -1.8065951, -1.8065946,
        -1.8065865],
       ...,
       [-1.807189 , -1.8070706, -1.8070931, ..., -1.8068595, -1.8067987,
        -1.8067913],
       [-1.8069313, -1.806901 , -1.8069676, ..., -1.806754 , -1.8067567,
        -1.8067582],
       [-1.8068497, -1.8068788, -1.8069222, ..., -1.8066405, -1.8066336,
        -1.8066478]], dtype=float32)

In [None]:
standardized_returns = me_mcpg_emitter.emitters[0].get_standardized_return(rewards, mask)

In [None]:
rewards[0][20:30]

Array([[0.99250793],
       [0.9928867 ],
       [0.99709135],
       [0.998076  ],
       [0.9950902 ],
       [0.99378616],
       [0.99568135],
       [0.99687713],
       [0.99556273],
       [0.9943279 ]], dtype=float32)

In [None]:
mask_ = jnp.expand_dims(mask, axis=-1)
valid_rewards = (rewards * mask_).squeeze(axis=-1)
return_ = jax.vmap(me_mcpg_emitter.emitters[0].get_return)(valid_rewards)

In [None]:
return_[0][]

Array([99.53472  , 99.52752  , 99.49046  , 99.47181  , 99.49231  ,
       99.49552  , 99.466805 , 99.43712  , 99.44918  , 99.46025  ,
       99.4587   , 99.440926 , 99.44003  , 99.44752  , 99.44797  ,
       99.439156 , 99.435074 , 99.43884  , 99.43999  , 99.43548  ,
       99.43209  , 99.43392  , 99.43539  , 99.43262  , 99.428825 ,
       99.42802  , 99.42851  , 99.4271   , 99.42447  , 99.42314  ,
       99.42304  , 99.4223   , 99.42064  , 99.41941  , 99.419044 ,
       99.41863  , 99.41767  , 99.41675  , 99.416336 , 99.41608  ,
       99.415565 , 99.414986 , 99.414665 , 99.414505 , 99.41426  ,
       99.413956 , 99.41378  , 99.41374  , 99.41372  , 99.413666 ,
       99.41368  , 99.41382  , 99.414    , 99.41419  , 99.41443  ,
       99.41476  , 99.41515  , 99.41559  , 99.41606  , 99.41661  ,
       99.41724  , 99.41791  , 99.41863  , 99.419426 , 99.42029  ,
       99.42121  , 99.422195 , 99.42325  , 99.42437  , 99.42556  ,
       99.42682  , 99.42814  , 99.429535 , 99.431    , 99.4325