In [1]:
!nvidia-smi

Fri Sep  6 16:52:22 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   30C    P0              25W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
from dataclasses import dataclass
from functools import partial
from math import floor 
from typing import Callable, Tuple, Any

import jax
from jax import debug
import jax.numpy as jnp
import flax.linen as nn
import optax
from chex import ArrayTree
from qdax.core.containers.repertoire import Repertoire
from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.environments.base_wrappers import QDEnv
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer
from qdax import environments_v1, environments

from qdax.core.emitters.emitter import Emitter, EmitterState

2024-09-06 16:52:29.290073: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-06 16:52:29.317045: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-06 16:52:29.317098: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 250
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 250

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 250

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 250

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
        
    elif env_name == "ant_omni":
        episode_length = 250

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
        
    return env

In [5]:
@dataclass
class Config:
    seed: 42
    num_iterations: 2000
    batch_size: 512

# Archive
    num_init_cvt_samples: 50000
    num_centroids: 1024

In [6]:
config = Config(
    seed=42,
    num_iterations=2000,
    batch_size=16,
    num_init_cvt_samples=50000,
    num_centroids=1024,
)

In [8]:
import os

os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
os.environ['WANDB_CACHE_DIR'] = '/tmp/wandb_cache'
os.environ['JAX_LOG_COMPILATION'] = '1'

import logging
import time
from dataclasses import dataclass
from functools import partial
from math import floor
from typing import Any, Dict, Tuple, List, Callable
import pickle
from flax import serialization
#logging.basicConfig(level=logging.DEBUG)
import hydra
from omegaconf import OmegaConf
import jax
import jax.numpy as jnp
from hydra.core.config_store import ConfigStore
from qdax.core.map_elites_advanced_baseline_time_step import MAPElites
from qdax.types import RNGKey, Genotype
from qdax.utils.sampling import sampling 
from qdax.core.containers.mapelites_repertoire_advanced_baseline_time_step import compute_cvt_centroids, MapElitesRepertoire
from qdax.core.neuroevolution.networks.networks import MLPMCPG
from qdax.core.emitters.me_mcpg_emitter_advanced_baseline_time_step import MEMCPGConfig, MEMCPGEmitter
#from qdax.core.emitters.rein_emitter_advanced import REINaiveConfig, REINaiveEmitter
from qdax.core.neuroevolution.buffers.buffer import QDTransition, QDMCTransition
from qdax.environments_v1 import behavior_descriptor_extractor
from qdax.tasks.brax_envs_advanced_baseline_time_step import reset_based_scoring_function_brax_envs as scoring_function
from qdax.core.emitters.mutation_operators import isoline_variation
import wandb
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_2d_map_elites_repertoire
import matplotlib.pyplot as plt







