# Training an SNN using evolution!

This notebook contains the experimental results for evolving SNNs for NMNIST.

In [1]:
import jax.numpy as jnp
warmup = jnp.array([1,2,3])
warmup * 12.34

Array([12.34, 24.68, 37.02], dtype=float32, weak_type=True)

In [2]:
import spyx
import spyx.nn as snn

# JAX imports
import jax
from jax import numpy as jnp
import jmp
import numpy as np

from jax_tqdm import scan_tqdm
from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# for surrogate loss training.
import evosax
from evosax.strategies import CR_FM_NES as CRFMNES
from evosax.strategies import LM_MA_ES as LMMAES

from evosax import FitnessShaper

# rendering tools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
%matplotlib notebook

## Set Mixed Precision Policy

In [3]:
policy = jmp.get_policy('half')


hk.mixed_precision.set_policy(hk.Flatten, policy)
hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(snn.LIF, policy)
hk.mixed_precision.set_policy(snn.LI, policy)

## Data Loading

In [4]:
nmnist_dl = spyx.data.NMNIST_loader(batch_size=256)

In [5]:
key = jax.random.PRNGKey(0)
x, y = nmnist_dl.train_epoch(key)

In [6]:
x.shape

(54, 256, 5, 2, 34, 34)

## SNN

A simple Feed Forward architecture is used to reduce computational demand compared to evaluating hundreds of convolutional SNNs in parallel.

In [7]:
def nmnist_snn(x):
    
    x = hk.BatchApply(hk.Flatten())(x)
    x = hk.BatchApply(hk.Linear(512, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((512,)),
        hk.Linear(10, with_bias=False),
        snn.LI((10,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=20)
    
    return spikes, V


In [8]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN = hk.without_apply_rng(hk.transform(nmnist_snn))
params = SNN.init(rng=key, x=x[0])

## evolution


In [9]:
def evo(params, dl, key, epochs=300): # rename this
    rng = key        
    
    
    param_reshaper = evosax.ParameterReshaper(params)
        

    fit_shaper = FitnessShaper(maximize=True)

    # Shape the evaluated fitness scores
    
    strategy = CRFMNES(popsize=256, 
                num_dims=param_reshaper.total_params,
                sigma_init=0.03
                )

    opt_state = strategy.initialize(rng)
    
            
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        acc, _ = spyx.fn.integral_accuracy(traces, targets)
        return acc
        
    sim_fn = jax.vmap(net_eval, (0, None, None))
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        old_state, rng = state
        rng, rng_ask = jax.random.split(rng)
        events, targets = data
        pop, evo_state = strategy.ask(rng_ask, old_state)
        population_params = param_reshaper.reshape(pop.astype(jnp.float16))
        # EVAL
        fit = sim_fn(population_params, jnp.unpackbits(events, axis=1), targets) 
        # TELL
        fit_shaped = fit_shaper.apply(pop, fit) 
        new_state = [strategy.tell(pop, fit_shaped, evo_state), rng]
        
        return new_state, fit
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params_rng, data):
        grad_params, rng = grad_params_rng
        events, targets = data # fix
        readout = SNN.apply(grad_params, jnp.unpackbits(events, axis=1))
        traces, V_f = readout
        acc, _ = spyx.fn.integral_accuracy(traces, targets)
        return [grad_params, rng], acc
        
    
    val_data = dl.val_epoch()
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(curr_opt_state, epoch_num):
        
        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = dl.train_epoch(shuffle_rng)
        
        # train epoch
        end_state, train_acc = jax.lax.scan(
            train_step,# func
            curr_opt_state,# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
        
        elite = param_reshaper.reshape(jnp.array([end_state[0].best_member]))
        new_params = jax.tree_util.tree_map(lambda x: x[0], elite)
                        
        # val epoch
        _rng, val_acc = jax.lax.scan(
            eval_step,# func
            [new_params, end_state[1]],# init
            val_data,# xs
            val_data.obs.shape[0]# len
        )

        _, new_rng = _rng
        
        return [end_state[0], new_rng], jnp.hstack([jnp.max(train_acc), jnp.mean(val_acc)])
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [opt_state, rng], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    elite = param_reshaper.reshape(jnp.array([final_state[0].best_member]))
    final_params = jax.tree_util.tree_map(lambda x: x[0], elite)
                
    # return our final, optimized network.       
    return final_params, metrics

In [10]:
def test_evo(params, dl):

    rng = jax.random.PRNGKey(0)
    
    @jax.jit
    def test_step(params_rng, data):
        params, rng = params_rng
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        return [params, rng], [acc, pred, targets ]
    
    test_data = dl.test_epoch()
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            [params, rng],# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    preds = jnp.array(test_metrics[1]).flatten()
    tgts = jnp.array(test_metrics[2]).flatten()
    return acc, preds, tgts

## Training Time



In [11]:
from time import time

In [12]:
key = jax.random.PRNGKey(0)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) 
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.9609375, val_acc=0.88671875 Elapsed Time: 418.15694999694824
Test Acc: 0.89152646


In [14]:
key = jax.random.PRNGKey(7)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) 
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.96484375, val_acc=0.9021739363670349 Elapsed Time: 415.0913519859314
Test Acc: 0.8958334


In [15]:
key = jax.random.PRNGKey(42)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) 
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.96875, val_acc=0.8904551863670349 Elapsed Time: 415.1407585144043
Test Acc: 0.8924279


