<a href="https://colab.research.google.com/github/google/evojax/blob/main/examples/notebooks/GymnaxEvosax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook showcases how to use any [`evosax`](https://github.com/RobertTLange/evosax) evolutionary optimization algorithm and [`gymnax`](https://github.com/RobertTLange/gymnax) RL tasks jointly with the EvoJAX neuroevolution pipeline.

In [None]:
# @title Install Packages

from IPython.display import clear_output

!pip install chex
!pip install git+https://github.com/google/evojax.git@main
!pip install git+https://github.com/RobertTLange/gymnax.git@main
!pip install git+https://github.com/RobertTLange/evosax.git@main

clear_output()

In [1]:
# @title Import Libraries

import time
import numpy as np
import matplotlib.pyplot as plt
from typing import Sequence, Tuple

import jax
import jax.numpy as jnp
import chex
from flax import linen as nn

from evojax import SimManager
from evojax import ObsNormalizer
from evojax.policy.base import PolicyNetwork
from evojax.policy.base import PolicyState
from evojax.task.base import TaskState
from evojax.util import get_params_format_fn

import os
if 'COLAB_TPU_ADDR' in os.environ:
    from jax.tools import colab_tpu
    colab_tpu.setup_tpu()

print('jax.devices():')
jax.devices()

jax.devices():


[StreamExecutorGpuDevice(id=0, process_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0),
 StreamExecutorGpuDevice(id=2, process_index=0),
 StreamExecutorGpuDevice(id=3, process_index=0)]

# Import gymnax task and define MinAtar Policy

In [2]:
from gymnax.utils.evojax_wrapper import GymnaxTask

# Define the Gymnax Task
env_name = "Asterix-MinAtar"
max_steps = 500
train_task = GymnaxTask(env_name, max_steps=max_steps, test=False)
test_task = GymnaxTask(env_name, max_steps=max_steps, test=True)

In [3]:
# Define the MinAtar CNN Policy
class MinAtarCNN(nn.Module):
    """A general purpose conv net model."""

    hidden_dims: Sequence[int]
    out_dim: int

    @nn.compact
    def __call__(self, x: chex.Array) -> chex.Array:
        x = nn.Conv(
            features=16,
            kernel_size=(3, 3),
            padding="SAME",
            strides=1,
        )(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))
        for hidden_dim in self.hidden_dims:
            x = nn.relu(
                nn.Dense(
                    features=hidden_dim,
                )(x)
            )
        x = nn.Dense(features=self.out_dim)(x)
        return x


class MinAtarPolicy(PolicyNetwork):
    """Deterministic CNN policy - greedy action selection."""

    def __init__(
        self,
        input_dim: Sequence[int],
        hidden_dims: Sequence[int],
        output_dim: int,
    ):
        self.input_dim = [1, *input_dim]
        self.model = MinAtarCNN(hidden_dims=hidden_dims, out_dim=output_dim)
        self.params = self.model.init(
            jax.random.PRNGKey(0), jnp.ones(self.input_dim)
        )
        self.num_params, format_params_fn = get_params_format_fn(self.params)
        self._format_params_fn = jax.vmap(format_params_fn)
        self._forward_fn = jax.vmap(self.model.apply)

    def get_actions(
        self, t_states: TaskState, params: chex.Array, p_states: PolicyState
    ) -> Tuple[chex.Array, PolicyState]:
        params = self._format_params_fn(params)
        obs = jnp.expand_dims(t_states.obs, axis=1)
        activations = self._forward_fn(params, obs)
        action = jnp.argmax(activations, axis=2).squeeze()
        return action, p_states


policy = MinAtarPolicy(
    input_dim=train_task.obs_shape,
    output_dim=train_task.num_actions,
    hidden_dims=[32],
)

# Setup evosax strategy wrapper for OpenAI-ES

In [4]:
from evosax import Strategies
from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper

In [5]:
es_config = {
    "maximize": True,
    "centered_rank": True,
    "lrate_init": 0.01,
    "lrate_decay": 0.999,
    "lrate_limit": 0.001,
    "sigma_init": 0.1,
    "sigma_decay": 0.999,
    "sigma_limit": 0.01
}

solver = Evosax2JAX_Wrapper(
    Strategies["OpenES"],
    param_size=policy.num_params,
    pop_size=256,
    es_config=es_config,
    seed=42,
)

# Run the neuroevolution training loop

In [6]:
obs_normalizer = ObsNormalizer(
        obs_shape=train_task.obs_shape, dummy=True
    )
sim_mgr = SimManager(
    policy_net=policy,
    train_vec_task=train_task,
    valid_vec_task=test_task,
    seed=42,
    obs_normalizer=obs_normalizer,
    pop_size=256,
    use_for_loop=False,
    n_repeats=4,
    test_n_repeats=1,
    n_evaluations=64,
)

In [7]:
 print(f"START EVOLVING {policy.num_params} PARAMS.")
# Run ES Loop.
num_generations = 5000
for gen_counter in range(num_generations):
    
    params = solver.ask()
    scores, _ = sim_mgr.eval_params(params=params, test=False)
    solver.tell(fitness=scores)

    if gen_counter == 0 or (gen_counter + 1) % 200 == 0:
        test_scores, _ = sim_mgr.eval_params(
            params=solver.best_params, test=True
        )
        print(
            {
                "num_gens": gen_counter + 1,
                "train_perf": float(np.nanmean(scores)),
                "test_perf": float(np.nanmean(test_scores)),
            },
        )

START EVOLVING 51989 PARAMS.




{'num_gens': 1, 'train_perf': 0.6416015625, 'test_perf': 0.421875}
{'num_gens': 200, 'train_perf': 0.5029296875, 'test_perf': 0.53125}
{'num_gens': 400, 'train_perf': 0.494140625, 'test_perf': 0.34375}
{'num_gens': 600, 'train_perf': 0.103515625, 'test_perf': 0.84375}
{'num_gens': 800, 'train_perf': 1.3564453125, 'test_perf': 1.0625}
{'num_gens': 1000, 'train_perf': 0.97265625, 'test_perf': 1.59375}
{'num_gens': 1200, 'train_perf': 2.9208984375, 'test_perf': 1.609375}
{'num_gens': 1400, 'train_perf': 2.2841796875, 'test_perf': 2.59375}
{'num_gens': 1600, 'train_perf': 2.294921875, 'test_perf': 2.5}
{'num_gens': 1800, 'train_perf': 4.8662109375, 'test_perf': 3.671875}
{'num_gens': 2000, 'train_perf': 2.4375, 'test_perf': 4.96875}
{'num_gens': 2200, 'train_perf': 3.634765625, 'test_perf': 3.703125}
{'num_gens': 2400, 'train_perf': 3.8642578125, 'test_perf': 5.0625}
{'num_gens': 2600, 'train_perf': 5.5986328125, 'test_perf': 4.5}
{'num_gens': 2800, 'train_perf': 3.359375, 'test_perf': 5.5