env_name = "ant_uni"
def main(config, env_name) -> None:
    #profiler_dir = "Memory_Investigation"
    #os.makedirs(profiler_dir, exist_ok=True)
    #wandb.login(key="ab476069b53a15ad74ff1845e8dee5091d241297")
    #wandb.init(
    #    project="me-mcpg",
    #    name=config.algo.name,
    #    config=OmegaConf.to_container(config, resolve=True),
    #)
    # Init a random key
    
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = get_env(env_name)
    reset_fn = jax.jit(env.reset)

    # Compute the centroids
    if env_name not in ["ant_uni", "walker2d_uni"]:
        centroids, random_key = compute_cvt_centroids(
            num_descriptors=env.behavior_descriptor_length,
            num_init_cvt_samples=config.num_init_cvt_samples,
            num_centroids=config.num_centroids,
            minval=-30,
            maxval=30,
            random_key=random_key,
        )
        
    else:
        centroids, random_key = compute_cvt_centroids(
            num_descriptors=env.behavior_descriptor_length,
            num_init_cvt_samples=config.num_init_cvt_samples,
            num_centroids=config.num_centroids,
            minval=0,
            maxval=1,
            random_key=random_key,
        )  
    # Init policy network

    

    
    
    policy_network = MLPMCPG(
        action_dim=env.action_size,
        activation='tanh',
        no_neurons=64,
    )
    
    # Init population of controllers
    
    # maybe consider adding two random keys for each policy
    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.split(subkey, num=config.batch_size)
    #split_keys = jax.vmap(lambda k: jax.random.split(k, 2))(keys)
    #keys1, keys2 = split_keys[:, 0], split_keys[:, 1]
    fake_batch_obs = jnp.zeros(shape=(config.batch_size, env.observation_size))
    init_params = jax.vmap(policy_network.init)(keys, fake_batch_obs)

    param_count = sum(x[0].size for x in jax.tree_util.tree_leaves(init_params))
    print("Number of parameters in policy_network: ", param_count)

    # Define the fonction to play a step with the policy in the environment
    @jax.jit
    def play_step_fn(env_state, policy_params, random_key):
        #random_key, subkey = jax.random.split(random_key)
        pi, action = policy_network.apply(policy_params, env_state.obs)
        logp = pi.log_prob(action)
        #logp = policy_network.apply(policy_params, env_state.obs, actions, method=policy_network.logp)
        state_desc = env_state.info["state_descriptor"]
        next_state = env.step(env_state, action)

        transition = QDMCTransition(
            obs=env_state.obs,
            next_obs=next_state.obs,
            rewards=next_state.reward,
            dones=next_state.done,
            truncations=next_state.info["truncation"],
            actions=action,
            state_desc=state_desc,
            next_state_desc=next_state.info["state_descriptor"],
            logp=logp,
        )

        return (next_state, policy_params, random_key), transition





    # Prepare the scoring function
    bd_extraction_fn = behavior_descriptor_extractor[env_name]
    scoring_fn = partial(
        scoring_function,
        episode_length=250,
        play_reset_fn=reset_fn,
        play_step_fn=play_step_fn,
        behavior_descriptor_extractor=bd_extraction_fn,
    )



    
    reward_offset = 0
    
    
    
         

    # Get minimum reward value to make sure qd_score are positive
    

    # Define a metrics function
    metrics_function = partial(
        default_qd_metrics,
        qd_offset=0,
    )

    # Define the PG-emitter config
    
    me_mcpg_config = MEMCPGConfig(
        proportion_mutation_ga=0.5,
        no_agents=config.batch_size,
        buffer_sample_batch_size=1,
        buffer_add_batch_size=config.batch_size,
        no_epochs=32,
        learning_rate=3e-3,
        clip_param=0.2,
        discount_rate=0.99,
    )
    
    variation_fn = partial(
        isoline_variation, iso_sigma=0.005, line_sigma=0.05
    )
    
    me_mcpg_emitter = MEMCPGEmitter(
        config=me_mcpg_config,
        policy_network=policy_network,
        env=env,
        variation_fn=variation_fn,
        )
    


    # Instantiate MAP Elites
    map_elites = MAPElites(
        scoring_function=scoring_fn,
        emitter=me_mcpg_emitter,
        metrics_function=metrics_function,
    )

    # compute initial repertoire
    repertoire, emitter_state, random_key = map_elites.init(init_params, centroids, random_key)
    
    return scoring_fn, me_mcpg_emitter, metrics_function, repertoire, emitter_state, random_key




In [9]:
scoring_fn, me_mcpg_emitter, metrics_function, repertoire, emitter_state, random_key = main(config, env_name)

