In [1]:
!nvidia-smi

Sat Aug  3 23:41:53 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   33C    P0              30W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
import os

os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['WANDB_CACHE_DIR'] = '/tmp/wandb_cache'
os.environ['JAX_LOG_COMPILATION'] = '1'
#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
#os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'

In [3]:
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Optional, Tuple

import flax.linen as nn
import jax
import optax
from jax import numpy as jnp

from qdax.core.containers.repertoire import Repertoire
from qdax.core.emitters.emitter import Emitter, EmitterState
from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer
from qdax.core.neuroevolution.losses.td3_loss import make_td3_loss_fn
from qdax.core.neuroevolution.networks.networks import QModule
from qdax.environments.base_wrappers import QDEnv
from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, Params, RNGKey

In [4]:
from typing import Tuple
from dataclasses import dataclass
import functools
import os
import time
import pickle

import jax
import jax.numpy as jnp
from flax import serialization

from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.environments import behavior_descriptor_extractor
from qdax.core.map_elites_pga import MAPElites
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results, plot_2d_map_elites_repertoire
from set_up_brax import get_reward_offset_brax
from qdax.utils.sampling import sampling 
from qdax.types import RNGKey, Genotype
from typing import Any, Dict, Tuple, List, Callable
import matplotlib.pyplot as plt
from qdax import environments_v1

In [5]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [6]:
@dataclass
class QualityPGConfig:
    """Configuration for QualityPG Emitter"""

    env_batch_size: int = 256
    num_critic_training_steps: int = 300
    num_pg_training_steps: int = 100

    # TD3 params
    replay_buffer_size: int = 1000000
    critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
    critic_learning_rate: float = 3e-4
    actor_learning_rate: float = 3e-4
    policy_learning_rate: float = 1e-3
    noise_clip: float = 0.5
    policy_noise: float = 0.2
    discount: float = 0.99
    reward_scaling: float = 1.0
    batch_size: int = 100
    soft_tau_update: float = 0.005
    policy_delay: int = 2


class QualityPGEmitterState(EmitterState):
    """Contains training state for the learner."""

    critic_params: Params
    critic_optimizer_state: optax.OptState
    actor_params: Params
    actor_opt_state: optax.OptState
    target_critic_params: Params
    target_actor_params: Params
    replay_buffer: ReplayBuffer
    random_key: RNGKey
    steps: jnp.ndarray


