In [1]:
!nvidia-smi

Mon Jul  1 01:31:19 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        Off | 00000000:2D:00.0 Off |                  N/A |
|  0%   49C    P3              56W / 320W |     86MiB / 16376MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import os

os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['WANDB_CACHE_DIR'] = '/tmp/wandb_cache'
os.environ['JAX_LOG_COMPILATION'] = '1'

import logging
import time
from dataclasses import dataclass
from functools import partial
from math import floor
from typing import Any, Dict, Tuple, List, Callable
import pickle
from flax import serialization
#logging.basicConfig(level=logging.DEBUG)
import hydra
from omegaconf import OmegaConf, DictConfig
import jax
import jax.numpy as jnp
from hydra.core.config_store import ConfigStore
from qdax.core.map_elites import MAPElites
from qdax.types import RNGKey, Genotype
from qdax.utils.sampling import sampling 
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.core.neuroevolution.networks.networks import MLP, MLPRein
from qdax.core.emitters.rein_var import REINConfig, REINEmitter
#from qdax.core.emitters.rein_emitter_advanced import REINaiveConfig, REINaiveEmitter
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.environments import behavior_descriptor_extractor
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from utils import Config, get_env
from qdax.core.emitters.mutation_operators import isoline_variation
import wandb
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results, plot_2d_map_elites_repertoire
import matplotlib.pyplot as plt
from set_up_brax import get_reward_offset_brax
from qdax import environments_v1, environments


In [3]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 10

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [4]:
@dataclass
class Config:
    """Configuration from this experiment script
    """
    # Env config
    #alg_name: str
    seed: int
    env_name: str
    episode_length: int
    policy_hidden_layer_sizes: Tuple[int, ...]   
    # ME config
    num_evaluations: int
    num_iterations: int
    batch_size: int
    num_samples: int
    fixed_init_state: bool
    discard_dead: bool
    # Emitter config
    iso_sigma: float
    line_sigma: float
    #crossover_percentage: float
    # Grid config 
    grid_shape: Tuple[int, ...]
    num_init_cvt_samples: int
    num_centroids: int
    # Log config
    log_period: int
    store_repertoire: bool
    store_repertoire_log_period: int
    
    # REINFORCE Parameters
    proportion_mutation_ga : float
    rollout_number: int
    num_rein_training_steps: int
    adam_optimizer: bool
    learning_rate: float
    discount_rate: float
    temperature: int
    buffer_size: int

In [5]:
config = Config(
    seed=0,
    env_name='ant_uni',
    episode_length=10,
    policy_hidden_layer_sizes=[128, 128],
    num_evaluations=0,
    num_iterations=200,
    num_samples=32,
    batch_size=10,
    fixed_init_state=False,
    discard_dead=False,
    grid_shape=[50, 50],
    num_init_cvt_samples=50000,
    num_centroids=1296,
    log_period=400,
    store_repertoire=True,
    store_repertoire_log_period=800,
    iso_sigma=0.005,
    line_sigma=0.05,
    proportion_mutation_ga=0.5,
    rollout_number=1, # Num of episodes used for gradient estimate
    num_rein_training_steps=1, # Num gradient steps per generation
    buffer_size=200, # Size of the replay buffer
    adam_optimizer=True,
    learning_rate=1e-3,
    discount_rate=0.99,
    temperature=0,
)

In [6]:
# Init a random key
random_key = jax.random.PRNGKey(config.seed)

# Init environment
env = get_env("ant_uni")
reset_fn = jax.jit(env.reset)

# Compute the centroids
centroids, random_key = compute_cvt_centroids(
    num_descriptors=env.behavior_descriptor_length,
    num_init_cvt_samples=config.num_init_cvt_samples,
    num_centroids=config.num_centroids,
    minval=0,
    maxval=1,
    random_key=random_key,
)
# Init policy network
policy_layer_sizes = config.policy_hidden_layer_sizes #+ (env.action_size,)
print(policy_layer_sizes)

