In [1]:
!nvidia-smi

Tue Sep 10 20:50:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   28C    P0              25W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
import jax.numpy as jnp


In [7]:
data = jnp.load('data_time_efficiency/output/anttrap_omni_250/mcpg_me/2024-08-30_151310_156288/repertoire/genotypes.npy')

ValueError: cannot reshape array of size 1572832 into shape (1024,6672)

In [4]:
data.shape

(1024, 6664)

In [54]:
import functools
import pickle

import jax
import jax.numpy as jnp
from flax import serialization

from qdax import environments
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import PPOTransition, QDTransition
from qdax.core.neuroevolution.networks.networks import MLPMCPG, MLP, MLPDC

from brax.v1.io import html
from IPython.display import HTML
from omegaconf import OmegaConf
from utils import get_env, get_config

In [135]:
from pathlib import Path


run_dir = Path("data_time_efficiency/output/hopper_uni_1000/dcg_me/2024-08-30_111553_509846")
config = get_config(run_dir)



In [136]:
rng = jax.random.PRNGKey(config.seed)
env = get_env(config)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)


#policy_network = MLPMCPG(
#    action_dim=env.action_size,
#    activation=config.algo.activation,
#    no_neurons=config.algo.no_neurons,
#)

policy_layer_sizes = config.policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)
actor_dc_network = MLPDC(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

    
#@jax.jit
#def play_step_fn(env_state, policy_params, key):
#    rng, rng_ = jax.random.split(key)
#    pi, action = policy_network.apply(policy_params, env_state.obs)
#    #action_ = pi.sample(seed=rng_)
#    log_prob = pi.log_prob(action)
    
    #rng, rng_ = jax.random.split(rng)
    #rng_step = jax.random.split(rng_, num=self._config.NUM_ENVS)
#    next_env_state = env.step(env_state, action)
#    transition = PPOTransition(
#        obs=env_state.obs,
#        next_obs=next_env_state.obs,
#        rewards=next_env_state.reward,
#        dones=next_env_state.done,
#        truncations=next_env_state.info["truncation"],
#        actions=action,
#        state_desc=env_state.info["state_descriptor"],
#        next_state_desc=next_env_state.info["state_descriptor"],
#        val= 0.0,
#        logp=log_prob,
#    )
    
#    return (next_env_state, policy_params, rng), transition


@jax.jit
def play_step_fn(env_state, policy_params, random_key):
    actions = policy_network.apply(policy_params, env_state.obs)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=actions,
        state_desc=state_desc,
        desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
        next_state_desc=next_state.info["state_descriptor"],
        desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
    )

    return (next_state, policy_params, random_key), transition



bd_extraction_fn = environments.behavior_descriptor_extractor[config.env.name]
scoring_fn = functools.partial(
    scoring_function,
    episode_length=config.env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

# Build the reconstruction function
fake_obs = jnp.zeros(shape=(env.observation_size,))
fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))

In [137]:
from jax.flatten_util import ravel_pytree
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire


random_key, random_subkey = jax.random.split(rng)
fake_params = policy_network.init(random_subkey, fake_obs)

_, reconstruction_fn = ravel_pytree(fake_params)

# Build the repertoire
repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=str(run_dir) + "/repertoire/")

In [118]:
repertoire.descriptors

Array([[0.76400006],
       [0.3292929 ],
       [0.08114558],
       ...,
       [0.21400002],
       [0.938     ],
       [0.06273063]], dtype=float32)

In [138]:
#indices = [index for index, (x, y) in enumerate(repertoire.descriptors) if 0.4 < x < 0.5 and 0.6 < y < 0.7]
indices = [index for index, x in enumerate(repertoire.descriptors) if 0.1 < x < 0.2]


print("Indices of tuples that are either [29, 29] or [30, 30]:", indices)

Indices of tuples that are either [29, 29] or [30, 30]: [11, 34, 84, 148, 157, 208, 212, 290, 355, 376, 378, 387, 492, 520, 558, 586, 614, 618, 622, 658, 680, 699, 703, 764, 782, 786, 804, 845, 885, 887, 901, 934, 962, 989, 1015]


In [143]:
repertoire.fitnesses

Array([1674.4199, 1841.0847, 1525.0276, ...,      -inf, 1583.8702,
       1434.6731], dtype=float32)

In [144]:
index_desired = 1
#index_desired = jnp.argmax([repertoire.fitnesses])
policy_params = jax.tree_util.tree_map(lambda x: x[index_desired], repertoire.genotypes)


random_key, subkey = jax.random.split(random_key)
state = reset_fn(subkey)
rollout = [state]
while True:
    actions = policy_network.apply(policy_params, state.obs)
    state = step_fn(state, actions)

    if state.done:
        break
    else:
        rollout.append(state)

if env.state_descriptor_name == "feet_contact":
    descriptor = sum([s.info["state_descriptor"] for s in rollout]) / len(rollout)
elif env.state_descriptor_name == "xy_position":
    descriptor = rollout[-1].info["state_descriptor"]
else:
    raise NotImplementedError

print("Episode length: {}/{}".format(len(rollout), config.env.episode_length))
print("Fitness: {}".format(sum([s.reward for s in rollout])))
print("Observed descriptor: {}".format(descriptor))

Episode length: 963/1000
Fitness: 1617.8909912109375
Observed descriptor: [0.7227415]


In [145]:
a = html.render(env.sys, [state.qp for state in rollout])
with open(run_dir / "rollout.html", "w") as f:
    f.write(a)

HTML(a)