class QualityPGEmitter(Emitter):
    """
    A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites
    (PGA-Map-Elites) algorithm.
    """

    def __init__(
        self,
        config: QualityPGConfig,
        policy_network: nn.Module,
        env: QDEnv,
    ) -> None:
        self._config = config
        self._env = env
        self._policy_network = policy_network

        # Init Critics
        critic_network = QModule(
            n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size
        )
        self._critic_network = critic_network

        # Set up the losses and optimizers - return the opt states
        self._policy_loss_fn, self._critic_loss_fn = make_td3_loss_fn(
            policy_fn=policy_network.apply,
            critic_fn=critic_network.apply,
            reward_scaling=self._config.reward_scaling,
            discount=self._config.discount,
            noise_clip=self._config.noise_clip,
            policy_noise=self._config.policy_noise,
        )

        # Init optimizers
        self._actor_optimizer = optax.adam(
            learning_rate=self._config.actor_learning_rate
        )
        self._critic_optimizer = optax.adam(
            learning_rate=self._config.critic_learning_rate
        )
        self._policies_optimizer = optax.adam(
            learning_rate=self._config.policy_learning_rate
        )

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._config.env_batch_size

    @property
    def use_all_data(self) -> bool:
        """Whether to use all data or not when used along other emitters.

        QualityPGEmitter uses the transitions from the genotypes that were generated
        by other emitters.
        """
        return True

    def init(
        self,
        random_key: RNGKey,
        repertoire: Repertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> Tuple[QualityPGEmitterState, RNGKey]:
        """Initializes the emitter state.

        Args:
            genotypes: The initial population.
            random_key: A random key.

        Returns:
            The initial state of the PGAMEEmitter, a new random key.
        """

        observation_size = self._env.observation_size
        action_size = self._env.action_size
        descriptor_size = self._env.state_descriptor_length

        # Initialise critic, greedy actor and population
        random_key, subkey = jax.random.split(random_key)
        fake_obs = jnp.zeros(shape=(observation_size,))
        fake_action = jnp.zeros(shape=(action_size,))
        critic_params = self._critic_network.init(
            subkey, obs=fake_obs, actions=fake_action
        )
        target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params)

        actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes)
        target_actor_params = jax.tree_util.tree_map(lambda x: x[0], genotypes)

        # Prepare init optimizer states
        critic_optimizer_state = self._critic_optimizer.init(critic_params)
        actor_optimizer_state = self._actor_optimizer.init(actor_params)

        # Initialize replay buffer
        dummy_transition = QDTransition.init_dummy(
            observation_dim=observation_size,
            action_dim=action_size,
            descriptor_dim=descriptor_size,
        )

        replay_buffer = ReplayBuffer.init(
            buffer_size=self._config.replay_buffer_size, transition=dummy_transition
        )

        # get the transitions out of the dictionary
        assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
        transitions = extra_scores["transitions"]

        # add transitions in the replay buffer
        replay_buffer = replay_buffer.insert(transitions)

        # Initial training state
        random_key, subkey = jax.random.split(random_key)
        emitter_state = QualityPGEmitterState(
            critic_params=critic_params,
            critic_optimizer_state=critic_optimizer_state,
            actor_params=actor_params,
            actor_opt_state=actor_optimizer_state,
            target_critic_params=target_critic_params,
            target_actor_params=target_actor_params,
            replay_buffer=replay_buffer,
            random_key=subkey,
            steps=jnp.array(0),
        )

        return emitter_state, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: Repertoire,
        emitter_state: QualityPGEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, ExtraScores, RNGKey]:
        """Do a step of PG emission.

        Args:
            repertoire: the current repertoire of genotypes
            emitter_state: the state of the emitter used
            random_key: a random key

        Returns:
            A batch of offspring, the new emitter state and a new key.
        """

        batch_size = self._config.env_batch_size

        # sample parents
        mutation_pg_batch_size = int(batch_size - 1)
        print(type(repertoire))
        parents, random_key = repertoire.sample(random_key, mutation_pg_batch_size)

        # apply the pg mutation
        offsprings_pg = self.emit_pg(emitter_state, parents)

        # get the actor (greedy actor)
        offspring_actor = self.emit_actor(emitter_state)

        # add dimension for concatenation
        offspring_actor = jax.tree_util.tree_map(
            lambda x: jnp.expand_dims(x, axis=0), offspring_actor
        )

        # gather offspring
        genotypes = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            offsprings_pg,
            offspring_actor,
        )

        return genotypes, {}, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit_pg(
        self, emitter_state: QualityPGEmitterState, parents: Genotype
    ) -> Genotype:
        """Emit the offsprings generated through pg mutation.

        Args:
            emitter_state: current emitter state, contains critic and
                replay buffer.
            parents: the parents selected to be applied gradients in order
                to mutate towards better performance.

        Returns:
            A new set of offsprings.
        """
        mutation_fn = partial(
            self._mutation_function_pg,
            emitter_state=emitter_state,
        )
        offsprings = jax.vmap(mutation_fn)(parents)

        return offsprings

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype:
        """Emit the greedy actor.

        Simply needs to be retrieved from the emitter state.

        Args:
            emitter_state: the current emitter state, it stores the
                greedy actor.

        Returns:
            The parameters of the actor.
        """
        return emitter_state.actor_params

    @partial(jax.jit, static_argnames=("self",))
    def state_update(
        self,
        emitter_state: QualityPGEmitterState,
        repertoire: Optional[Repertoire],
        genotypes: Optional[Genotype],
        fitnesses: Optional[Fitness],
        descriptors: Optional[Descriptor],
        extra_scores: ExtraScores,
    ) -> QualityPGEmitterState:
        """This function gives an opportunity to update the emitter state
        after the genotypes have been scored.

        Here it is used to fill the Replay Buffer with the transitions
        from the scoring of the genotypes, and then the training of the
        critic/actor happens. Hence the params of critic/actor are updated,
        as well as their optimizer states.

        Args:
            emitter_state: current emitter state.
            repertoire: the current genotypes repertoire
            genotypes: unused here - but compulsory in the signature.
            fitnesses: unused here - but compulsory in the signature.
            descriptors: unused here - but compulsory in the signature.
            extra_scores: extra information coming from the scoring function,
                this contains the transitions added to the replay buffer.

        Returns:
            New emitter state where the replay buffer has been filled with
            the new experienced transitions.
        """
        # get the transitions out of the dictionary
        assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
        transitions = extra_scores["transitions"]

        # add transitions in the replay buffer
        replay_buffer = emitter_state.replay_buffer.insert(transitions)
        emitter_state = emitter_state.replace(replay_buffer=replay_buffer)

        def scan_train_critics(
            carry: QualityPGEmitterState, unused: Any
        ) -> Tuple[QualityPGEmitterState, Any]:
            emitter_state = carry
            new_emitter_state = self._train_critics(emitter_state)
            return new_emitter_state, ()

        # Train critics and greedy actor
        emitter_state, _ = jax.lax.scan(
            scan_train_critics,
            emitter_state,
            (),
            length=self._config.num_critic_training_steps,
        )

        return emitter_state  # type: ignore

    @partial(jax.jit, static_argnames=("self",))
    def _train_critics(
        self, emitter_state: QualityPGEmitterState
    ) -> QualityPGEmitterState:
        """Apply one gradient step to critics and to the greedy actor
        (contained in carry in training_state), then soft update target critics
        and target actor.

        Those updates are very similar to those made in TD3.

        Args:
            emitter_state: actual emitter state

        Returns:
            New emitter state where the critic and the greedy actor have been
            updated. Optimizer states have also been updated in the process.
        """

        # Sample a batch of transitions in the buffer
        random_key = emitter_state.random_key
        replay_buffer = emitter_state.replay_buffer
        transitions, random_key = replay_buffer.sample(
            random_key, sample_size=self._config.batch_size
        )

        # Update Critic
        (
            critic_optimizer_state,
            critic_params,
            target_critic_params,
            random_key,
        ) = self._update_critic(
            critic_params=emitter_state.critic_params,
            target_critic_params=emitter_state.target_critic_params,
            target_actor_params=emitter_state.target_actor_params,
            critic_optimizer_state=emitter_state.critic_optimizer_state,
            transitions=transitions,
            random_key=random_key,
        )

        # Update greedy actor
        (actor_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond(
            emitter_state.steps % self._config.policy_delay == 0,
            lambda x: self._update_actor(*x),
            lambda _: (
                emitter_state.actor_opt_state,
                emitter_state.actor_params,
                emitter_state.target_actor_params,
            ),
            operand=(
                emitter_state.actor_params,
                emitter_state.actor_opt_state,
                emitter_state.target_actor_params,
                emitter_state.critic_params,
                transitions,
            ),
        )

        # Create new training state
        new_emitter_state = emitter_state.replace(
            critic_params=critic_params,
            critic_optimizer_state=critic_optimizer_state,
            actor_params=actor_params,
            actor_opt_state=actor_optimizer_state,
            target_critic_params=target_critic_params,
            target_actor_params=target_actor_params,
            random_key=random_key,
            steps=emitter_state.steps + 1,
            replay_buffer=replay_buffer,
        )

        return new_emitter_state  # type: ignore

    @partial(jax.jit, static_argnames=("self",))
    def _update_critic(
        self,
        critic_params: Params,
        target_critic_params: Params,
        target_actor_params: Params,
        critic_optimizer_state: Params,
        transitions: QDTransition,
        random_key: RNGKey,
    ) -> Tuple[Params, Params, Params, RNGKey]:

        # compute loss and gradients
        random_key, subkey = jax.random.split(random_key)
        critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)(
            critic_params,
            target_actor_params,
            target_critic_params,
            transitions,
            subkey,
        )
        critic_updates, critic_optimizer_state = self._critic_optimizer.update(
            critic_gradient, critic_optimizer_state
        )

        # update critic
        critic_params = optax.apply_updates(critic_params, critic_updates)

        # Soft update of target critic network
        target_critic_params = jax.tree_map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            target_critic_params,
            critic_params,
        )

        return critic_optimizer_state, critic_params, target_critic_params, random_key

    @partial(jax.jit, static_argnames=("self",))
    def _update_actor(
        self,
        actor_params: Params,
        actor_opt_state: optax.OptState,
        target_actor_params: Params,
        critic_params: Params,
        transitions: QDTransition,
    ) -> Tuple[optax.OptState, Params, Params]:

        # Update greedy actor
        policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)(
            actor_params,
            critic_params,
            transitions,
        )
        (
            policy_updates,
            actor_optimizer_state,
        ) = self._actor_optimizer.update(policy_gradient, actor_opt_state)
        actor_params = optax.apply_updates(actor_params, policy_updates)

        # Soft update of target greedy actor
        target_actor_params = jax.tree_map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            target_actor_params,
            actor_params,
        )

        return (
            actor_optimizer_state,
            actor_params,
            target_actor_params,
        )

    @partial(jax.jit, static_argnames=("self",))
    def _mutation_function_pg(
        self,
        policy_params: Genotype,
        emitter_state: QualityPGEmitterState,
    ) -> Genotype:
        """Apply pg mutation to a policy via multiple steps of gradient descent.
        First, update the rewards to be diversity rewards, then apply the gradient
        steps.

        Args:
            policy_params: a policy, supposed to be a differentiable neural
                network.
            emitter_state: the current state of the emitter, containing among others,
                the replay buffer, the critic.

        Returns:
            The updated params of the neural network.
        """

        # Define new policy optimizer state
        policy_optimizer_state = self._policies_optimizer.init(policy_params)

        def scan_train_policy(
            carry: Tuple[QualityPGEmitterState, Genotype, optax.OptState],
            unused: Any,
        ) -> Tuple[Tuple[QualityPGEmitterState, Genotype, optax.OptState], Any]:
            emitter_state, policy_params, policy_optimizer_state = carry
            (
                new_emitter_state,
                new_policy_params,
                new_policy_optimizer_state,
            ) = self._train_policy(
                emitter_state,
                policy_params,
                policy_optimizer_state,
            )
            return (
                new_emitter_state,
                new_policy_params,
                new_policy_optimizer_state,
            ), ()

        (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan(
            scan_train_policy,
            (emitter_state, policy_params, policy_optimizer_state),
            (),
            length=self._config.num_pg_training_steps,
        )

        return policy_params

    @partial(jax.jit, static_argnames=("self",))
    def _train_policy(
        self,
        emitter_state: QualityPGEmitterState,
        policy_params: Params,
        policy_optimizer_state: optax.OptState,
    ) -> Tuple[QualityPGEmitterState, Params, optax.OptState]:
        """Apply one gradient step to a policy (called policy_params).

        Args:
            emitter_state: current state of the emitter.
            policy_params: parameters corresponding to the weights and bias of
                the neural network that defines the policy.

        Returns:
            The new emitter state and new params of the NN.
        """

        # Sample a batch of transitions in the buffer
        random_key = emitter_state.random_key
        replay_buffer = emitter_state.replay_buffer
        transitions, random_key = replay_buffer.sample(
            random_key, sample_size=self._config.batch_size
        )

        # update policy
        policy_optimizer_state, policy_params = self._update_policy(
            critic_params=emitter_state.critic_params,
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            transitions=transitions,
        )

        # Create new training state
        new_emitter_state = emitter_state.replace(
            random_key=random_key,
            replay_buffer=replay_buffer,
        )

        return new_emitter_state, policy_params, policy_optimizer_state

    @partial(jax.jit, static_argnames=("self",))
    def _update_policy(
        self,
        critic_params: Params,
        policy_optimizer_state: optax.OptState,
        policy_params: Params,
        transitions: QDTransition,
    ) -> Tuple[optax.OptState, Params]:

        # compute loss
        _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)(
            policy_params,
            critic_params,
            transitions,
        )
        # Compute gradient and update policies
        (
            policy_updates,
            policy_optimizer_state,
        ) = self._policies_optimizer.update(policy_gradient, policy_optimizer_state)
        policy_params = optax.apply_updates(policy_params, policy_updates)

        return policy_optimizer_state, policy_params