'''
policy_network = MLPRein(
    action_size=env.action_size,
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    kernel_init_final=jax.nn.initializers.orthogonal(scale=0.01),
)
'''
policy_network = MLPRein(
    action_size=env.action_size,
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    kernel_init_final=jax.nn.initializers.lecun_uniform(),
)


# Init population of controllers

# maybe consider adding two random keys for each policy
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=config.batch_size)
#split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
#keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
fake_batch_obs = jnp.zeros(shape=(config.batch_size, env.observation_size))
init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
print("Number of parameters in policy_network: ", param_count)

# Define the fonction to play a step with the policy in the environment
def play_step_fn(env_state, policy_params, random_key):
    #random_key, subkey = jax.random.split(random_key)
    actions = policy_network.apply(policy_params, env_state.obs)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=actions,
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
        #desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
        #desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
    )

    return next_state, policy_params, random_key, transition

# Prepare the scoring function
bd_extraction_fn = behavior_descriptor_extractor['ant_uni']
scoring_fn = partial(
    scoring_function,
    episode_length=env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)
#reward_offset = get_reward_offset_brax(env, config.env_name)
#print(f"Reward offset: {reward_offset}")

me_scoring_fn = partial(
sampling,
scoring_fn=scoring_fn,
num_samples=config.num_samples,
)



reward_offset = 0



# Get minimum reward value to make sure qd_score are positive


# Define a metrics function
metrics_function = partial(
    default_qd_metrics,
    qd_offset=reward_offset * env.episode_length,
)

# Define the PG-emitter config

rein_emitter_config = REINConfig(
    proportion_mutation_ga=config.proportion_mutation_ga,
    batch_size=config.batch_size,
    num_rein_training_steps=config.num_rein_training_steps,
    buffer_size=config.buffer_size,
    rollout_number=config.rollout_number,
    discount_rate=config.discount_rate,
    adam_optimizer=config.adam_optimizer,
    learning_rate=config.learning_rate,
)


variation_fn = partial(
    isoline_variation, iso_sigma=config.iso_sigma, line_sigma=config.line_sigma
)

rein_emitter = REINEmitter(
    config=rein_emitter_config,
    policy_network=policy_network,
    env=env,
    variation_fn=variation_fn,
    )



# Instantiate MAP Elites
map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=rein_emitter,
    metrics_function=metrics_function,
)

# compute initial repertoire
repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key)

log_period = 1
num_loops = int(config.num_iterations / log_period)


# Main loop
map_elites_scan_update = map_elites.scan_update
'''
for i in range(num_loops):
    print(f"Loop {i+1}/{num_loops}")
    start_time = time.time()
    
    (repertoire, emitter_state, random_key,), current_metrics = jax.lax.scan(
        map_elites_scan_update,
        (repertoire, emitter_state, random_key),
        (),
        length=log_period,
    )
    timelapse = time.time() - start_time
    
'''


