# Training an SNN using Neuroevolution!

In [1]:
import spyx
import spyx.nn as snn
from synecdoche import hyper

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".70"
from jax import numpy as jnp
import numpy as np

from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# optimize the parameters using evosax
import evosax
from evosax.strategies import LM_MA_ES as LMMAES

# rendering tools
import matplotlib.pyplot as plt
%matplotlib notebook
import graphviz
import mediapy as media

  warn(


## Data Loading

In [2]:
mnist_dl = spyx.data.MNIST_loader(64)

## SNN

In [3]:
def mnist_snn(x):
    # seqs is [T, F].
    core = hk.DeepRNN([
        hk.Flatten(),
        hk.Linear(128, with_bias=False),
        snn.LIF(128, activation=spyx.activation.SuperSpike()),
        hk.Linear(128, with_bias=False),
        snn.LIF(128, activation=spyx.activation.SuperSpike()),
        hk.Linear(10, with_bias=False),
        snn.LI(10)
    ])
    spikes, V = hk.dynamic_unroll(core, x.astype(jnp.float32), core.initial_state(x.shape[0]), time_major=False)
    return spikes, V

In [4]:
key = jax.random.PRNGKey(0)
SNN = hk.without_apply_rng(hk.transform(mnist_snn))
params = SNN.init(rng=key, x=mnist_dl.train_step().obs)

In [5]:
HyperNet = hk.without_apply_rng(hk.transform(lambda: hyper.DCT(256, params)()))
hypernet_params = HyperNet.init(key)

In [6]:
from jax import tree_util as tree

def param_count(hypernetwork_params):
    """ Count the number of learnable params in a network"""
    return sum(tree.tree_leaves(tree.tree_map(jnp.size, hypernetwork_params)))

In [7]:
param_count(hypernet_params)

1280

In [8]:
param_count(params)

118272

## Evolution

In [11]:
def evolution(SNN, HyperNet, hyper_params, dl, epochs=15, test_every=5, key=0):

    rng = jax.random.PRNGKey(key)
    param_reshaper = evosax.ParameterReshaper(hyper_params)
    

    strategy = LMMAES(popsize=64,
                num_dims=param_reshaper.total_params,
                )
    
    es_params = strategy.default_params
    es_params = es_params.replace(init_min= -1, init_max=1)
    # check the initialization here....
    state = strategy.initialize(rng)
    
    @jax.jit
    def forward(hyper_params, events):
        return SNN.apply(HyperNet.apply(hyper_params), events)
        
    sim_fn = jax.vmap(forward, (0, None)) #jit this outside the loop...
    acc_fn = jax.vmap(spyx.loss.integral_accuracy, (0, None))
    loss_fn = jax.vmap(spyx.loss.integral_crossentropy, (0, None))
    
    @jax.jit
    def step(rng, state, events, targets):
        rng, rng_ask = jax.random.split(rng, 2)
        # ASK
        pop, state = strategy.ask(rng_ask, state)
        population_params = param_reshaper.reshape(pop.astype(jnp.float32))
        # EVAL
        spikes, V = sim_fn(population_params, events)
        loss = loss_fn(spikes, targets)
        # TELL
        state = strategy.tell(pop, loss, state)        
        
        return rng, state, loss
    
    
    for gen in range(epochs):
        pbar = tqdm([*range(dl.train_len//dl.batch_size)])
        pbar.set_description("Epoch #{}".format(gen))
        dl.train_reset()
        for _ in pbar:
            events, targets = dl.train_step() # non-jittable...

            rng, state, loss = step(rng, state, events, targets)
            
            pbar.set_postfix(Loss=jnp.min(loss))
        
        elite = param_reshaper.reshape(jnp.array([state.best_member]))
        if gen % test_every == test_every-1:
            dl.val_reset()
            accs = []
            losses = []
            
            pbar = tqdm([*range(dl.val_len//dl.batch_size)])
            pbar.set_description("Validate")
            for _ in pbar:
                events, targets = dl.val_step()
                spikes, V = sim_fn(elite, events)
                acc, pred = acc_fn(spikes, targets)
                loss = loss_fn(spikes, targets)
                
                accs.append(acc)
                losses.append(loss)
                
                pbar.set_postfix(Loss=np.mean(losses), Accuracy=np.mean(accs))
        
    return jax.tree_util.tree_map(jnp.squeeze, elite)

In [12]:
elite_params = evolution(SNN, HyperNet, hypernet_params, mnist_dl)

ParameterReshaper: 1280 parameters detected for optimization.


Epoch #0: 100%|███████████████| 656/656 [00:43<00:00, 14.97it/s, Loss=2.3025846]
Epoch #1: 100%|███████████████| 656/656 [00:39<00:00, 16.50it/s, Loss=2.3025846]
Epoch #2: 100%|███████████████| 656/656 [00:41<00:00, 15.81it/s, Loss=2.3025846]
Epoch #3: 100%|███████████████| 656/656 [00:38<00:00, 17.05it/s, Loss=2.3025846]
Epoch #4: 100%|███████████████| 656/656 [00:37<00:00, 17.56it/s, Loss=2.3025846]
Validate: 100%|█████| 281/281 [00:05<00:00, 48.05it/s, Accuracy=0.101, Loss=2.3]
Epoch #5: 100%|███████████████| 656/656 [00:38<00:00, 17.01it/s, Loss=2.3025846]
Epoch #6: 100%|███████████████| 656/656 [00:40<00:00, 16.20it/s, Loss=2.3025846]
Epoch #7: 100%|███████████████| 656/656 [00:35<00:00, 18.51it/s, Loss=2.3025846]
Epoch #8: 100%|███████████████| 656/656 [00:35<00:00, 18.52it/s, Loss=2.3025846]
Epoch #9: 100%|███████████████| 656/656 [00:39<00:00, 16.72it/s, Loss=2.3025846]
Validate: 100%|█████| 281/281 [00:05<00:00, 48.07it/s, Accuracy=0.101, Loss=2.3]
Epoch #10: 100%|████████████

Yikes... Looks like 717,200 parameters is too much for neuroevolution to handle!

### Attempt # 2

In [None]:
def mnist_scnn(x):
    # seqs is [T, F].
    core = hk.DeepRNN([
        hk.MaxPool(3, 3, "SAME"),
        hk.DepthwiseConv2D(8, 3, with_bias=False),
        hk.Flatten(),
        hk.Linear(64, with_bias=False),
        snn.LIF(64, activation=spyx.activation.SuperSpike()),
        hk.Linear(10, with_bias=False),
        snn.LI(10)
    ])
    spikes, V = hk.dynamic_unroll(core, x.astype(jnp.float32), core.initial_state(x.shape[0]), time_major=False)
    return spikes, V

In [None]:
key = jax.random.PRNGKey(0)
SNN = hk.without_apply_rng(hk.transform(mnist_scnn))
params = SNN.init(rng=key, x=mnist_dl.train_step().obs)

In [None]:
print(hk.experimental.tabulate(SNN)(mnist_dl.train_step().obs))

In [None]:
elite_params = evolution(SNN, params, mnist_dl)