# Training an SNN using evolution!

This notebook contains the LM-MA-ES experimental results on NMNIST

In [1]:
import jax.numpy as jnp
warmup = jnp.array([1,2,3])
warmup * 12.34

Array([12.34, 24.68, 37.02], dtype=float32, weak_type=True)

In [2]:
import spyx
import spyx.nn as snn

# JAX imports
import jax
from jax import numpy as jnp
import jmp
import numpy as np

from jax_tqdm import scan_tqdm
from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# for surrogate loss training.
import evosax
from evosax.strategies import CR_FM_NES as CRFMNES
from evosax.strategies import LM_MA_ES as LMMAES

from evosax import FitnessShaper

# rendering tools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
%matplotlib notebook

## Set Mixed Precision Policy

In [3]:
policy = jmp.get_policy('half')


hk.mixed_precision.set_policy(hk.Flatten, policy)
hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(snn.LIF, policy)
hk.mixed_precision.set_policy(snn.LI, policy)

## Data Loading

In [4]:
nmnist_dl = spyx.data.NMNIST_loader(batch_size=256)

In [5]:
key = jax.random.PRNGKey(0)
x, y = nmnist_dl.train_epoch(key)

In [6]:
x.shape

(54, 256, 5, 2, 34, 34)

## SNN

Here we define a simple feed-forward SNN using Haiku's RNN features, incorporating our
LIF neuron models where activation functions would usually go. Haiku manages all of the state for us, so when we transform the function and get an apply() function we just need to pass the params!

Since spiking neurons have a discrete all-or-nothing activation, in order to do gradient descent we'll have to approximate the derivative of the Heaviside function with something smoother. In this case, we use the SuperSpike surrogate gradient from Zenke & Ganguli 2017.
Also not that we aren't using bias terms on the linear layers and since the inputs are images, we flatten the data before feeding it to the first layer.

Depending on computational constraints, we can use haiku's dynamic unroll to iterate the SNN, or we can use static unroll where the SNN will be unrolled during the JIT compiling process to further increase speed when training on GPU. Note that the static unroll will take longer to compile, but once it runs the iterations per second will be 2x-3x greater than the dynamic unroll.

In [7]:
def nmnist_snn(x):
    
    x = hk.BatchApply(hk.Flatten())(x)
    x = hk.BatchApply(hk.Linear(512, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((512,)),
        hk.Linear(10, with_bias=False),
        snn.LI((10,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=20)
    
    return spikes, V


In [8]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN = hk.without_apply_rng(hk.transform(nmnist_snn))
params = SNN.init(rng=key, x=x[0])

## evolution


In [9]:
def evo(params, dl, key, epochs=300): # rename this
    rng = key        
    
    
    param_reshaper = evosax.ParameterReshaper(params)
        

    fit_shaper = FitnessShaper(maximize=True)

    # Shape the evaluated fitness scores
    
    strategy = LMMAES(popsize=256, 
                num_dims=param_reshaper.total_params,
                sigma_init=0.03
                )

    opt_state = strategy.initialize(rng)
    
            
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        acc, _ = spyx.fn.integral_accuracy(traces, targets)
        return acc
        
    sim_fn = jax.vmap(net_eval, (0, None, None))
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        old_state, rng = state
        rng, rng_ask = jax.random.split(rng)
        events, targets = data
        pop, evo_state = strategy.ask(rng_ask, old_state)
        population_params = param_reshaper.reshape(pop.astype(jnp.float16)) # this cast is bad...
        # EVAL
        fit = sim_fn(population_params, jnp.unpackbits(events, axis=1), targets) 
        # TELL
        fit_shaped = fit_shaper.apply(pop, fit) 
        new_state = [strategy.tell(pop, fit_shaped, evo_state), rng]
        
        return new_state, fit
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params_rng, data):
        grad_params, rng = grad_params_rng
        events, targets = data # fix
        readout = SNN.apply(grad_params, jnp.unpackbits(events, axis=1))
        traces, V_f = readout
        acc, _ = spyx.fn.integral_accuracy(traces, targets)
        return [grad_params, rng], acc
        
    
    val_data = dl.val_epoch()
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(curr_opt_state, epoch_num):
        
        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = dl.train_epoch(shuffle_rng)
        
        # train epoch
        end_state, train_acc = jax.lax.scan(
            train_step,# func
            curr_opt_state,# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
        
        elite = param_reshaper.reshape(jnp.array([end_state[0].best_member]))
        new_params = jax.tree_util.tree_map(lambda x: x[0], elite)
                        
        # val epoch
        _rng, val_acc = jax.lax.scan(
            eval_step,# func
            [new_params, end_state[1]],# init
            val_data,# xs
            val_data.obs.shape[0]# len
        )

        _, new_rng = _rng
        
        return [end_state[0], new_rng], jnp.hstack([jnp.max(train_acc), jnp.mean(val_acc)])
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [opt_state, rng], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    elite = param_reshaper.reshape(jnp.array([final_state[0].best_member]))
    final_params = jax.tree_util.tree_map(lambda x: x[0], elite)
                
    # return our final, optimized network.       
    return final_params, metrics

In [10]:
def test_evo(params, dl):

    rng = jax.random.PRNGKey(0)
    
    @jax.jit
    def test_step(params_rng, data):
        params, rng = params_rng
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        return [params, rng], [acc, pred, targets ]
    
    test_data = dl.test_epoch()
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            [params, rng],# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    preds = jnp.array(test_metrics[1]).flatten()
    tgts = jnp.array(test_metrics[2]).flatten()
    return acc, preds, tgts

## Training Time

We'll train the network for 50 epochs since SHD is more difficult than MNIST.

The SHD dataloader for Spyx has built-in leave-one-group-out cross validation. This is becuase the test set for SHD has two unseen speakers, so when we train our model we need to make it robust to speakers it isn't training on in the hopes of improving generalization accuracy.

In [11]:
from time import time

In [12]:
key = jax.random.PRNGKey(0)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) # 3:11
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.9609375, val_acc=0.893172562122345 Elapsed Time: 660.3495941162109
Test Acc: 0.8977364


In [13]:
key = jax.random.PRNGKey(7)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) # 3:11
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.95703125, val_acc=0.87941575050354 Elapsed Time: 655.1253998279572
Test Acc: 0.8839143


In [14]:
key = jax.random.PRNGKey(42)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) # 3:11
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.9609375, val_acc=0.890285313129425 Elapsed Time: 653.5114979743958
Test Acc: 0.88952327


