In [1]:
!nvidia-smi

Thu Oct 24 17:24:49 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.107.02             Driver Version: 550.107.02     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        Off |   00000000:2D:00.0  On |                  N/A |
|  0%   34C    P8             11W /  320W |   12473MiB /  16376MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import jax.numpy as jnp


In [3]:
data = jnp.load('data_time_efficiency/output/ant_omni_250/mcpg_me/2024-08-30_132117_340383/repertoire/genotypes.npy')

In [4]:
data.shape

(1024, 6672)

In [3]:
import functools
import pickle

import jax
import jax.numpy as jnp
from flax import serialization

from qdax import environments
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import PPOTransition, QDTransition, QDMCTransition
from qdax.core.neuroevolution.networks.networks import MLPMCPG, MLP, MLPDC

from brax.v1.io import html
from IPython.display import HTML
from omegaconf import OmegaConf
from utils import get_env, get_config


In [81]:
from pathlib import Path


run_dir = Path("data_time_efficiency/output/anttrap_omni_250/mcpg_me/2024-08-25_081532_471122")
config = get_config(run_dir)



In [82]:
rng = jax.random.PRNGKey(config.seed)
env = get_env(config)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)


policy_network = MLPMCPG(
    action_dim=env.action_size,
    activation=config.algo.activation,
    no_neurons=config.algo.no_neurons,
)

#policy_layer_sizes = config.policy_hidden_layer_sizes + (env.action_size,)
#policy_network = MLP(
#    layer_sizes=policy_layer_sizes,
#    kernel_init=jax.nn.initializers.lecun_uniform(),
#    final_activation=jnp.tanh,
#)
#actor_dc_network = MLPDC(
#    layer_sizes=policy_layer_sizes,
#    kernel_init=jax.nn.initializers.lecun_uniform(),
#    final_activation=jnp.tanh,
#)

    
@jax.jit
def play_step_fn(env_state, policy_params, random_key):
    #random_key, subkey = jax.random.split(random_key)
    pi, action = policy_network.apply(policy_params, env_state.obs)
    logp = pi.log_prob(action)
    #logp = policy_network.apply(policy_params, env_state.obs, actions, method=policy_network.logp)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, action)

    transition = QDMCTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=action,
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
        logp=logp,
    )

    return (next_state, policy_params, random_key), transition


#@jax.jit
#def play_step_fn(env_state, policy_params, random_key):
#    actions = policy_network.apply(policy_params, env_state.obs)
#    state_desc = env_state.info["state_descriptor"]
#    next_state = env.step(env_state, actions)

#    transition = QDTransition(
#        obs=env_state.obs,
#        next_obs=next_state.obs,
#        rewards=next_state.reward,
#        dones=next_state.done,
#        truncations=next_state.info["truncation"],
#        actions=actions,
#        state_desc=state_desc,
#        desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
#        next_state_desc=next_state.info["state_descriptor"],
#        desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
#    )

#    return (next_state, policy_params, random_key), transition



bd_extraction_fn = environments.behavior_descriptor_extractor[config.env.name]
scoring_fn = functools.partial(
    scoring_function,
    episode_length=config.env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

# Build the reconstruction function
fake_obs = jnp.zeros(shape=(env.observation_size,))
fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))

In [83]:
from jax.flatten_util import ravel_pytree
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire


random_key, random_subkey = jax.random.split(rng)
fake_params = policy_network.init(random_subkey, fake_obs)

_, reconstruction_fn = ravel_pytree(fake_params)

# Build the repertoire
repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=str(run_dir) + "/repertoire/")

In [84]:
repertoire.descriptors

Array([[ 12.210036  , -29.787737  ],
       [  1.9766684 ,   4.9344063 ],
       [ -7.5166144 , -21.614779  ],
       ...,
       [ -0.87479806,  27.659557  ],
       [ 14.554674  , -17.739021  ],
       [ -4.0704613 , -18.532244  ]], dtype=float32)

In [101]:
indices = [index for index, (x, y) in enumerate(repertoire.descriptors) if 14 < x < 5 and 0 < y < 1]
#indices = [index for index, x in enumerate(repertoire.descriptors) if 0.1 < x < 0.2]


print("Requests indices:", indices)

Requests indices: []


In [102]:
repertoire.descriptors[365]

Array([ 0.5147489, 13.374974 ], dtype=float32)

In [103]:
index_desired = 365
#index_desired = jnp.argmax([repertoire.fitnesses])
policy_params = jax.tree_util.tree_map(lambda x: x[index_desired], repertoire.genotypes)


random_key, subkey = jax.random.split(random_key)
state = reset_fn(subkey)
rollout = [state]
while True:
    _, actions = policy_network.apply(policy_params, state.obs)
    state = step_fn(state, actions)

    if state.done:
        break
    else:
        rollout.append(state)

if env.state_descriptor_name == "feet_contact":
    descriptor = sum([s.info["state_descriptor"] for s in rollout]) / len(rollout)
elif env.state_descriptor_name == "xy_position":
    descriptor = rollout[-1].info["state_descriptor"]
else:
    raise NotImplementedError

print("Episode length: {}/{}".format(len(rollout), config.env.episode_length))
print("Fitness: {}".format(sum([s.reward for s in rollout])))
print("Observed descriptor: {}".format(descriptor))

Episode length: 250/250
Fitness: 986.7078247070312
Observed descriptor: [0.4580231 6.1755204]


In [104]:
a = html.render(env.sys, [state.qp for state in rollout])
with open("rollout.html", "w") as f:
    f.write(a)

HTML(a)