In [7]:
@dataclass
class Config:
    """Configuration for the experiment script
    """
    # Basic Configuration
    seed: int
    num_iterations: int
    num_samples: int
    env_batch_size: int  # Used in GA emitter and PG emitter
    batch_size: int

    # Archive Configuration
    num_init_cvt_samples: int
    num_centroids: int
    policy_hidden_layer_sizes: Tuple[int, ...]
    proportion_mutation_ga: float

    # GA Emitter Configuration
    iso_sigma: float
    line_sigma: float

    # PG Emitter Configuration
    critic_hidden_layer_size: Tuple[int, ...]
    num_critic_training_steps: int
    num_pg_training_steps: int
    replay_buffer_size: int
    discount: float
    reward_scaling: float
    critic_learning_rate: float
    actor_learning_rate: float
    policy_learning_rate: float
    noise_clip: float
    policy_noise: float
    soft_tau_update: float
    policy_delay: int

config = Config(
    seed=0,
    num_iterations=200,
    num_samples=32,
    batch_size=100,
    env_batch_size=5096,
    num_init_cvt_samples=50000,
    num_centroids=1024,
    policy_hidden_layer_sizes=[128, 128],
    proportion_mutation_ga=0.5,
    iso_sigma=0.005,
    line_sigma=0.05,
    critic_hidden_layer_size=[256, 256],
    num_critic_training_steps=3000,
    num_pg_training_steps=150,
    replay_buffer_size=5096000, #2048000, #1_000_000,
    discount=0.99,
    reward_scaling=1.0,
    critic_learning_rate=3e-4,
    actor_learning_rate=3e-4,
    policy_learning_rate=5e-3,
    noise_clip=0.5,
    policy_noise=0.2,
    soft_tau_update=0.005,
    policy_delay=2
)

