In [1]:
!nvidia-smi

Mon Sep  2 00:22:35 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   30C    P0              26W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
import jax.numpy as jnp


In [7]:
data = jnp.load('data_time_efficiency/output/anttrap_omni_250/mcpg_me/2024-08-30_151310_156288/repertoire/genotypes.npy')

ValueError: cannot reshape array of size 1572832 into shape (1024,6672)

In [4]:
data.shape

(1024, 6664)

In [8]:
import functools
import pickle

import jax
import jax.numpy as jnp
from flax import serialization

from qdax import environments
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import PPOTransition
from qdax.core.neuroevolution.networks.networks import MLPMCPG

from brax.v1.io import html
from IPython.display import HTML
from omegaconf import OmegaConf
from utils import get_env, get_config

In [5]:
from pathlib import Path


run_dir = Path("/vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/2024-08-19_141121_177022")
config = get_config(run_dir)



In [9]:
rng = jax.random.PRNGKey(config.seed)
env = get_env(config)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

if not config.algo.init_lecun:
    policy_network = MLPMCPG(
        action_dim=env.action_size,
        activation=config.algo.ACTIVATION,
        no_neurons=config.algo.NO_NEURONS,
    )
    
else:
    policy_network = MLPMCPG(
        action_dim=env.action_size,
        activation=config.algo.ACTIVATION,
        no_neurons=config.algo.NO_NEURONS,
        kernel_init=jax.nn.initializers.lecun_uniform(),
        final_init=jax.nn.initializers.lecun_uniform(),
    )
    
    
@jax.jit
def play_step_fn(env_state, policy_params, key):
    rng, rng_ = jax.random.split(key)
    pi, action = policy_network.apply(policy_params, env_state.obs)
    #action_ = pi.sample(seed=rng_)
    log_prob = pi.log_prob(action)
    
    #rng, rng_ = jax.random.split(rng)
    #rng_step = jax.random.split(rng_, num=self._config.NUM_ENVS)
    next_env_state = env.step(env_state, action)
    transition = PPOTransition(
        obs=env_state.obs,
        next_obs=next_env_state.obs,
        rewards=next_env_state.reward,
        dones=next_env_state.done,
        truncations=next_env_state.info["truncation"],
        actions=action,
        state_desc=env_state.info["state_descriptor"],
        next_state_desc=next_env_state.info["state_descriptor"],
        val= 0.0,
        logp=log_prob,
    )
    
    return (next_env_state, policy_params, rng), transition



bd_extraction_fn = environments.behavior_descriptor_extractor[config.env.name]
scoring_fn = functools.partial(
    scoring_function,
    episode_length=config.env.episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

# Build the reconstruction function
fake_obs = jnp.zeros(shape=(env.observation_size,))
fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))

In [11]:
from jax.flatten_util import ravel_pytree
from qdax.core.containers.mapelites_repertoire import MapElitesRepertoire


random_key, random_subkey = jax.random.split(rng)
fake_params = policy_network.init(random_subkey, fake_obs)

_, reconstruction_fn = ravel_pytree(fake_params)

# Build the repertoire
repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=str(run_dir) + "/repertoire/")

In [18]:
index_desired = jnp.argmax(repertoire.fitnesses)
policy_params = jax.tree_util.tree_map(lambda x: x[index_desired], repertoire.genotypes)


random_key, subkey = jax.random.split(random_key)
state = reset_fn(subkey)
rollout = [state]
while True:
    _, actions = policy_network.apply(policy_params, state.obs)
    state = step_fn(state, actions)

    if state.done:
        break
    else:
        rollout.append(state)

if env.state_descriptor_name == "feet_contact":
    descriptor = sum([s.info["state_descriptor"] for s in rollout]) / len(rollout)
elif env.state_descriptor_name == "xy_position":
    descriptor = rollout[-1].info["state_descriptor"]
else:
    raise NotImplementedError

print("Episode length: {}/{}".format(len(rollout), config.env.episode_length))
print("Fitness: {}".format(sum([s.reward for s in rollout])))

Episode length: 804/1000
Fitness: 22567.06640625


In [19]:
a = html.render(env.sys, [state.qp for state in rollout])
with open(run_dir / "rollout.html", "w") as f:
    f.write(a)

HTML(a)