# Training an SNN using Neuroevolution!

In [4]:
import spyx
import spyx.nn as snn

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".70"
from jax import numpy as jnp
import numpy as np

from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# optimize the parameters using evosax
import evosax
from evosax.strategies import PersistentES as PES

# rendering tools
import matplotlib.pyplot as plt
%matplotlib notebook
import graphviz
import mediapy as media

## Data Loading

In [33]:
mnist_dl = spyx.data.MNIST_loader(64)

## SNN

In [3]:
def mnist_snn(x):
    # seqs is [T, F].
    core = hk.DeepRNN([
        hk.Flatten(),
        hk.Linear(784, with_bias=False),
        snn.LIF(784, activation=spyx.activation.SuperSpike()),
        hk.Linear(128, with_bias=False),
        snn.LIF(128, activation=spyx.activation.SuperSpike()),
        hk.Linear(10, with_bias=False),
        snn.LI(10)
    ])
    spikes, V = hk.dynamic_unroll(core, x.astype(jnp.float32), core.initial_state(x.shape[0]), time_major=False)
    return spikes, V

In [4]:
key = jax.random.PRNGKey(0)
SNN = hk.without_apply_rng(hk.transform(mnist_snn))
params = SNN.init(rng=key, x=mnist_dl.train_step().obs)

In [5]:
print(hk.experimental.tabulate(SNN)(mnist_dl.train_step().obs))