In [8]:
'''
class Normalizer:
    def __init__(self, size, epsilon=1e-8):
        self.size = size
        self.mean = jnp.zeros(size)
        self.var = jnp.ones(size)
        self.count = epsilon

    def update(self, x):
        batch_mean = jnp.mean(x, axis=0)
        batch_var = jnp.var(x, axis=0)
        batch_count = x.shape[0]

        self.mean, self.var, self.count = self._update_mean_var_count(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

    def normalize(self, x):
        return (x - self.mean) / jnp.sqrt(self.var + EPS)

    def _update_mean_var_count(self, mean, var, count, batch_mean, batch_var, batch_count):
        delta = batch_mean - mean
        tot_count = count + batch_count

        new_mean = mean + delta * batch_count / tot_count
        m_a = var * count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        return new_mean, new_var, new_count
'''
'''
class Normalizer:
    def __init__(self, size, epsilon=1e-8):
        self.size = size  # Expecting size to be the dimensionality of the observation features (z)
        self.mean = jnp.zeros(size)
        self.var = jnp.ones(size)
        self.count = epsilon

    def update(self, x):
        # Flatten the first two dimensions (x, y) to treat as a single batch dimension
        flat_x = x.reshape(-1, self.size)
        batch_mean = jnp.mean(flat_x, axis=0)
        batch_var = jnp.var(flat_x, axis=0)
        batch_count = flat_x.shape[0]

        self.mean, self.var, self.count = self._update_mean_var_count(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

    def normalize(self, x):
        # Normalize maintaining the original shape, using broadcasting
        return (x - self.mean) / jnp.sqrt(self.var + 1e-8)

    def _update_mean_var_count(self, mean, var, count, batch_mean, batch_var, batch_count):
        delta = batch_mean - mean
        tot_count = count + batch_count

        new_mean = mean + delta * batch_count / tot_count
        m_a = var * count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        return new_mean, new_var, new_count
        
'''