Number of parameters in policy_network:  6544


  repertoire = MapElitesRepertoire.init(


In [9]:
import pandas as pd
import timeit
import statistics
import matplotlib.pyplot as plt
import seaborn as sns

In [16]:
envs = ["ant_uni", "walker2d_uni", "ant_omni"]
batch_sizes = [2, 4, 8, 16]

results_data = []


for env_name in envs:
    for batch_size in batch_sizes:
        config = Config(
            seed=42,
            num_iterations=2000,
            batch_size=batch_size,
            num_init_cvt_samples=50000,
            num_centroids=1024,
        )
        scoring_fn, me_mcpg_emitter, metrics_function, repertoire, emitter_state, random_key = main(config, env_name)
        

        def emit():
            genotypes, _, rng = me_mcpg_emitter.emit(repertoire, emitter_state, random_key)
            return genotypes, rng
        
        genotypes, rng = emit()


        timer_1 = timeit.Timer(emit)
        results_1 = timer_1.repeat(repeat=100, number=1)  # Adjust 'repeat' and 'number' as needed

        median_time_1 = statistics.median(results_1)
        
        def evaluate():
            fitnesses, descriptors, extra_scores, rng_ = scoring_fn(genotypes, rng)
            return fitnesses, descriptors, extra_scores, rng_
        
        fitnesses, descriptors, extra_scores, rng_ = evaluate()
        
        timer_2 = timeit.Timer(evaluate)
        results_2 = timer_2.repeat(repeat=100, number=1)
        
        median_time_2 = statistics.median(results_2)
        
        
        print(fitnesses.shape)
        
        print(genotypes.shape)
        
        def add_to_repertoire():
            repertoire_, _ = repertoire.add(genotypes, fitnesses, descriptors, extra_scores)
            return repertoire_
        
        repertoire_ = add_to_repertoire()
        
        timer_3 = timeit.Timer(add_to_repertoire)
        results_3 = timer_3.repeat(repeat=100, number=1)
        
        median_time_3 = statistics.median(results_3)
        
        def update_emitter_state():
            emitter_state = me_mcpg_emitter.state_update(
                emitter_state=emitter_state,
                genotypes=genotypes,
                fitnesses=fitnesses,
                descriptors=descriptors,
                extra_scores={**extra_scores},
            )      
            
            return emitter_state
        
        emitter_state = update_emitter_state()
        
        timer_4 = timeit.Timer(update_emitter_state)
        results_4 = timer_4.repeat(repeat=100, number=1)
        
        median_time_4 = statistics.median(results_4)
        
        functions = ['emit', 'evaluate', 'add_to_repertoire', 'update_emitter_state']
        times = [median_time_1, median_time_2, median_time_3, median_time_4]
        for func, time in zip(functions, times):
            results_data.append({
                'Environment': env_name,
                'Batch Size': batch_size,
                'Function': func,
                'Median Time (s)': time
            })
            

df = pd.DataFrame(results_data)

# Plotting with Seaborn
sns.set(style="whitegrid")
g = sns.FacetGrid(df, col="Environment", hue="Function", col_wrap=4, height=4, aspect=1)
g.map(sns.barplot, "Batch Size", "Median Time (s)", order=batch_sizes, hue_order=functions)

# Adjust the labels and titles
g.set_titles("{col_name}")
g.set_axis_labels("Batch Size", "Median Time (s)")
g.add_legend()

# Save the plot as a PDF
plt.savefig("time_analysis.pdf")
plt.close()

Number of parameters in policy_network:  6544


  repertoire = MapElitesRepertoire.init(


ValueError: indices and arr must have the same number of dimensions; 2 vs. 3

In [10]:
import timeit
import statistics

# Emit

In [16]:
# 2
def emit():
    genotypes, _, rng = me_mcpg_emitter.emit(repertoire, emitter_state, random_key)
    return genotypes

genotypes__ = emit()


timer = timeit.Timer(emit)
results = timer.repeat(repeat=100, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5
median_time = statistics.median(results)
# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Median:", median_time)
print("Standard deviation:", standard_deviation)

Mean time: 0.0039826165791600945
Median: 0.003259650431573391
Standard deviation: 0.0009600521763875027


{'params': {'Dense_0': {'bias': Array([[-0.02268209, -0.01245804,  0.03800016, ...,  0.01663512,
           -0.01081221,  0.00783238],
          [-0.00453394, -0.02123269,  0.01229084, ..., -0.00201197,
            0.00371878,  0.00885391],
          [-0.01737902,  0.0069883 ,  0.01333423, ...,  0.00307016,
           -0.00815608, -0.00309304],
          ...,
          [ 0.0008353 , -0.00128999, -0.00413902, ..., -0.00108962,
           -0.00235767,  0.00602286],
          [-0.00202536,  0.00155167,  0.00458123, ...,  0.00341697,
            0.00427263, -0.00461006],
          [-0.00491072,  0.00486673, -0.00167077, ...,  0.00041222,
           -0.00267532, -0.00449609]], dtype=float32),
   'kernel': Array([[[ 0.01755947, -0.00582279,  0.08840003, ..., -0.36163956,
            -0.25936425, -0.19244075],
           [-0.08389208,  0.05417093, -0.02185524, ...,  0.03562257,
             0.17757055,  0.08871203],
           [ 0.23376095, -0.16447133, -0.04926711, ...,  0.14301273,
        

In [74]:
# 4
def emit():
    me_mcpg_emitter.emit(repertoire, emitter_state, random_key)

emit()


timer = timeit.Timer(emit)
results = timer.repeat(repeat=10000, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)

Mean time: 0.0027077434998936953
Standard deviation: 9.75418948432544e-05


In [80]:
# 4
def emit():
    me_mcpg_emitter.emit(repertoire, emitter_state, random_key)

emit()


timer = timeit.Timer(emit)
results = timer.repeat(repeat=10000, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)

Mean time: 0.002706670105550438
Standard deviation: 8.496294775579063e-05


In [89]:
# 8
def emit():
    me_mcpg_emitter.emit(repertoire, emitter_state, random_key)

emit()


timer = timeit.Timer(emit)
results = timer.repeat(repeat=10000, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)

Mean time: 0.0028861560557968914
Standard deviation: 9.583239244856506e-05


In [98]:
# 16
def emit():
    me_mcpg_emitter.emit(repertoire, emitter_state, random_key)

emit()


timer = timeit.Timer(emit)
results = timer.repeat(repeat=10000, number=1)  # Adjust 'repeat' and 'number' as needed

# Calculate mean time and standard deviation
mean_time = sum(results) / len(results)
standard_deviation = (sum((x - mean_time) ** 2 for x in results) / len(results)) ** 0.5

# Now you can use `mean_time` and `standard_deviation` as needed
print("Mean time:", mean_time)
print("Standard deviation:", standard_deviation)

Mean time: 0.0032492196750827135
Standard deviation: 9.439742016336214e-05
