# Training an SNN using Neuroevolution!

Cartpole

In [19]:
import spyx
import spyx.nn as snn

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".70"
from jax import numpy as jnp

from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# optimize the parameters using evosax
import evosax
from evosax.strategies import OpenES as ES

import gymnax

# rendering tools
import matplotlib.pyplot as plt
%matplotlib notebook
import graphviz
import mediapy as media

## Create Env

In [70]:
rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("CartPole-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
done

Array(False, dtype=bool)

In [146]:
class binarize:
    def __init__(self, neuron_count, min_val, max_val):
        self.neuron_count = neuron_count
        self.min_val = min_val
        self.max_val = max_val
        
    def __call__(self, obs):
        digital = jnp.digitize(obs, jnp.linspace(self.min_val, self.max_val, self.neuron_count))
        return jax.nn.one_hot(digital, self.neuron_count)
    
class NeuromorphicCartpole:
    def __init__(self, angle_neurons=16, cart_v_neurons=16, pole_w_neurons=16):
        self.angle_converter = binarize(angle_neurons, -.21, .21)
        self.v_converter = binarize(cart_v_neurons, -3.5, 3.5)
        self.w_converter = binarize(pole_w_neurons, -3.5, 3.5)
        
    def __call__(self, obs):
        cart_v = self.v_converter(obs[1]) #self.v_converter(obs[1])
        theta = self.angle_converter(obs[2])
        pole_w = self.w_converter(obs[3]) #self.w_converter(obs[3])
        
        return jnp.concatenate([cart_v, theta, pole_w])
        

In [147]:
adapter = NeuromorphicCartpole()

## SNN

In [148]:
    
def action_selection(spike_trains):
    return jnp.argmax(spike_trains, axis=0)
        
def controller(x, state):
    # seqs is [T, F].
    core = hk.DeepRNN([
        hk.Linear(64, with_bias=False),
        snn.LIF(64, beta=0.8, activation=spyx.activation.Heaviside()),
        hk.Linear(2, with_bias=False),
        snn.LI(2)
    ])
    
    spikes, out_state = core(x, state)
    return spikes, out_state

In [149]:
key = jax.random.PRNGKey(0)
init_state = (jnp.zeros(64, dtype=jnp.float16), jnp.zeros(2, dtype=jnp.float16))
policy = hk.without_apply_rng(hk.transform(controller))
policy_params = policy.init(rng=key, x=adapter(obs), state=init_state)

In [150]:
policy.apply(policy_params, adapter(obs), init_state)

(Array([0., 0.], dtype=float16),
 (Array([ 0.3274  , -0.2544  ,  0.2534  ,  0.0263  , -0.1855  , -0.1671  ,
         -0.00456 ,  0.3271  ,  0.4004  , -0.1048  , -0.1736  , -0.2947  ,
         -0.276   , -0.1814  , -0.2439  , -0.2079  , -0.3247  , -0.01633 ,
          0.1384  ,  0.2717  ,  0.2734  ,  0.5234  ,  0.2039  ,  0.126   ,
          0.05368 , -0.09973 ,  0.1622  ,  0.1447  ,  0.05576 , -0.10504 ,
         -0.03168 , -0.3457  ,  0.01399 ,  0.1906  ,  0.0853  , -0.2593  ,
          0.1192  , -0.11957 , -0.0918  ,  0.0825  ,  0.1083  , -0.3496  ,
          0.1732  , -0.167   , -0.3674  ,  0.3113  , -0.0644  , -0.3525  ,
         -0.1984  , -0.2327  , -0.4382  , -0.11115 ,  0.07574 ,  0.1929  ,
          0.0474  ,  0.491   ,  0.005825, -0.2385  ,  0.1746  ,  0.1804  ,
          0.3054  , -0.1537  ,  0.1804  ,  0.0543  ], dtype=float16),
  Array([0., 0.], dtype=float16)))

In [151]:
adapter(obs).shape

(48,)

## Evolution

In [152]:
# Instantiate the environment & its settings.
env, env_params = gymnax.make("CartPole-v1")

def rollout(policy_params, init_policy_state, env_params, rng_input, steps_in_episode):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode = jax.random.split(rng_input)
    obs, env_state = env.reset(rng_reset, env_params)

    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""
        obs, env_state, policy_params, policy_state, prev_done, rng = state_input
        rng, rng_step, rng_net = jax.random.split(rng, 3)
        activation, new_policy_state = policy.apply(policy_params, adapter(obs), policy_state)
        action = action_selection(activation)
        next_obs, next_state, reward, done, _ = env.step(
            rng_step, env_state, action, env_params
        )
        carry = [next_obs, next_state, policy_params, new_policy_state, prev_done + done, rng]
        return carry, [obs, action, reward, next_obs, prev_done+done]

    # Scan over episode step loop
    _, scan_out = jax.lax.scan(
        policy_step,
        [obs, env_state, policy_params, init_policy_state, 
         False, rng_episode],
        (),
        steps_in_episode
    )
    # Return masked sum of rewards accumulated by agent in episode
    obs, action, reward, next_obs, done = scan_out
    return obs, action, reward, next_obs, done

jit_rollout = jax.jit(rollout, static_argnums=[4])
vector_rollout = jax.vmap(jit_rollout, (0,None,None,None,None))

In [153]:
def evolution(SNN, params, epochs=25, trials=32, steps=500, key=0):

    rng = jax.random.PRNGKey(key)
    param_reshaper = evosax.ParameterReshaper(params)
    
    # Instantiate and initialize the evolution strategy
    strategy = ES(popsize=128,
                      num_dims=param_reshaper.total_params,
                      opt_name="adam"
                    )

    es_params = strategy.default_params
    es_params = es_params.replace(sigma_init=0.1, sigma_decay=0.999, sigma_limit=0.01)
    es_params = es_params.replace(opt_params=es_params.opt_params.replace(
        lrate_init=0.1, lrate_decay=0.999, lrate_limit=0.001))
    
    # check the initialization here....
    strat_state = strategy.initialize(rng, es_params)
        
    @jax.jit
    def step(rng, pop):
        rng, rng_eval = jax.random.split(rng)
        # ASK
        population_params = param_reshaper.reshape(pop)
        init_policy_state = init_state
                
        # EVAL
        obs, action, reward, next_obs, done = \
            vector_rollout(population_params, init_policy_state, env_params, rng_eval, steps)
        

        # TELL
        total_reward = jnp.sum(reward*(1-done), axis=-1)        
        return rng, total_reward, done
    
    
    
    for gen in range(epochs):
        
        # figure out way to JIT this inner loop better to account for trials
        total_reward = jnp.zeros([128])
        
        rng, rng_ask = jax.random.split(rng)
        pop, strat_state = strategy.ask(rng_ask, strat_state)

        
        pbar = tqdm([*range(trials)])
        pbar.set_description("Epoch #{}".format(gen))
        for trials_so_far in pbar:
            
            rng, reward, done = step(rng, pop)
            total_reward += reward
            pbar.set_postfix(Reward=jnp.max(total_reward)/(trials_so_far+1))
        
        strat_state = strategy.tell(pop, -total_reward/trials, strat_state)
            
        
    elite = param_reshaper.reshape(jnp.array([strat_state.best_member]))
    return jax.tree_util.tree_map(jnp.squeeze, elite)

In [154]:
elite_params = evolution(policy, policy_params)

ParameterReshaper: 3200 parameters detected for optimization.


Epoch #0: 100%|████████████████| 32/32 [00:05<00:00,  6.18it/s, Reward=21.53125]
Epoch #1: 100%|██████████████████| 32/32 [00:04<00:00,  7.28it/s, Reward=31.625]
Epoch #2: 100%|████████████████| 32/32 [00:04<00:00,  7.28it/s, Reward=58.90625]
Epoch #3: 100%|████████████████| 32/32 [00:04<00:00,  7.25it/s, Reward=62.71875]
Epoch #4: 100%|███████████████████| 32/32 [00:04<00:00,  7.26it/s, Reward=71.25]
Epoch #5: 100%|█████████████████| 32/32 [00:04<00:00,  7.25it/s, Reward=112.875]
Epoch #6: 100%|███████████████| 32/32 [00:04<00:00,  7.26it/s, Reward=117.21875]
Epoch #7: 100%|████████████████| 32/32 [00:04<00:00,  7.25it/s, Reward=160.0625]
Epoch #8: 100%|███████████████| 32/32 [00:04<00:00,  7.25it/s, Reward=213.90625]
Epoch #9: 100%|███████████████| 32/32 [00:04<00:00,  7.24it/s, Reward=182.59375]
Epoch #10: 100%|███████████████| 32/32 [00:04<00:00,  7.25it/s, Reward=199.1875]
Epoch #11: 100%|████████████████| 32/32 [00:04<00:00,  7.25it/s, Reward=199.625]
Epoch #12: 100%|████████████

In [157]:
#activation_seq = []
action_seq = []
state_seq, reward_seq = [], []
rng, rng_reset = jax.random.split(rng)
obs, env_state = env.reset(rng_reset, env_params)
new_policy_state = init_state
while True:
    state_seq.append(env_state)
    rng, rng_step = jax.random.split(rng, 2)
    activation, new_policy_state = policy.apply(elite_params, adapter(obs), new_policy_state)
    action = action_selection(activation)
    action_seq.append(action)
    #activation_seq.append(activation)
    next_obs, next_env_state, reward, done, info = env.step(
        rng_step, env_state, action, env_params
    )
    reward_seq.append(reward)
    if done:
        break
    else:
        obs = next_obs
        env_state = next_env_state

cum_rewards = jnp.sum(jnp.array(reward_seq))
cum_rewards

Array(265., dtype=float32)

In [115]:
action_seq

[Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=