from jax import tree_util


@dataclass
class Normalizer:
    size: int
    mean: jnp.ndarray = None
    var: jnp.ndarray = None
    count: jnp.ndarray = 1e-4
    
    def __post_init__(self):
        if self.mean is None:
            self.mean = jnp.zeros(self.size)
        if self.var is None:
            self.var = jnp.ones(self.size)
            
    def update(self, x):
        # Flatten the first two dimensions (x, y) to treat as a single batch dimension
        flat_x = x.reshape(-1, self.size)
        batch_mean = jnp.mean(flat_x, axis=0)
        batch_var = jnp.var(flat_x, axis=0)
        batch_count = flat_x.shape[0]

        new_mean, new_var, new_count = self._update_mean_var_count(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
        
        return self.replace(mean=new_mean, var=new_var, count=new_count)

    def normalize(self, x):
        # Normalize maintaining the original shape, using broadcasting
        return (x - self.mean) / jnp.sqrt(self.var + 1e-8)

    def _update_mean_var_count(self, mean, var, count, batch_mean, batch_var, batch_count):
        delta = batch_mean - mean
        tot_count = count + batch_count

        new_mean = mean + delta * batch_count / tot_count
        m_a = var * count
        m_b = batch_var * batch_count
        M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
        new_var = M2 / tot_count
        new_count = tot_count

        return new_mean, new_var, new_count
'''
    def tree_flatten(self):
        return ((self.mean, self.var, self.count), self.size)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        size = aux_data
        mean, var, count = children
        return cls(size=size, mean=mean, var=var, count=count)

# Register Normalizer as a pytree node with JAX
tree_util.register_pytree_node(
    Normalizer,
    Normalizer.tree_flatten,
    Normalizer.tree_unflatten
)
'''

@dataclass
class RewardNormalizer:
    size: int
    mean: jnp.ndarray = 0.0
    var: jnp.ndarray = 1.0
    count: jnp.ndarray = 1e-4
    return_val: jnp.ndarray = None
    
    def __post_init__(self):
        if self.return_val is None:
            self.return_val = jnp.zeros((self.size,))

         
    def update(self, reward, done, gamma=0.99):
        
        def _update_column_scan(carry, x):
            mean, var, count, return_val = carry
            (reward, done) = x
            
            # Update the return value
            new_return_val = reward + gamma * return_val * (1 - done)
            
            # Update the mean, var, and count
            batch_mean = jnp.mean(new_return_val, axis=0)
            batch_var = jnp.var(new_return_val, axis=0)
            batch_count = new_return_val.shape[0]
            
            delta = batch_mean - mean
            tot_count = count + batch_count
            
            new_mean = mean + delta * batch_count / tot_count
            m_a = var * count
            m_b = batch_var * batch_count
            M2 = m_a + m_b + jnp.square(delta) * count * batch_count / tot_count
            new_var = M2 / tot_count
            new_count = tot_count
            
            normalized_reward = reward / jnp.sqrt(new_var + 1e-8)
            
            return (new_mean, new_var, new_count, new_return_val), normalized_reward
        
        (new_mean, new_var, new_count, _), normalized_rewards = jax.lax.scan(
            _update_column_scan,
            (self.mean, self.var, self.count, self.return_val),
            (reward.T, done.T),
        )

        
        return self.replace(mean=new_mean, var=new_var, count=new_count), normalized_rewards.T

'''
    def tree_flatten(self):
        return ((self.mean, self.var, self.count, self.return_val), self.size)
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        size = aux_data
        mean, var, count, return_val = children
        return cls(size=size, mean=mean, var=var, count=count, return_val=return_val)
    
        
tree_util.register_pytree_node(
    RewardNormalizer,
    RewardNormalizer.tree_flatten,
    RewardNormalizer.tree_unflatten
)
'''

'\n    def tree_flatten(self):\n        return ((self.mean, self.var, self.count, self.return_val), self.size)\n    \n    @classmethod\n    def tree_unflatten(cls, aux_data, children):\n        size = aux_data\n        mean, var, count, return_val = children\n        return cls(size=size, mean=mean, var=var, count=count, return_val=return_val)\n    \n        \ntree_util.register_pytree_node(\n    RewardNormalizer,\n    RewardNormalizer.tree_flatten,\n    RewardNormalizer.tree_unflatten\n)\n'

In [9]:
env.observation_size

NameError: name 'env' is not defined

In [9]:
random_key = jax.random.PRNGKey(config.seed)

# Init environment
env = get_env('ant_uni')
reset_fn = jax.jit(env.reset)
normalizer = Normalizer(env.observation_size)
reward_normalizer = RewardNormalizer(config.env_batch_size)

# Compute the centroids
centroids, random_key = compute_cvt_centroids(
    num_descriptors=env.behavior_descriptor_length,
    num_init_cvt_samples=config.num_init_cvt_samples,
    num_centroids=config.num_centroids,
    minval=0,
    maxval=1,
    random_key=random_key,
)

# Init policy network

policy_layer_sizes = config.policy_hidden_layer_sizes + [env.action_size]
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=config.env_batch_size)
fake_batch_obs = jnp.zeros(shape=(config.env_batch_size, env.observation_size))
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
print("Number of parameters in policy_network: ", param_count)

