In [2]:
!nvidia-smi

Tue Nov  5 14:32:01 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 2080        Off |   00000000:0A:00.0 Off |                  N/A |
| 33%   37C    P8             21W /  225W |      41MiB /   8192MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 2080        Off |   00

In [3]:
import jax.numpy as jnp


In [3]:
data = jnp.load('data_time_efficiency/output/ant_omni_250/mcpg_me/2024-08-30_132117_340383/repertoire/genotypes.npy')

IsADirectoryError: [Errno 21] Is a directory: 'output/anttrap_omni_250/mcpg_me/2024-10-25_084444_739206'

In [4]:
data.shape

(1024, 6672)

In [4]:
import functools
import pickle

import jax
import jax.numpy as jnp
from flax import serialization

from qdax import environments
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import PPOTransition, QDTransition, QDMCTransition
from qdax.core.neuroevolution.networks.networks import MLPMCPG, MLP, MLPDC

from brax.v1.io import html
from IPython.display import HTML
from omegaconf import OmegaConf
from utils import get_env, get_config


In [22]:
from pathlib import Path


run_dir = Path("scalability/output/walker2d_uni_250/pga_me/2024-09-04_172119_197084")
config = get_config(run_dir)



In [23]:
rng = jax.random.PRNGKey(config.seed)
env = get_env(config)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)


#policy_network = MLPMCPG(
#    action_dim=env.action_size,
#    activation=config.algo.activation,
#    no_neurons=config.algo.no_neurons,
#)

policy_layer_sizes = config.policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)
#actor_dc_network = MLPDC(
#    layer_sizes=policy_layer_sizes,
#    kernel_init=jax.nn.initializers.lecun_uniform(),
#    final_activation=jnp.tanh,
#)

    
@jax.jit
def play_step_fn(env_state, policy_params, random_key):
    #random_key, subkey = jax.random.split(random_key)
    pi, action = policy_network.apply(policy_params, env_state.obs)
    logp = pi.log_prob(action)
    #logp = policy_network.apply(policy_params, env_state.obs, actions, method=policy_network.logp)
    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, action)

    transition = QDMCTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        truncations=next_state.info["truncation"],
        actions=action,
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
        logp=logp,
    )

    return (next_state, policy_params, random_key), transition


#@jax.jit
#def play_step_fn(env_state, policy_params, random_key):
#    actions = policy_network.apply(policy_params, env_state.obs)
#    state_desc = env_state.info["state_descriptor"]
#    next_state = env.step(env_state, actions)

#    transition = QDTransition(
#        obs=env_state.obs,
#        next_obs=next_state.obs,
#        rewards=next_state.reward,
#        dones=next_state.done,
#        truncations=next_state.info["truncation"],
#        actions=actions,
#        state_desc=state_desc,
#        desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
#        next_state_desc=next_state.info["state_descriptor"],
#        desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
#    )

#    return (next_state, policy_params, random_key), transition



bd_extraction_fn = environments.behavior_descriptor_extractor[config.env.name]
scoring_fn = functools.partial(
    scoring_function,
    episode_length=config.env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

# Build the reconstruction function
fake_obs = jnp.zeros(shape=(env.observation_size,))
fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))

In [24]:
from jax.flatten_util import ravel_pytree
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire


random_key, random_subkey = jax.random.split(rng)
fake_params = policy_network.init(random_subkey, fake_obs)

_, reconstruction_fn = ravel_pytree(fake_params)

# Build the repertoire
repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=str(run_dir) + "/repertoire/")

In [8]:
repertoire.descriptors

Array([[0.45800003],
       [0.777666  ],
       [0.163     ],
       ...,
       [0.09716599],
       [0.12857144],
       [0.855     ]], dtype=float32)

Array([False, False, False, ..., False, False, False], dtype=bool)

In [25]:
#indices = [index for index, (x, y) in enumerate(repertoire.descriptors) if 28 < x < 30 and -30 < y < -26]
indices = [index for index, x in enumerate(repertoire.fitnesses) if x > 900]


print("Requests indices:", indices)

Requests indices: [1, 6, 12, 19, 23, 25, 27, 28, 29, 49, 50, 54, 60, 66, 72, 73, 75, 79, 81, 86, 89, 91, 95, 97, 99, 109, 112, 120, 133, 134, 137, 138, 139, 141, 156, 158, 164, 168, 173, 183, 185, 190, 192, 193, 203, 205, 223, 225, 230, 233, 234, 250, 255, 261, 262, 265, 270, 273, 274, 281, 285, 287, 293, 295, 297, 300, 305, 313, 317, 321, 324, 331, 332, 352, 358, 366, 370, 372, 374, 376, 378, 380, 388, 398, 399, 402, 407, 408, 409, 412, 416, 419, 429, 434, 436, 443, 445, 446, 448, 449, 457, 458, 459, 461, 464, 466, 470, 473, 496, 499, 501, 515, 516, 517, 519, 520, 523, 524, 526, 532, 541, 547, 549, 551, 552, 553, 555, 559, 560, 576, 582, 583, 594, 595, 605, 611, 619, 621, 622, 625, 630, 638, 651, 656, 670, 673, 678, 683, 684, 691, 694, 695, 698, 701, 703, 706, 710, 715, 718, 724, 736, 741, 742, 754, 763, 769, 776, 784, 790, 794, 795, 803, 804, 806, 808, 812, 813, 814, 816, 819, 820, 826, 832, 840, 846, 849, 851, 854, 862, 864, 871, 873, 882, 887, 891, 892, 896, 898, 899, 913, 914, 916

In [26]:
index_desired = 1
#index_desired = jnp.argmax(repertoire.fitnesses)
print(f"Index desired: {index_desired}, Max Fitness: {repertoire.fitnesses[index_desired]}")

Index desired: 1, Max Fitness: 1021.4940185546875


In [28]:
#index_desired = jnp.argmax(repertoire.fitnesses)
policy_params = jax.tree_util.tree_map(lambda x: x[index_desired], repertoire.genotypes)


random_key, subkey = jax.random.split(random_key)
state = reset_fn(subkey)
rollout = [state]
while True:
    actions = policy_network.apply(policy_params, state.obs)
    state = step_fn(state, actions)

    if state.done:
        break
    else:
        rollout.append(state)

if env.state_descriptor_name == "feet_contact":
    descriptor = sum([s.info["state_descriptor"] for s in rollout]) / len(rollout)
elif env.state_descriptor_name == "xy_position":
    descriptor = rollout[-1].info["state_descriptor"]
else:
    raise NotImplementedError

print("Episode length: {}/{}".format(len(rollout), config.env.episode_length))
print("Fitness: {}".format(sum([s.reward for s in rollout])))
print("Observed descriptor: {}".format(descriptor))

Episode length: 250/250
Fitness: 1014.3215942382812
Observed descriptor: [0.36400002 0.40800002]


In [29]:
a = html.render(env.sys, [state.qp for state in rollout])
with open("rollout.html", "w") as f:
    f.write(a)

HTML(a)