+-------------------------------------+----------------------------------------------------------------------------------+-----------------+---------------------------------------------------------+------------------------------------------------------+---------------+---------------+
| Module                              | Config                                                                           | Module params   | Input                                                   | Output                                               |   Param count |   Param bytes |
| deep_rnn (DeepRNN.initial_state)    | DeepRNN(                                                                         |                 | 64                                                      | (f16[64,784], f16[64,128], f32[64,10])               |           912 |       1.82 KB |
|                                     |     layers=[Flatten(),                                                           |                 |  

## Evolution

In [15]:
def evolution(SNN, params, dl, epochs=15, test_every=5, key=0):

    rng = jax.random.PRNGKey(key)
    param_reshaper = evosax.ParameterReshaper(params)
    

    strategy = PES(popsize=64,
                num_dims=param_reshaper.total_params,
                opt_name="clipup",
                )
    
    es_params = strategy.default_params
    es_params = es_params.replace(init_min= -1, init_max=1)
    # check the initialization here....
    state = strategy.initialize(rng)
    state = state.replace(mean=param_reshaper.flatten_single(params))

        
    sim_fn = jax.vmap(SNN.apply, (0, None)) #jit this outside the loop...
    acc_fn = jax.vmap(spyx.loss.integral_accuracy, (0, None))
    loss_fn = jax.vmap(spyx.loss.integral_crossentropy, (0, None))
    
    @jax.jit
    def step(rng, state, events, targets):
        rng, rng_ask = jax.random.split(rng, 2)
        # ASK
        pop, state = strategy.ask(rng_ask, state)
        population_params = param_reshaper.reshape(pop.astype(jnp.float32))
        # EVAL
        spikes, V = sim_fn(population_params, events)
        loss = loss_fn(spikes, targets)
        # TELL
        state = strategy.tell(pop, loss, state)        
        
        return rng, state, loss
    
    
    for gen in range(epochs):
        pbar = tqdm([*range(dl.train_len//dl.batch_size)])
        pbar.set_description("Epoch #{}".format(gen))
        dl.train_reset()
        for _ in pbar:
            events, targets = dl.train_step() # non-jittable...

            rng, state, loss = step(rng, state, events, targets)
            
            pbar.set_postfix(Loss=jnp.min(loss))
        
        elite = param_reshaper.reshape(jnp.array([state.best_member]))
        if gen % test_every == test_every-1:
            dl.val_reset()
            accs = []
            losses = []
            
            pbar = tqdm([*range(dl.val_len//dl.batch_size)])
            pbar.set_description("Validate")
            for _ in pbar:
                events, targets = dl.val_step()
                spikes, V = sim_fn(elite, events)
                acc, pred = acc_fn(spikes, targets)
                loss = loss_fn(spikes, targets)
                
                accs.append(acc)
                losses.append(loss)
                
                pbar.set_postfix(Loss=np.mean(losses), Accuracy=np.mean(accs))
        
    return jax.tree_util.tree_map(jnp.squeeze, elite)

In [9]:
elite_params = evolution(SNN, params, mnist_dl)

ParameterReshaper: 717200 parameters detected for optimization.


Epoch #0: 100%|█████████████████████████████████████████| 656/656 [03:14<00:00,  3.37it/s, Loss=3.101177]
Epoch #1: 100%|████████████████████████████████████████| 656/656 [03:12<00:00,  3.42it/s, Loss=2.7998104]
Epoch #2: 100%|████████████████████████████████████████| 656/656 [03:12<00:00,  3.41it/s, Loss=2.6480782]
Epoch #3: 100%|█████████████████████████████████████████| 656/656 [03:12<00:00,  3.40it/s, Loss=2.570093]
Epoch #4: 100%|█████████████████████████████████████████| 656/656 [03:10<00:00,  3.44it/s, Loss=2.508956]
Validate: 100%|████████████████████████████| 281/281 [01:07<00:00,  4.16it/s, Accuracy=0.0895, Loss=2.84]
Epoch #5: 100%|████████████████████████████████████████| 656/656 [03:12<00:00,  3.40it/s, Loss=2.3636222]
Epoch #6: 100%|████████████████████████████████████████| 656/656 [03:12<00:00,  3.41it/s, Loss=2.3868988]
Epoch #7: 100%|████████████████████████████████████████| 656/656 [03:12<00:00,  3.41it/s, Loss=2.3733869]
Epoch #8: 100%|███████████████████████████████

Yikes... Looks like 717,200 parameters is too much for neuroevolution to handle!

### Attempt # 2

In [38]:
def mnist_scnn(x):
    # seqs is [T, F].
    core = hk.DeepRNN([
        hk.MaxPool(3, 3, "SAME"),
        hk.DepthwiseConv2D(8, 3, with_bias=False),
        hk.Flatten(),
        hk.Linear(64, with_bias=False),
        snn.LIF(64, activation=spyx.activation.SuperSpike()),
        hk.Linear(10, with_bias=False),
        snn.LI(10)
    ])
    spikes, V = hk.dynamic_unroll(core, x.astype(jnp.float32), core.initial_state(x.shape[0]), time_major=False)
    return spikes, V

In [39]:
key = jax.random.PRNGKey(0)
SNN = hk.without_apply_rng(hk.transform(mnist_scnn))
params = SNN.init(rng=key, x=mnist_dl.train_step().obs)

In [40]:
print(hk.experimental.tabulate(SNN)(mnist_dl.train_step().obs))

+-------------------------------------+-------------------------------------------------------------------------------------+-----------------+-------------------------------------------+----------------------------------------+---------------+---------------+
| Module                              | Config                                                                              | Module params   | Input                                     | Output                                 |   Param count |   Param bytes |
| deep_rnn (DeepRNN.initial_state)    | DeepRNN(                                                                            |                 | 64                                        | (f16[64,64], f32[64,10])               |            64 |      128.00 B |
|                                     |     layers=[MaxPool(window_shape=2, strides=2, padding='SAME'),                     |                 |                                           |                              

In [41]:
elite_params = evolution(SNN, params, mnist_dl)

ParameterReshaper: 25864 parameters detected for optimization.


Epoch #0: 100%|████████████████████████████████████████| 656/656 [00:20<00:00, 32.18it/s, Loss=6.8789053]
Epoch #1: 100%|█████████████████████████████████████████| 656/656 [00:19<00:00, 32.85it/s, Loss=3.243165]
Epoch #2: 100%|████████████████████████████████████████| 656/656 [00:20<00:00, 31.69it/s, Loss=2.3815699]
Epoch #3: 100%|████████████████████████████████████████| 656/656 [00:24<00:00, 27.17it/s, Loss=2.2731974]
Epoch #4: 100%|████████████████████████████████████████| 656/656 [00:20<00:00, 32.46it/s, Loss=2.2836933]
Validate: 100%|████████████████████████████| 281/281 [00:52<00:00,  5.35it/s, Accuracy=0.0931, Loss=2.45]
Epoch #5: 100%|█████████████████████████████████████████| 656/656 [00:19<00:00, 32.84it/s, Loss=2.277646]
Epoch #6: 100%|█████████████████████████████████████████| 656/656 [00:20<00:00, 32.45it/s, Loss=2.288693]
Epoch #7: 100%|█████████████████████████████████████████| 656/656 [00:20<00:00, 32.34it/s, Loss=2.283051]
Epoch #8: 100%|███████████████████████████████