In [16]:
key = jax.random.PRNGKey(12345)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) 
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.9609375, val_acc=0.8924932479858398 Elapsed Time: 415.57034158706665
Test Acc: 0.89703524


In [17]:
key = jax.random.PRNGKey(54321)

start = time()
evolved_params, metrics = evo(params, nmnist_dl, key, epochs=50) 
elapsed_time = time()-start
print("Performance: train_acc={}, val_acc={}".format(*metrics[-1]), "Elapsed Time: {}".format(elapsed_time))
acc, preds, tgts = test_evo(evolved_params, nmnist_dl)
print("Test Acc:",acc)

ParameterReshaper: 1189376 parameters detected for optimization.


  0%|          | 0/50 [00:00<?, ?it/s]

Performance: train_acc=0.96484375, val_acc=0.88892662525177 Elapsed Time: 414.9552752971649
Test Acc: 0.8954327


In [61]:
metrics

Array([[0.3046875 , 0.22894022],
       [0.5078125 , 0.42747962],
       [0.67578125, 0.6110734 ],
       [0.75390625, 0.67493206],
       [0.8046875 , 0.727072  ],
       [0.81640625, 0.7204484 ],
       [0.85546875, 0.7498302 ],
       [0.83984375, 0.7498302 ],
       [0.8515625 , 0.7498302 ],
       [0.8515625 , 0.7498302 ],
       [0.859375  , 0.7814199 ],
       [0.8828125 , 0.78396744],
       [0.8671875 , 0.78396744],
       [0.87890625, 0.78396744],
       [0.87109375, 0.78396744],
       [0.90625   , 0.84714675],
       [0.93359375, 0.8661685 ],
       [0.9375    , 0.8755095 ],
       [0.94140625, 0.8816237 ],
       [0.94140625, 0.8816237 ],
       [0.94921875, 0.86956525],
       [0.9453125 , 0.86956525],
       [0.953125  , 0.8806046 ],
       [0.95703125, 0.8873981 ],
       [0.96484375, 0.8828125 ],
       [0.96875   , 0.8868886 ],
       [0.95703125, 0.8868886 ],
       [0.9609375 , 0.8868886 ],
       [0.96484375, 0.8868886 ],
       [0.96484375, 0.8868886 ],
       [0.

In [18]:
from sklearn.metrics import confusion_matrix

confusion_matrix(tgts, preds)

array([[ 945,    0,    1,    5,    2,    6,    2,   11,    5,    3],
       [   0, 1091,    6,   12,    3,    4,    8,    1,    7,    1],
       [  23,   14,  845,   40,   30,    1,   23,   13,   23,   18],
       [   9,    3,   13,  916,    3,   20,    7,    5,   14,   19],
       [   4,    4,    3,    0,  922,    0,    9,    0,    4,   33],
       [  15,    0,    3,   40,    5,  752,   24,    8,   26,   18],
       [  18,    4,    3,    1,    9,   23,  887,    5,    3,    1],
       [   1,   17,   15,    3,    7,    2,    0,  915,    5,   62],
       [  17,   10,   13,   63,   11,   29,   12,   24,  738,   56],
       [  14,    5,    1,    9,   20,    8,    2,   16,    4,  929]])