In [1]:
!nvidia-smi

Sat Jun 29 15:37:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        Off | 00000000:2D:00.0  On |                  N/A |
|  0%   41C    P5              13W / 320W |     86MiB / 16376MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import os

os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['WANDB_CACHE_DIR'] = '/tmp/wandb_cache'
os.environ['JAX_LOG_COMPILATION'] = '1'

import logging
import time
from dataclasses import dataclass
from functools import partial
from math import floor
from typing import Any, Dict, Tuple, List, Callable
import pickle
from flax import serialization
#logging.basicConfig(level=logging.DEBUG)
import hydra
from omegaconf import OmegaConf, DictConfig
import jax
import jax.numpy as jnp
from hydra.core.config_store import ConfigStore
from qdax.core.map_elites import MAPElites
from qdax.types import RNGKey, Genotype
from qdax.utils.sampling import sampling 
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.core.neuroevolution.networks.networks import MLP, MLPRein
from qdax.core.emitters.rein_var import REINConfig, REINEmitter
#from qdax.core.emitters.rein_emitter_advanced import REINaiveConfig, REINaiveEmitter
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.environments import behavior_descriptor_extractor
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from utils import Config, get_env
from qdax.core.emitters.mutation_operators import isoline_variation
import wandb
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results, plot_2d_map_elites_repertoire
import matplotlib.pyplot as plt
from set_up_brax import get_reward_offset_brax

In [None]:
@dataclass
class ExperimentConfig:
    """Configuration from this experiment script
    """
    # Env config
    alg_name: str
    seed: int
    env_name: str
    episode_length: int
    policy_hidden_layer_sizes: Tuple[int, ...]   
    # ME config
    num_evaluations: int
    num_iterations: int
    batch_size: int
    num_samples: int
    fixed_init_state: bool
    discard_dead: bool
    # Emitter config
    is_sigma: float
    line_sigma: float
    crossover_percentage: float
    # Grid config 
    grid_shape: Tuple[int, ...]
    num_init_cvt_samples: int
    num_centroids: int
    # Log config
    log_period: int
    store_repertoire: bool
    store_reperoire_log_period: int
    
    # REINFORCE Parameters
    sample_number: int
    num_in_optimizer_steps: int
    adam_optimizer: bool
    learning_rate: float
    l2_coefficient: float
    scan_batch_size: int
    


@hydra.main(version_base="1.2", config_path="configs", config_name="rein-me")
def main(config: Config) -> None:
    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = get_env(config)
    reset_fn = jax.jit(env.reset)

    # Compute the centroids
    centroids, random_key = compute_cvt_centroids(
        num_descriptors=env.behavior_descriptor_length,
        num_init_cvt_samples=config.num_init_cvt_samples,
        num_centroids=config.num_centroids,
        minval=config.env.min_bd,
        maxval=config.env.max_bd,
        random_key=random_key,
    )
    # Init policy network
    policy_layer_sizes = config.policy_hidden_layer_sizes #+ (env.action_size,)
    print(policy_layer_sizes)
    
    '''
    policy_network = MLPRein(
        action_size=env.action_size,
        layer_sizes=policy_layer_sizes,
        kernel_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
        kernel_init_final=jax.nn.initializers.orthogonal(scale=0.01),
    )
    '''
    policy_network = MLPRein(
        action_size=env.action_size,
        layer_sizes=policy_layer_sizes,
        kernel_init=jax.nn.initializers.lecun_uniform(),
        kernel_init_final=jax.nn.initializers.lecun_uniform(),
    )


    # Init population of controllers
    
    # maybe consider adding two random keys for each policy
    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.split(subkey, num=config.batch_size)
    #split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
    #keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
    fake_batch_obs = jnp.zeros(shape=(config.batch_size, env.observation_size))
    init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

    param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
    print("Number of parameters in policy_network: ", param_count)

    # Define the fonction to play a step with the policy in the environment
    def play_step_fn(env_state, policy_params, random_key):
        #random_key, subkey = jax.random.split(random_key)
        actions = policy_network.apply(policy_params, env_state.obs)
        state_desc = env_state.info["state_descriptor"]
        next_state = env.step(env_state, actions)

        transition = QDTransition(
            obs=env_state.obs,
            next_obs=next_state.obs,
            rewards=next_state.reward,
            dones=next_state.done,
            truncations=next_state.info["truncation"],
            actions=actions,
            state_desc=state_desc,
            next_state_desc=next_state.info["state_descriptor"],
            #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
            #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
        )

        return next_state, policy_params, random_key, transition

    # Prepare the scoring function
    bd_extraction_fn = behavior_descriptor_extractor[config.env.name]
    scoring_fn = partial(
        scoring_function,
        episode_length=config.env.episode_length,
        play_reset_fn=reset_fn,
        play_step_fn=play_step_fn,
        behavior_descriptor_extractor=bd_extraction_fn,
    )
    #reward_offset = get_reward_offset_brax(env, config.env_name)
    #print(f"Reward offset: {reward_offset}")
    
    me_scoring_fn = partial(
    sampling,
    scoring_fn=scoring_fn,
    num_samples=config.num_samples,
)

    
    
    reward_offset = 0
    
    

    # Get minimum reward value to make sure qd_score are positive
    

    # Define a metrics function
    metrics_function = partial(
        default_qd_metrics,
        qd_offset=reward_offset * config.env.episode_length,
    )

    # Define the PG-emitter config
    
    rein_emitter_config = REINConfig(
        proportion_mutation_ga=config.proportion_mutation_ga,
        batch_size=config.batch_size,
        num_rein_training_steps=config.num_rein_training_steps,
        buffer_size=config.buffer_size,
        rollout_number=config.rollout_number,
        discount_rate=config.discount_rate,
        adam_optimizer=config.adam_optimizer,
        learning_rate=config.learning_rate,
    )
    

    variation_fn = partial(
        isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
    )
    
    rein_emitter = REINEmitter(
        config=rein_emitter_config,
        policy_network=policy_network,
        env=env,
        variation_fn=variation_fn,
        )
    


    # Instantiate MAP Elites
    map_elites = MAPElites(
        scoring_function=scoring_fn,
        emitter=rein_emitter,
        metrics_function=metrics_function,
    )

    # compute initial repertoire
    repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key)

    log_period = 1
    num_loops = int(config.num_iterations / log_period)


    # Main loop
    map_elites_scan_update = map_elites.scan_update
    eval_num = int(config.proportion_mutation_ga * (config.batch_size * config.rollout_number * config.num_rein_training_steps)) + config.batch_size
    print(f"Number of evaluations per iteration: {eval_num}")
    for i in range(num_loops):
        print(f"Loop {i+1}/{num_loops}")
        start_time = time.time()
        
        (repertoire, emitter_state, random_key,), current_metrics = jax.lax.scan(
            map_elites_scan_update,
            (repertoire, emitter_state, random_key),
            (),
            length=log_period,
        )
        timelapse = time.time() - start_time
