In [1]:
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
from qdax.core.containers.mome_repertoire import MOMERepertoire
from qdax import environments
from qdax.core.neuroevolution.networks.networks import MLP
import os
from typing import Tuple
from qdax.types import RNGKey
from IPython.display import HTML
from brax.io import html
from jax.flatten_util import ravel_pytree


ModuleNotFoundError: No module named 'qdax'

In [None]:
policy_hidden_layer_sizes = (64, 64)
episode_length = 1000
env_name = "kicker_multi"
fixed_init_state = False
mutation_ga_batch_size = 256
mutation_qpg_batch_size = 64
num_objective_functions = 2
num_centroids= 256 
pareto_front_max_length = 50
batch_size = mutation_ga_batch_size + mutation_qpg_batch_size * num_objective_functions

env = environments.create(env_name, episode_length=episode_length, fixed_init_state=fixed_init_state)

repertoire_path = "/Users/joaquinarias/Downloads/kicker_vis/2024-05-10_193529_42/final/repertoire/"
num_save_visualisations = 1
save_dir = "/Users/joaquinarias/Documents/Thesis/Project/MOME_PGX/"



In [None]:
random_key = jax.random.PRNGKey(42)
random_key, subkey = jax.random.split(random_key)
env = environments.create(env_name, episode_length=episode_length)
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)
random_key, subkey = jax.random.split(random_key)

keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_genotypes = jax.vmap(policy_network.init)(keys, fake_batch)

default_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(
                shape=(
                    num_centroids,
                    pareto_front_max_length,
                )
                + x.shape[1:]
            ),
            init_genotypes,
        )



In [None]:
global_unravel_fn = None

def flatten(genotype):
    global global_unravel_fn
    flatten_genotype, unravel_fn = ravel_pytree(genotype)
    if global_unravel_fn is None:
        global_unravel_fn = unravel_fn  # Initialize once
    return flatten_genotype
flat_genotypes = jax.vmap(flatten)(default_genotypes)

In [None]:
repertoire = MOMERepertoire.load(reconstruction_fn=global_unravel_fn, path=repertoire_path)

In [None]:
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
best_bd = repertoire.descriptors[best_idx]

In [None]:
print(
    f"Best fitness in the repertoire: {best_fitness:.2f}\n",
    f"Behavior descriptor of the best individual in the repertoire: {best_bd}\n",
    f"Index in the repertoire of this individual: {best_idx}\n"
)

In [None]:
sampled_genotypes, _ = repertoire.sample(random_key, num_save_visualisations)

In [None]:
for sample in range(num_save_visualisations):
        params = jax.tree_util.tree_map(
            lambda x: x[sample],
            sampled_genotypes
        )

        visualise_individual(
            env,
            policy_network,
            params,
            f"sample_{sample}.html",
            save_dir
        )

In [None]:
def visualise_individual(
    env,
    policy_network,
    params,
    name,
    save_dir,
):
    """ Roll out individual policy and save visualisation"""
    path = os.path.join(save_dir, name)

    jit_env_reset = jax.jit(env.reset)
    jit_env_step = jax.jit(env.step)
    jit_inference_fn = jax.jit(policy_network.apply)

    rollout = []
    rng = jax.random.PRNGKey(seed=1)
    state = jit_env_reset(rng=rng)

    while not state.done:
        rollout.append(state)
        action = jit_inference_fn(params, state.obs)
        state = jit_env_step(state, action)

    with File(path, 'w') as fout:
        fout.write(html.render(env.sys, [s.qp for s in rollout], height=480))
