In [1]:
!nvidia-smi


Fri Nov 29 15:26:25 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2080        Off |   00000000:0A:00.0 Off |                  N/A |
| 32%   43C    P8             24W /  225W |      41MiB /   8192MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 2080        Off |   00

In [2]:
import jax.numpy as jnp


In [3]:
data = jnp.load('output/ant_omni/mcpg_me/2024-11-18_114515_007801/repertoire/genotypes.npy')

In [6]:
jnp.min(data)

Array(-1.3110462, dtype=float32)

In [3]:
import functools
import pickle

import jax
import jax.numpy as jnp
from flax import serialization

from qdax import environments
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import PPOTransition, QDTransition, QDMCTransition
from qdax.core.neuroevolution.networks.networks import MLPMCPG, MLP, MLPDC, MLPMCPG_

from brax.v1.io import html
from IPython.display import HTML
from omegaconf import OmegaConf
from utils import get_env, get_config


In [4]:
genotypes = jnp.load("output/hopper_uni_250/mcpg_me_fixed/2024-11-27_141857_414364")

IsADirectoryError: [Errno 21] Is a directory: 'output/hopper_uni_250/mcpg_me_fixed/2024-11-27_141857_414364'

In [8]:
jnp.min(genotypes)

Array(-1.4786335, dtype=float32)

In [4]:
from pathlib import Path


run_dir = Path("output/walker2d_uni_250/mcpg_me_fixed/2024-11-29_151947_678968")
config = get_config(run_dir)



In [5]:
rng = jax.random.PRNGKey(config.seed)
env = get_env(config)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)


policy_network = MLPMCPG_(
    action_dim=env.action_size,
    activation=config.algo.activation,
    no_neurons=config.algo.no_neurons,
)

#policy_layer_sizes = config.policy_hidden_layer_sizes + (env.action_size,)
#policy_network = MLP(
#    layer_sizes=policy_layer_sizes,
#    kernel_init=jax.nn.initializers.lecun_uniform(),
#    final_activation=jnp.tanh,
#)
#actor_dc_network = MLPDC(
#    layer_sizes=policy_layer_sizes,
#    kernel_init=jax.nn.initializers.lecun_uniform(),
#    final_activation=jnp.tanh,
#)

    
@jax.jit
def play_step_fn(env_state, policy_params, random_key):
    #random_key, subkey = jax.random.split(random_key)
    pi, action = policy_network.apply(policy_params, env_state.obs)
    logp = pi.log_prob(action)
    #logp = policy_network.apply(policy_params, env_state.obs, actions, method=policy_network.logp)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, action)

    transition = QDMCTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=action,
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
        logp=logp,
    )

    return (next_state, policy_params, random_key), transition


#@jax.jit
#def play_step_fn(env_state, policy_params, random_key):
#    actions = policy_network.apply(policy_params, env_state.obs)
#    state_desc = env_state.info["state_descriptor"]
#    next_state = env.step(env_state, actions)

#    transition = QDTransition(
#        obs=env_state.obs,
#        next_obs=next_state.obs,
#        rewards=next_state.reward,
#        dones=next_state.done,
#        truncations=next_state.info["truncation"],
#        actions=actions,
#        state_desc=state_desc,
#        desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
#        next_state_desc=next_state.info["state_descriptor"],
#        desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
#    )

#    return (next_state, policy_params, random_key), transition



bd_extraction_fn = environments.behavior_descriptor_extractor[config.env.name]
scoring_fn = functools.partial(
    scoring_function,
    episode_length=config.env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

# Build the reconstruction function
fake_obs = jnp.zeros(shape=(env.observation_size,))
fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))

In [6]:
from jax.flatten_util import ravel_pytree
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire


random_key, random_subkey = jax.random.split(rng)
fake_params = policy_network.init(random_subkey, fake_obs)

_, reconstruction_fn = ravel_pytree(fake_params)

# Build the repertoire
repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=str(run_dir) + "/repertoire/")

In [8]:
repertoire.descriptors

Array([[0.45800003],
       [0.777666  ],
       [0.163     ],
       ...,
       [0.09716599],
       [0.12857144],
       [0.855     ]], dtype=float32)

Array([False, False, False, ..., False, False, False], dtype=bool)

In [7]:
#indices = [index for index, (x) in enumerate(repertoire.descriptors) if 0 < x < 0.2]
indices = [index for index, x in enumerate(repertoire.fitnesses) if    x > 1200]

#indices__ = [i for i in indices if i in indices_]
print("Requests indices:", indices)

Requests indices: [6, 28, 79, 83, 128, 196, 271, 387, 422, 453, 564, 570, 590, 599, 610, 655, 662, 679, 717, 799, 819, 846, 870, 913, 937, 994]


In [23]:
index_desired = 79
#index_desired = np.argmax(repertoire.fitnesses)
print(f"Index desired: {index_desired}, Max Fitness: {repertoire.fitnesses[index_desired]}")

Index desired: 79, Max Fitness: 1207.300048828125


In [28]:
#index_desired = jnp.argmax(repertoire.fitnesses)
policy_params = jax.tree_util.tree_map(lambda x: x[index_desired], repertoire.genotypes)


random_key, subkey = jax.random.split(random_key)
state = reset_fn(subkey)
rollout = [state]
while True:
    _, actions = policy_network.apply(policy_params, state.obs)
    state = step_fn(state, actions)

    if state.done:
        break
    else:
        rollout.append(state)

if env.state_descriptor_name == "feet_contact":
    descriptor = sum([s.info["state_descriptor"] for s in rollout]) / len(rollout)
elif env.state_descriptor_name == "xy_position":
    descriptor = rollout[-1].info["state_descriptor"]
else:
    raise NotImplementedError

print("Episode length: {}/{}".format(len(rollout), config.env.episode_length))
print("Fitness: {}".format(sum([s.reward for s in rollout])))
print("Observed descriptor: {}".format(descriptor))

Episode length: 244/250
Fitness: 1155.619140625
Observed descriptor: [0.33196723 0.2540984 ]


In [39]:
state.obs.shape

(18,)

In [29]:
a = html.render(env.sys, [state.qp for state in rollout])
with open("rollout.html", "w") as f:
    f.write(a)
    


HTML(a)