imports

In [11]:
import jax
from jax import numpy as jnp
from flax import nnx
import flax
import evojax

import numpy as np

import sys
sys.path.append("..")

import os

from IPython.display import Image

from models.ctm import CTM

# dynamic reload
import importlib
importlib.reload(sys.modules["models.ctm"])

<module 'models.ctm' from '/home/kevin/projects/ctm-experiments/experiments/../models/ctm.py'>

create model

In [12]:
def flatten_params(model):
    """Extract parameters as a single flattened vector."""
    # Get the model state, filtering for parameters only
    state = nnx.state(model, nnx.Param)
    
    # Flatten the parameter state
    flat_params, tree_def = jax.tree_util.tree_flatten(state)
    param_shapes = [p.shape for p in flat_params]
    
    # Pre-compute split indices as concrete values
    param_sizes = [int(np.prod(shape)) for shape in param_shapes]
    split_indices = [int(idx) for idx in np.cumsum(param_sizes[:-1])]
    
    flattened_vector = jnp.concatenate([p.flatten() for p in flat_params])
    return flattened_vector, (tree_def, param_shapes, split_indices)


def unflatten_and_set_params(model, flattened_vector, restore_info):
    """Restore parameters from a flattened vector and update the model."""
    tree_def, shapes, split_indices = restore_info
    
    # Use the pre-computed concrete split indices
    param_arrays = jnp.split(flattened_vector, split_indices)
    
    # Reshape each array back to its original shape
    reshaped_params = [arr.reshape(shape) for arr, shape in zip(param_arrays, shapes)]
    
    # Reconstruct the parameter tree
    new_param_state = jax.tree_util.tree_unflatten(tree_def, reshaped_params)
    
    # Update the model with new parameters
    nnx.update(model, new_param_state)

# Test the functions
config = {
    "input_size": 10,
    "hidden_size": 10,
    "output_size": 10,
}

ctm = CTM(config, nnx.Rngs(0))

# Get original output
original_output = ctm(jnp.zeros((1, 10)))
print("Original output:", original_output)

# Flatten parameters
flattened_params, restore_info = flatten_params(ctm)
print(f"Flattened parameter vector shape: {flattened_params.shape}")

# # Modify parameters (add noise)
# modified_params = flattened_params + 0.1 * jax.random.normal(jax.random.PRNGKey(42), flattened_params.shape)

# # Set modified parameters
# unflatten_and_set_params(ctm, modified_params, restore_info)

Original output: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Flattened parameter vector shape: (220,)


In [13]:
from brax import envs
from brax.io import html

from evojax import SimManager
from evojax import ObsNormalizer
from evojax.algo import PGPE
from evojax.policy import MLPPolicy
from evojax.policy.base import PolicyState
from evojax.policy.base import PolicyNetwork
from evojax.task.cartpole import CartPoleSwingUp
from evojax.task.slimevolley import SlimeVolley
from evojax.util import create_logger
from evojax import Trainer

print('jax.devices():')
jax.devices()

# Let's create a directory to save logs and models.
log_dir = '../logs'
logger = create_logger(name='EvoJAX', log_dir=log_dir)
logger.info('Testing CTM')

logger.info('Jax backend: {}'.format(jax.local_devices()))
!nvidia-smi --query-gpu=name --format=csv,noheader

class CTMPolicy(PolicyNetwork):
    def __init__(self, input_dim, output_dim, hidden_dim, rngs=nnx.Rngs(0)):
        self.ctm = CTM({"input_size": input_dim, "hidden_size": hidden_dim, "output_size": output_dim}, rngs)
        params, restore_info = flatten_params(self.ctm)
        self.restore_info = restore_info
        self.num_params = params.shape[0]

    def get_actions(self, t_states, params, p_states):
        def get_action_single(single_params, single_obs):
            unflatten_and_set_params(self.ctm, single_params, self.restore_info)
            return self.ctm(single_obs)
        
        # vmap over parameter vectors
        actions = jax.vmap(get_action_single)(params, t_states.obs)
        return actions, p_states

EvoJAX: 2025-07-29 02:41:54,662 [INFO] Testing CTM
EvoJAX: 2025-07-29 02:41:54,663 [INFO] Jax backend: [CudaDevice(id=0)]


