In [1]:
import jax
import jax.numpy as jnp
from evosax import ParameterReshaper
import ipyplot
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

# Setup

### Create environment

In [2]:
import gymnax as gym

In [3]:
env_name = "MountainCar-v0"
env, env_params = gym.make(env_name)
n_actions = env.action_space(env_params).n
obs_dims = env.observation_space(env_params).shape[0]

In [4]:
env_steps = env_params.max_steps_in_episode

### Create NDP and Policy

In [5]:
from models import HyperNCA, HyperNCA_Config
from policies import MLPPolicy

In [6]:
# MLP configuration
mlp_hidden_dims = max((obs_dims, n_actions)) # nb of features for each hidden layer
mlp_hidden_layers = 1 # nb of hidden layers
mlp_action_dims = n_actions 
mlp_obs_dims = obs_dims

In [7]:
policy = MLPPolicy(mlp_action_dims, mlp_hidden_layers,
                  mlp_hidden_dims, mode="categorical")
key = jax.random.PRNGKey(42)
policy_params = policy.init(key, jnp.zeros((obs_dims,)), key)
policy_params

FrozenDict({
    params: {
        mlp: {
            layers_0: {
                kernel: Array([[-0.0854697 , -0.42540255, -0.10164906],
                       [-0.11355377,  0.05332496, -1.1430653 ]], dtype=float32),
            },
            out_layer: {
                kernel: Array([[-0.68307227,  0.38337517, -0.16882634],
                       [-0.7283222 , -0.1463625 , -0.5922689 ],
                       [-0.16727802, -1.1538497 , -0.2005003 ]], dtype=float32),
            },
        },
    },
})

In [8]:
# NCA Configuration
channels = 4
alpha = 0.1 # alive threshold
perception_dims = 3 # nb of perception kernels
update_features = (16,) # hidden features of update network

iterations = 20 #number of development steps

In [9]:
ndp_config = HyperNCA_Config(
    channels = channels,
    alpha = alpha,
    perception_dims = perception_dims,
    update_features = update_features,
    iterations = iterations,
    action_dims = n_actions,
    obs_dims = obs_dims,
    hidden_dims = mlp_hidden_dims,
    hidden_layers = mlp_hidden_layers
)

In [10]:
ndp = HyperNCA(ndp_config)

In [11]:
z_dims = ndp.z_dims # nb of dimensions of the latent space (channels)
z_dims

4

In [12]:
ndp_params = ndp.init(jax.random.PRNGKey(42), jnp.ones((z_dims, )))
parameter_reshaper = ParameterReshaper(ndp_params)

ParameterReshaper: 580 parameters detected for optimization.


### Create evaluator

In [13]:
from evaluators import DiversityEvaluator, DiversityEvaluator_Config 
from envs import bd_mountain_car

In [14]:
evaluator_config = DiversityEvaluator_Config(
    epochs = 3,
    env = env,
    env_params = env_params,
    env_steps = 200,
    n_params = z_dims,
    bd_extractor = bd_mountain_car,
    popsize = 60,
    score_fn = 'knn_sparsity'
)

In [15]:
evaluator = DiversityEvaluator(evaluator_config, ndp, policy)

### Create Trainer

In [16]:
from metandp import NDP_Trainer, Config

In [17]:
trainer_config = Config(
    epochs = 200,
    n_params = parameter_reshaper.total_params,
    params_shaper = parameter_reshaper,
    es = "des",
    popsize = 64,
    es_config = {},
    es_params = None
)

In [18]:
ndp_trainer = NDP_Trainer(trainer_config, ndp, evaluator)

# Train

In [19]:
key = jax.random.PRNGKey(42)
es_state, data = ndp_trainer.train(key)

	INNER LOOP #0
	INNER LOOP #1
	INNER LOOP #2
OUTER LOOP #0 : avg = 0.003312010783702135, top = 0.004178614355623722, best = -0.004178614355623722
	INNER LOOP #0
	INNER LOOP #1
	INNER LOOP #2