# Define the fonction to play a step with the policy in the environment
@jax.jit
def play_step_fn(env_state, policy_params, random_key):
    actions = policy_network.apply(policy_params, env_state.obs)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=actions,
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
        #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
        #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
    )

    return next_state, policy_params, random_key, transition

# Prepare the scoring function
bd_extraction_fn = behavior_descriptor_extractor['ant_uni']
scoring_fn = functools.partial(
    scoring_function,
    episode_length=env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
    #normalizer=normalizer,
    #reward_normalizer=reward_normalizer,
)


me_scoring_fn = functools.partial(
sampling,
scoring_fn=scoring_fn,
num_samples=config.num_samples,
)

reward_offset = 0

metrics_function = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * env.episode_length,
)

# Define the PG-emitter config
pga_emitter_config = PGAMEConfig(
    env_batch_size=config.env_batch_size,
    proportion_mutation_ga=config.proportion_mutation_ga,
    critic_hidden_layer_size=config.critic_hidden_layer_size,
    num_critic_training_steps=config.num_critic_training_steps,
    num_pg_training_steps=config.num_pg_training_steps,
    batch_size=config.batch_size,
    replay_buffer_size=config.replay_buffer_size,
    discount=config.discount,
    reward_scaling=config.reward_scaling,
    critic_learning_rate=config.critic_learning_rate,
    #actor_learning_rate=config.algo.actor_learning_rate,
    policy_learning_rate=config.policy_learning_rate,
    noise_clip=config.noise_clip,
    policy_noise=config.policy_noise,
    soft_tau_update=config.soft_tau_update,
    policy_delay=config.policy_delay,
)