jax.devices():
NVIDIA GeForce RTX 3060 Ti


  pid, fd = os.forkpty()


In [None]:
seed = 42  # Wish me luck!

train_task = SlimeVolley(test=False, max_steps=3000)
test_task = SlimeVolley(test=True, max_steps=1000)

# # We use a feedforward network as our policy.
# # By default, MLPPolicy uses "tanh" as its activation function for the output.
# policy = MLPPolicy(
#     input_dim=train_task.obs_shape[0],
#     hidden_dims=[64, 64],
#     output_dim=train_task.act_shape[0],
#     logger=logger,
# )

policy = CTMPolicy(train_task.obs_shape[0], train_task.act_shape[0], 3, nnx.Rngs(0))

print(train_task.obs_shape)
print(train_task.act_shape)

print(policy.num_params)

# We use PGPE as our evolution algorithm.
# If you want to know more about the algorithm, please take a look at the paper:
# https://people.idsia.ch/~juergen/nn2010.pdf 
solver = PGPE(
    pop_size=16,
    param_size=policy.num_params,
    optimizer='adam',
    center_learning_rate=0.05,
    seed=seed,
)

# Now that we have all the three components instantiated, we can create a
# trainer and start the training process.
trainer = Trainer(
    policy=policy,
    solver=solver,
    train_task=train_task,
    test_task=test_task,
    max_iter=600,
    log_interval=100,
    test_interval=200,
    n_repeats=5, # duplicates
    n_evaluations=128,
    seed=seed,
    log_dir=log_dir,
    logger=logger,
)

_ = trainer.run()

EvoJAX: 2025-07-29 02:57:15,875 [INFO] use_for_loop=False
EvoJAX: 2025-07-29 02:57:15,892 [INFO] Start to train for 600 iterations.


(12,)
(3,)
35


EvoJAX: 2025-07-29 02:57:48,561 [INFO] Iter=100, size=16, max=-25.6000, avg=-28.2750, min=-31.6000, std=1.6460
EvoJAX: 2025-07-29 02:58:17,014 [INFO] Iter=200, size=16, max=-25.2000, avg=-27.7125, min=-30.0000, std=1.3546
EvoJAX: 2025-07-29 02:58:19,768 [INFO] [TEST] Iter=200, #tests=128, max=-2.0000, avg=-4.7188, min=-5.0000, std=0.5581


In [None]:
# Let's visualize the learned policy.

def render(task, algo, policy):
    """Render the learned policy."""

    task_reset_fn = jax.jit(test_task.reset)
    policy_reset_fn = jax.jit(policy.reset)
    step_fn = jax.jit(test_task.step)
    act_fn = jax.jit(policy.get_actions)

    params = algo.best_params[None, :]
    task_s = task_reset_fn(jax.random.PRNGKey(seed=seed)[None, :])
    policy_s = policy_reset_fn(task_s)

    single_task_s = jax.tree.map(lambda x: x[0], task_s)

    # images = [CartPoleSwingUp.render(task_s, 0)]
    images = [SlimeVolley.render(single_task_s, 0)]
    done = False
    step = 0
    reward = 0
    while not done:
        act, policy_s = act_fn(task_s, params, policy_s)
        task_s, r, d = step_fn(task_s, act)
        step += 1
        reward = reward + r
        done = bool(d[0])
        if step % 3 == 0:
            # images.append(CartPoleSwingUp.render(task_s, 0))
            single_task_s = jax.tree.map(lambda x: x[0], task_s)
            images.append(SlimeVolley.render(single_task_s, 0))
    print('reward={}'.format(reward))
    return images


imgs = render(test_task, solver, policy)
# gif_file = os.path.join(log_dir, 'slimevolley.gif')
# imgs[0].save(
#     gif_file, save_all=True, append_images=imgs[1:], duration=40, loop=0)
# Image(open(gif_file,'rb').read())

# display mp4
import imageio
mp4_file = os.path.join(log_dir, 'slimevolley.mp4')
imageio.mimsave(mp4_file, imgs, fps=30)
from IPython.display import Video
Video(mp4_file)

  single_task_s = jax.tree_map(lambda x: x[0], task_s)
  single_task_s = jax.tree_map(lambda x: x[0], task_s)


reward=[-4]


  self.pid = _fork_exec(