In [15]:
key = jax.random.PRNGKey(12345)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) # 3:11
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.95703125, val_acc=0.89113450050354 Elapsed Time: 655.3028016090393
Test Acc: 0.89783657


In [16]:
key = jax.random.PRNGKey(54321)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) # 3:11
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.96484375, val_acc=0.88892662525177 Elapsed Time: 655.1524081230164
Test Acc: 0.89503205


In [17]:
metrics

Array([[0.4375    , 0.3947011 ],
       [0.67578125, 0.62788725],
       [0.80859375, 0.7236753 ],
       [0.85546875, 0.77632475],
       [0.8671875 , 0.79959244],
       [0.8828125 , 0.79432744],
       [0.89453125, 0.8289742 ],
       [0.8984375 , 0.8192935 ],
       [0.9140625 , 0.8439199 ],
       [0.921875  , 0.8350883 ],
       [0.92578125, 0.850034  ],
       [0.9296875 , 0.8580163 ],
       [0.9296875 , 0.8580163 ],
       [0.93359375, 0.8483356 ],
       [0.93359375, 0.8483356 ],
       [0.94140625, 0.8644701 ],
       [0.93359375, 0.8644701 ],
       [0.9453125 , 0.8496943 ],
       [0.94140625, 0.8496943 ],
       [0.94140625, 0.8496943 ],
       [0.9375    , 0.8496943 ],
       [0.94140625, 0.8496943 ],
       [0.9453125 , 0.8496943 ],
       [0.94140625, 0.8496943 ],
       [0.9453125 , 0.8496943 ],
       [0.94921875, 0.8731318 ],
       [0.953125  , 0.8717731 ],
       [0.953125  , 0.8717731 ],
       [0.94921875, 0.8717731 ],
       [0.94921875, 0.8717731 ],
       [0.

In [18]:
from sklearn.metrics import confusion_matrix

confusion_matrix(tgts, preds)

array([[ 923,    0,   14,    7,    3,    6,    7,    4,   11,    3],
       [   1, 1081,   11,    5,    2,    4,    7,    2,   19,    1],
       [  18,    1,  896,   30,   23,    3,    7,    9,   42,    2],
       [   0,    3,   20,  897,    1,   30,    5,    5,   35,   12],
       [   5,    8,    6,    1,  878,    3,    7,    2,   12,   57],
       [  12,    0,    8,   46,    4,  754,   16,    3,   38,   10],
       [  24,    3,   21,    5,   19,   21,  852,    1,   10,    1],
       [   5,   10,   33,   15,    6,    4,    0,  879,   16,   59],
       [   3,    5,   13,   28,    8,   26,    6,    8,  865,   10],
       [  13,    4,    7,   17,   16,   10,    0,    9,   21,  911]])