OUTER LOOP #1 : avg = 0.003257723757997155, top = 0.004132397472858429, best = -0.004178614355623722
	INNER LOOP #0
	INNER LOOP #1


XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py(1760): _wrapped_callback
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py(1346): __call__
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/profiler.py(314): wrapper
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/dispatch.py(142): apply_primitive
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/core.py(790): process_primitive
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/core.py(383): bind_with_trace
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/core.py(2633): bind
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py(1031): scan_bind
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py(262): scan
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback
  /Users/erpl/Library/CloudStorage/OneDrive-ITU/Documents/projects/MetaNDP/metandp.py(115): train
  /var/folders/8d/5vjv51g957zb9g2039y64g1c0000gn/T/ipykernel_66784/2174908769.py(2): <module>
  /Users/erpl/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3508): run_code
  /Users/erpl/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3448): run_ast_nodes
  /Users/erpl/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3269): run_cell_async
  /Users/erpl/.local/lib/python3.11/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /Users/erpl/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3064): _run_cell
  /Users/erpl/.local/lib/python3.11/site-packages/IPython/core/interactiveshell.py(3009): run_cell
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel/zmqshell.py(540): run_cell
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel/ipkernel.py(422): do_execute
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py(729): execute_request
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py(409): dispatch_shell
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py(502): process_one
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py(513): dispatch_queue
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/asyncio/events.py(80): _run
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/asyncio/base_events.py(1922): _run_once
  /Users/erpl/anaconda3/envs/metandp/lib/python3.11/asyncio/base_events.py(607): run_forever
  /Users/erpl/.local/lib/python3.11/site-packages/tornado/platform/asyncio.py(195): start
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel/kernelapp.py(725): start
  /Users/erpl/.local/lib/python3.11/site-packages/traitlets/config/application.py(1043): launch_instance
  /Users/erpl/.local/lib/python3.11/site-packages/ipykernel_launcher.py(17): <module>
  <frozen runpy>(88): _run_code
  <frozen runpy>(198): _run_module_as_main


In [None]:
fitness = data['fitness']
avg_fitness = jnp.mean(fitness, axis=-1)
max_fitness = jnp.max(fitness, axis=-1)
best_fitness = -data['es_state'].best_fitness

plt.plot(avg_fitness, label='avg')
plt.plot(max_fitness, label='top-gen')
plt.plot(best_fitness, label='best')
plt.legend()
plt.xlabel('generation')
plt.ylabel('fitness')
plt.show()

In [None]:
best_params = es_state.best_member
best_params = parameter_reshaper.reshape(best_params[None, :])
best_params = jax.tree_map(lambda x: x[0], best_params)
test_key = jax.random.PRNGKey(90)
fit_best, best_test_data = evaluator.test(test_key, best_params, n_samples=1000, 
                                  render=False)
bd_extractor = evaluator.config.bd_extractor
bds_best = jax.vmap(bd_extractor)(best_test_data)
fit_best

In [None]:
worst_params = data['es_state'].best_member[0]
worst_params = parameter_reshaper.reshape(worst_params[None, :])
worst_params = jax.tree_map(lambda x: x[0], worst_params)
test_key = jax.random.PRNGKey(66)
fit_worst, worst_test_data = evaluator.test(test_key, worst_params, n_samples=1000, 
                                  render=False)
bd_extractor = evaluator.config.bd_extractor
bds_worst = jax.vmap(bd_extractor)(worst_test_data)
fit_worst

In [None]:
s=20.
plt.scatter(bds_best[:, 0], bds_best[:, 1], label='evolved NDP',
           alpha = .3, s=s)
plt.scatter(bds_worst[:, 0], bds_worst[:, 1], label='random NDP',
           alpha=.3, s=s)
plt.legend()
plt.xlabel('min x')
plt.ylabel('max x')
plt.show()

In [None]:
import pandas as pd
import seaborn as sns

In [None]:
data_best = pd.DataFrame()
data_best[["BC_1", "BC_2"]] = bds_best
data_best['NDP'] = "HyperNCA - evolved"