# Get the emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
)

pg_emitter = PGAMEEmitter(
    config=pga_emitter_config,
    policy_network=policy_network,
    env=env,
    variation_fn=variation_fn,
)

# Instantiate MAP Elites
map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=pg_emitter,
    metrics_function=metrics_function,
)

# compute initial repertoire
repertoire, emitter_state, random_key, normalizer, reward_normalizer = map_elites.init(init_params, centroids, random_key)


Number of parameters in policy_network:  21256


  repertoire = MapElitesRepertoire.init(


<class 'qdax.core.containers.mapelites_repertoire.MapElitesRepertoire'>


In [18]:
emitter_state.emitter_states[0].extra_scores

AttributeError: 'QualityPGEmitterState' object has no attribute 'extra_scores'

In [19]:
pga_emitter_config = QualityPGConfig(
    env_batch_size=config.env_batch_size,
    num_critic_training_steps=config.num_critic_training_steps,
    num_pg_training_steps=config.num_pg_training_steps,
    replay_buffer_size=config.replay_buffer_size,
    critic_hidden_layer_size=config.critic_hidden_layer_size,
    critic_learning_rate=config.critic_learning_rate,
    actor_learning_rate=config.actor_learning_rate,
    policy_learning_rate=config.policy_learning_rate,
    noise_clip=config.noise_clip,
    policy_noise=config.policy_noise,
    discount=config.discount,
    reward_scaling=config.reward_scaling,
    batch_size=config.batch_size,
    soft_tau_update=config.soft_tau_update,
    policy_delay=config.policy_delay,
)

In [20]:
pga_emitter = QualityPGEmitter(
    config=pga_emitter_config,
    policy_network=policy_network,
    env=env,
)

In [21]:
fitnesses, descriptors, extra_scores, random_key, normalizer, reward_normalizer = scoring_fn(
    init_params, random_key, normalizer, reward_normalizer
)

In [25]:
reward_normalizer.return_val

Array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)

In [16]:
extra_scores["rewards"][0]