[128, 128]
Number of parameters in policy_network:  21264
10


  repertoire = MapElitesRepertoire.init(


'\nfor i in range(num_loops):\n    print(f"Loop {i+1}/{num_loops}")\n    start_time = time.time()\n    \n    (repertoire, emitter_state, random_key,), current_metrics = jax.lax.scan(\n        map_elites_scan_update,\n        (repertoire, emitter_state, random_key),\n        (),\n        length=log_period,\n    )\n    timelapse = time.time() - start_time\n    \n'

In [18]:
repertoire, emitter_state, metrics, random_key = map_elites.update(repertoire, emitter_state, random_key)

dones : [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
Flattened transitions pre: [[[ 0.5473534   1.          0.         ...  0.          1.
    0.        ]
  [ 0.53608376  0.99988157  0.00364783 ...  0.          1.
    0.        ]
  [ 0.5102447   0.9995933   0.00332651 ...  0.          1.
    1.        ]
  ...
  [ 0.52190715  0.9976168  -0.00360796 ...  0.          1.
    0.        ]
  [ 0.54783225  0.99806434 -0.00888216 ...  0.          1.
    0.        ]
  [ 0.59710866  0.9972295  -0.00430547 ...  0.          0.
    0.        ]]

 [[ 0.5230979   1.          0.         ...  0.          0.
    1.        ]
  [ 0.5165462   0.9990979  -0.01194586 ...  0.          1.
    1.        ]
  [ 0.5282194   0.99505246

In [19]:
x = emitter_state.emitter_states[0].trajectory_buffer

In [19]:
x.data.masks

AttributeError: 'ArrayImpl' object has no attribute 'masks'

In [20]:
print(f"Episode length: {x.episode_length}\n")
print(f"Environment batch size: {x.env_batch_size}\n")
print(f"Number of trajectories: {x.num_trajectories}\n")

print(f"Current position: {x.current_position}\n")
print(f"Current size: {x.current_size}\n")
print(f"Trajectory positions: {x.trajectory_positions}\n")
print(f"Timestep positions: {x.timestep_positions}\n")
print(f"Episodic data: {x.episodic_data}\n")
print(f"Current episodic data size: {x.current_episodic_data_size}\n")

Episode length: 10

Environment batch size: 10

Number of trajectories: 20

Current position: 0

Current size: 200

Trajectory positions: [4 4 4 4 4 4 4 4 4 4]

Timestep positions: [0 0 0 0 0 0 0 0 0 0]

Episodic data: [[  0.  10.  20.  30.  40.  50.  60.  70.  80.  90.]
 [  1.  11.  21.  31.  41.  51.  61.  71.  81.  91.]
 [  2.  12.  22.  32.  42.  52.  62.  72.  82.  92.]
 [  3.  13.  23.  33.  43.  53.  63.  73.  83.  93.]
 [  4.  14.  24.  34.  44.  54.  64.  74.  84.  94.]
 [  5.  15.  25.  35.  45.  55.  65.  75.  85.  95.]
 [  6.  16.  26.  36.  46.  56.  66.  76.  86.  96.]
 [  7.  17.  27.  37.  47.  57.  67.  77.  87.  97.]
 [  8.  18.  28.  38.  48.  58.  68.  78.  88.  98.]
 [  9.  19.  29.  39.  49.  59.  69.  79.  89.  99.]
 [100. 110. 120. 130. 140. 150. 160. 170. 180. 190.]
 [101. 111. 121. 131. 141. 151. 161. 171. 181. 191.]
 [102. 112. 122. 132. 142. 152. 162. 172. 182. 192.]
 [103. 113. 123. 133. 143. 153. 163. 173. 183. 193.]
 [104. 114. 124. 134. 144. 154. 164. 17

In [11]:
x.current_episodic_data_size

Array(10, dtype=int32)

In [33]:
repertoire, emitter_state, metrics, random_key = map_elites.update(repertoire, emitter_state, random_key)

transitions size: (array(10, dtype=int32), array(10, dtype=int32), array(28, dtype=int32))


In [34]:
x = emitter_state.emitter_states[0].trajectory_buffer

In [35]:
print(f"Episode length: {x.episode_length}\n")
print(f"Environment batch size: {x.env_batch_size}\n")
print(f"Number of trajectories: {x.num_trajectories}\n")

print(f"Current position: {x.current_position}\n")
print(f"Current size: {x.current_size}\n")
print(f"Trajectory positions: {x.trajectory_positions}\n")
print(f"Timestep positions: {x.timestep_positions}\n")
print(f"Episodic data: {x.episodic_data}\n")
print(f"Current episodic data size: {x.current_episodic_data_size}\n")

Episode length: 10

Environment batch size: 5

Number of trajectories: 20

Current position: 0

Current size: 200

Trajectory positions: [ 0  0  0  0 20]

Timestep positions: [40 40 40 40  0]

Episodic data: [[ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [164. 169.  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [174. 179.  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [ nan  nan  nan  nan  nan  nan  nan  nan  nan  nan]
 [184. 189.  nan  nan  nan  nan  nan  nan  nan  na

In [37]:
x.transition.dones

Array([0.], dtype=float32)