In [None]:
data_wrst = pd.DataFrame()
data_wrst[["BC_1", "BC_2"]] = bds_worst
data_wrst['NDP'] = "HyperNCA - random"

In [None]:
full_data = pd.concat([data_best, data_wrst])

In [None]:
sns.jointplot(data=full_data, x="BC_1", y="BC_2", hue="NDP", kind='scatter',
             alpha = .3, s= 15., legend=False)

In [None]:
best_members = data['es_state'].best_member

### Evaluate compositionality

In [None]:
from evaluators import CompositionalityEvaluator, CompositionalityEvaluator_Config

In [None]:
evaluator_config = CompositionalityEvaluator_Config(
    epochs = 1,
    env = env,
    env_params = env_params,
    env_steps = 200,
    mlp_hidden_dims = mlp_hidden_dims,
    mlp_hidden_layers = mlp_hidden_layers,
    n_params = z_dims,
    bd_extractor = bd_mountain_car,
    popsize = 60
)
evaluator = CompositionalityEvaluator(evaluator_config, ndp)

In [None]:
key_eval = jax.random.PRNGKey(42)
keys_eval = jax.random.split(key_eval, best_members.shape[0])
ndp_params = parameter_reshaper.reshape(best_members)
compos, data_compo = jax.vmap(evaluator.eval)(ndp_params, keys_eval)

In [None]:
plt.plot(compos)

# Eval

In [None]:
# from evaluators.simple_evaluator import SimpleEvaluator, SimpleEvaluator_Config

In [None]:
# opt_config = SimpleEvaluator_Config(
#     epochs = 200,
#     env = env,
#     env_params = env_params,
#     env_steps = 200,
#     mlp_hidden_dims = mlp_hidden_dims,
#     mlp_hidden_layers = mlp_hidden_layers,
#     n_params = z_dims,
#     es = "openes",
#     popsize = 128
# )

# opt = SimpleEvaluator(opt_config, ndp)

In [None]:
# key = jax.random.PRNGKey(33)
# best_fit, best_data = opt.eval(best_params, key)
# best_fit

In [None]:
# fitness = best_data['fitness']
# avg_fitness = jnp.mean(fitness, axis=-1)
# max_fitness = jnp.max(fitness, axis=-1)
# best_fitness = -best_data['es_state'].best_fitness

# plt.plot(avg_fitness, label='avg')
# plt.plot(max_fitness, label='top-gen')
# plt.plot(best_fitness, label='best')
# plt.legend()
# plt.xlabel('generation')
# plt.ylabel('fitness')
# plt.show()

In [None]:
# worst_fit, worst_data = opt.eval(worst_params, key)

In [None]:
# fitness = worst_data['fitness']
# avg_fitness = jnp.mean(fitness, axis=-1)
# max_fitness = jnp.max(fitness, axis=-1)
# best_fitness = -best_data['es_state'].best_fitness

# plt.plot(avg_fitness, label='avg')
# plt.plot(max_fitness, label='top-gen')
# plt.plot(best_fitness, label='best')
# plt.legend()
# plt.xlabel('generation')
# plt.ylabel('fitness')
# plt.show()

# Visualize

In [None]:
# %%capture
# best_params = es_state.best_member
# best_params = parameter_reshaper.reshape(best_params[None, :])
# best_params = jax.tree_map(lambda x: x[0], best_params)
# test_key = jax.random.PRNGKey(66)
# fit, test_data, files = evaluator.test(test_key, best_params, n_samples=5, 
#                                   render=True, save_file="best")

In [None]:
# ipyplot.plot_images(files)

In [None]:
# %%capture
# worst_params = data['es_state'].best_member[0]
# worst_params = parameter_reshaper.reshape(worst_params[None, :])
# worst_params = jax.tree_map(lambda x: x[0], worst_params)
# test_key = jax.random.PRNGKey(66)
# fit, test_data, files = evaluator.test(test_key, worst_params, n_samples=5, 
#                                   render=True, save_file="best")

In [None]:
# ipyplot.plot_images(files)