Array([0.9133669 , 0.64897674, 0.46943998, 0.33254927, 0.13727005,
       0.        , 0.        , 0.        , 0.4512778 , 0.5159382 ,
       0.36667266, 0.69073284, 0.6368465 , 0.56640613, 0.86703014,
       0.94544053, 0.7863475 , 0.35145256, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.8015611 ,
       0.9226498 , 0.2591268 , 0.        , 0.5491032 , 0.11409856,
       0.2736179 , 0.9644985 , 0.4596038 , 0.73729664, 0.44664145,
       0.73134035, 0.9176377 , 0.5421054 , 0.5023163 , 0.6601326 ,
       0.62721986, 0.584525  , 0.43393892, 0.46974394, 0.48667565,
       0.37871093, 0.        , 0.22576216, 0.        , 0.        ,
       0.23547144, 0.7468129 , 0.6553663 , 0.        , 0.        ,
       0.4728555 , 0.56349236, 0.843907  , 0.38519058, 1.1742105 ,
       1.0411208 , 0.76950645, 0.69206184, 0.04383034, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.33414796,
       0.02556707, 0.        , 0.        , 0.        , 0.     

In [15]:
extra_scores["transitions"].rewards[0]

Array([0.04110002, 0.02920406, 0.02112583, 0.01496618, 0.0061781 ,
       0.        , 0.        , 0.        , 0.02031648, 0.02322963,
       0.01651075, 0.0311061 , 0.02868279, 0.02551345, 0.03906013,
       0.04259862, 0.0354357 , 0.01584024, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.03617386,
       0.04164728, 0.0116992 , 0.        , 0.02480238, 0.0051549 ,
       0.01236484, 0.04359643, 0.02077975, 0.0333433 , 0.02020397,
       0.03309102, 0.0415314 , 0.02454173, 0.02274659, 0.02990126,
       0.02841833, 0.02649132, 0.01967216, 0.02130143, 0.02207559,
       0.0171833 , 0.        , 0.01024955, 0.        , 0.        ,
       0.01069991, 0.03394569, 0.02979809, 0.        , 0.        ,
       0.02151941, 0.02565213, 0.03842942, 0.01754604, 0.05350371,
       0.04745412, 0.03508488, 0.0315637 , 0.00199965, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.01527321,
       0.00116898, 0.        , 0.        , 0.        , 0.     

In [19]:
normalizer.count

1e-08

In [16]:
import timeit

def score():
    scoring_fn(
        init_params, random_key
    )

score()

timer = timeit.Timer(score)
results = timer.repeat(repeat=10, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)



Mean time: 3.8053678499534724
Standard deviation: 0.0019505445173990574


In [17]:
repertoire = MapElitesRepertoire.init(
    genotypes=init_params,
    fitnesses=fitnesses,
    descriptors=descriptors,
    centroids=centroids,
    extra_scores=extra_scores,
)

  repertoire = MapElitesRepertoire.init(


In [18]:
emitter_state, random_key = pga_emitter.init(
    random_key=random_key,
    repertoire=repertoire,
    genotypes=init_params,
    fitnesses=fitnesses,
    descriptors=descriptors,
    extra_scores=extra_scores,
)

In [19]:
import timeit

def update_state():
    pga_emitter.state_update(emitter_state=emitter_state, extra_scores=extra_scores, repertoire=repertoire, genotypes=init_params, fitnesses=fitnesses, descriptors=descriptors)
    
    
update_state()

timer = timeit.Timer(update_state)
results = timer.repeat(repeat=10, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)

Mean time: 0.7683740634471178
Standard deviation: 0.0012103786950750863


In [20]:
first_genotype = jax.tree_map(lambda x: x[0], init_params)

In [21]:
random_key = jax.random.PRNGKey(0)
def mutation_fn():
    pga_emitter._mutation_function_pg(policy_params=first_genotype, emitter_state=emitter_state)
mutation_fn()

timer = timeit.Timer(mutation_fn)
results = timer.repeat(repeat=10, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)

Mean time: 0.02882700189948082
Standard deviation: 0.005437930463612316


In [22]:
half_batch = jax.tree_map(lambda x: x[:int(0.5*config.env_batch_size)], init_params)

In [23]:
def emit():
    pga_emitter.emit_pg(emitter_state=emitter_state, parents=half_batch)
    
emit()

timer = timeit.Timer(emit)

results = timer.repeat(repeat=10, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)

Mean time: 5.0848997985944155
Standard deviation: 0.002601589654325258
