CMA-ES on CTRNN_SDI

In [17]:
import jax
import jax.numpy as jnp
import jax.random as jr
import diffrax as dfx
import equinox as eqx
from oua.oua import OUAModel
from neuromorphic_intelligence.models import StochasticDoubleIntegrator, CTRNN, CoupledSystem
import matplotlib.pyplot as plt
from matplotlib.transforms import ScaledTranslation
from evosax.problems.problem import Problem, State
from functools import partial
from evosax.types import Fitness, Metrics, Population, Solution

In [18]:
class CTRNNSDIProblem(Problem):

    @partial(jax.jit, static_argnames=("self",))
    def init(self, key: jax.Array) -> State:
        """Initialize state."""
        return State(counter=0)

    @partial(jax.jit, static_argnames=("self",))
    def eval(
        self,
        key: jax.Array,
        solutions: Population,
        state: State,
    ) -> tuple[Fitness, State, Metrics]:
        """Evaluate a batch of solutions."""

        param_rate, noise_rate, mean_rate, reward_rate = solutions

        agent = CTRNN(num_inputs=2, num_neurons=2, num_outputs=1, noise_scale=0.0)
        env = StochasticDoubleIntegrator(mass=1.0, damping_factor=0.5, noise_scale=0.1)
        coupled_system = CoupledSystem(agent=agent, env=env)
        system = OUAModel(model=coupled_system, param_rate=param_rate, noise_rate=noise_rate, mean_rate=mean_rate, reward_rate=reward_rate)

        def reward(t, x, args):
            (agent_state, env_state), agent_params, agent_means, avg_reward = x
            position, velocity = env_state
            tau, A, bias, B, C = agent_params
            u = jax.nn.tanh(jnp.squeeze(C @ agent_state))
            return -0.9 * jnp.linalg.norm(position)**2 - 0.1 * jnp.linalg.norm(u)**2

        def custom_output(t, x, args):
            return x, args['reward'](t, x, args)

        args = {'reward': reward}

        system = OUAModel(model=coupled_system, param_rate=param_rate, noise_rate=0.0, mean_rate=mean_rate, reward_rate=reward_rate)

        key, = jr.split(key, 1)
        sol = dfx.diffeqsolve(system.terms(key), dfx.EulerHeun(), args=args, t0=0, t1=300, dt0=0.01, y0=system.initial, saveat=dfx.SaveAt(steps=True, fn=custom_output), adjoint=dfx.DirectAdjoint(), max_steps=int(1e6))

        x, reward2 = sol.ys

        return jnp.cumsum(reward2, axis=0)[-1], state, {}

    @partial(jax.jit, static_argnames=("self",))
    def sample(self, key: jax.Array) -> Solution:
        """Sample a solution in the search space.
            Hyperparameters are (lambda, sigma, eta, rho). We keep k, z0, phi0 fixed with k nr of neurons, z0 = neuron state and phi0 = theta0, mu0, nu0 where theta0 are all CTRNN parameters
        """

        # Sample hyperparameters
        key, subkey = jr.split(key)
        lambda_ = jr.uniform(subkey, shape=(1,), minval=0.0, maxval=5.0)
        key, subkey = jr.split(key)
        sigma = jr.uniform(subkey, shape=(1,), minval=0.01, maxval=1.0)
        key, subkey = jr.split(key)
        eta = jr.uniform(subkey, shape=(1,), minval=0.0, maxval=100.0)
        key, subkey = jr.split(key)
        rho = jr.uniform(subkey, shape=(1,), minval=0.0, maxval=5.0)

        return (lambda_, sigma, eta, rho)

        


In [19]:
from evosax.algorithms import CMA_ES as ES

key = jax.random.PRNGKey(0)

num_generations = 64
population_size = 16

problem = CTRNNSDIProblem()

key, subkey = jax.random.split(key)
problem_state = problem.init(subkey)

# Instantiate evolution strategy
# solution is a PyTree
# hyperparameters are (z0, phi0, lambda, sigma, eta, rho) with z0 = neuron state and phi0 = theta0, mu0, nu0 where theta0 are all CTRNN parameters
key, subkey = jax.random.split(key)
solution = problem.sample(subkey)

es = ES(
    population_size=population_size,
    solution=solution,  # requires a dummy solution
)

# Use default parameters
params = es.default_params

# Initialize evolution strategy
key, subkey = jax.random.split(key)
state = es.init(subkey, solution, params)

In [20]:
key, subkey = jax.random.split(key)
key_ask, key_eval, key_tell = jax.random.split(subkey, 3)

# Generate a set of candidate solutions to evaluate
population, state = es.ask(key_ask, state, params)

# Evaluate the fitness of the population
fitness, problem_state, info = problem.eval(key_eval, population, problem_state)

# Update the evolution strategy
state, metrics = es.tell(key_tell, population, fitness, state, params)

ValueError: Terms are not compatible with solver!

In [None]:
key, subkey = jax.random.split(key)
state = es.init(subkey, solution, params)

metrics_log = []
for i in range(num_generations):
    key, subkey = jax.random.split(key)
    key_ask, key_eval, key_tell = jax.random.split(subkey, 3)

    population, state = es.ask(key_ask, state, params)

    fitness, problem_state, info = problem.eval(key_eval, population, problem_state)

    state, metrics = es.tell(key_tell, population, fitness, state, params)

    # Log metrics
    metrics_log.append(metrics)

In [None]:
# Extract the best fitness values across generations
generations = [metrics["generation_counter"] for metrics in metrics_log]
best_fitness = [metrics["best_fitness"] for metrics in metrics_log]

plt.figure(figsize=(10, 5))
plt.plot(generations, best_fitness, label="Best Fitness", marker="o", markersize=3)

plt.title("Best fitness over generations")
plt.xlabel("Generation")
plt.ylabel("Fitness")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()