# Training an SNN using surrogate gradients!

Train your first SNN in JAX in less than 10 minutes without needing a heavy-duty GPU!

In [1]:
import spyx
import spyx.nn as snn

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".70"
from jax import numpy as jnp
import jmp
import numpy as np

from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# for surrogate loss training.
import optax

# rendering tools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
%matplotlib notebook

  warn(


## Set Mixed Precision Policy

In [2]:
from jax import tree_util as tree

In [3]:
class l2_reg:

    def __init__(self, target_rate, tolerance, time_steps, num_classes):
        #                          spikes  per  expected number of samples
        self.rate_map = lambda x: (jnp.sum(x, axis=0) / num_classes) / time_steps
        self.sq_err_map = lambda x: optax.squared_error(x, jnp.array([target_rate]*x.size))
        self.clip = lambda x: jnp.maximum(0, (x/tolerance) - tolerance)
    
    def __call__(self, spikes):
        avg_neuron_activity = tree.tree_map(self.rate_map, spikes)
        activity_error = tree.tree_map(self.sq_err_map, avg_neuron_activity)
        clipped_error = tree.tree_map(self.clip, activity_error)
        return jnp.mean(jnp.concatenate(tree.tree_flatten(clipped_error)[0]))
    
    
        

In [4]:
class l1_reg:
    def __init__(self, target_rate, tolerance, time_steps, num_classes):
        self.l1_loss = lambda x: jnp.abs(jnp.sum(x,axis=1)/time_steps - (x.shape[1]/num_classes)*target_rate)
        self.clip = lambda x: jnp.maximum(0, x - tolerance)
        
    def __call__(self, spikes):
        loss_vectors = tree.tree_map(self.l1_loss, spikes)
        clipped_error = tree.tree_map(self.clip, loss_vectors)
        return jnp.mean(jnp.concatenate(tree.tree_flatten(clipped_error)[0]))
        # should return a scalar

In [5]:
class lasso:
    def __init__(self, target_rate, tolerance, time_steps, num_classes):
        self.l1 = l1_reg(target_rate, tolerance, time_steps, num_classes)
        self.l2 = l2_reg(target_rate, tolerance, time_steps, num_classes)
        
    def __call__(self, spikes):
        return self.l1(spikes) + self.l2(spikes)

In [6]:
class ActivityRegularization(hk.Module):
    """
    Add state to the SNN to track the average number of spikes emitted per neuron per batch.

    Adding this to a network requires using the Haiku transform_with_state transform, which will also return an initial regularization state vector.
    This blank initial vector can be reused and is provided as the second arg to the SNN's apply function. 
    """

    def __init__(self, name="ActReg"):
        super().__init__(name=name)
        
    def __call__(self, spikes):
        spike_count = hk.get_state("spike_count", spikes.shape, init=jnp.zeros, dtype=spikes.dtype)
        hk.set_state("spike_count", spike_count + spikes) #maybe wrong????
        return spikes

In [7]:
policy = jmp.get_policy('half')

hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(ActivityRegularization, policy)
hk.mixed_precision.set_policy(snn.LIF, policy)
hk.mixed_precision.set_policy(snn.LI, policy)

## Data Loading

In [8]:
shd_dl = spyx.data.SHD_loader(256,100,350)

In [9]:
x = shd_dl.train_step().obs

## SNN

Here we define a simple feed-forward SNN using Haiku's RNN features, incorporating our
LIF neuron models where activation functions would usually go. Haiku manages all of the state for us, so when we transform the function and get an apply() function we just need to pass the params!

Since spiking neurons have a discrete all-or-nothing activation, in order to do gradient descent we'll have to approximate the derivative of the Heaviside function with something smoother. In this case, we use the SuperSpike surrogate gradient from Zenke & Ganguli 2017.
Also not that we aren't using bias terms on the linear layers and since the inputs are images, we flatten the data before feeding it to the first layer.

Depending on computational constraints, we can use haiku's dynamic unroll to iterate the SNN, or we can use static unroll where the SNN will be unrolled during the JIT compiling process to further increase speed when training on GPU. Note that the static unroll will take longer to compile, but once it runs the iterations per second will be 2x-3x greater than the dynamic unroll.

In [10]:
def shd_snn(x):
    core = hk.DeepRNN([
        hk.Linear(128, with_bias=False),
        snn.RLIF((128,), activation=spyx.activation.SuperSpike()),
        ActivityRegularization(),
        hk.Linear(128, with_bias=False),
        snn.RLIF((128,), activation=spyx.activation.SuperSpike()),
        ActivityRegularization(),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x.astype(jnp.float32), core.initial_state(x.shape[0]), time_major=False, unroll=10)
    return spikes, V

In [11]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN = hk.without_apply_rng(hk.transform_with_state(shd_snn))

params, reg_init = SNN.init(rng=key, x=shd_dl.train_step().obs)

## Gradient Descent

We define a training loop below.

We use the Lion optimizer from Optax, which is a more efficient competitor to the popular Adam. The eval steps and updates are JIT'ed to maximize time spent in optimized GPU code and minimize time spent in higher-level python.

The use of regularizers in the spiking network will be covered in a seperate tutorial.

In [12]:
def gd(SNN, params, dl, epochs=50, test_every=1):
    
    aug = spyx.data.shift_augment()
    
    regularizer = lasso(0.4, .25, 100, 20)
    
    # create and initialize the optimizer
    schedule = optax.exponential_decay(
        init_value=6e-4,
        transition_steps=250,
        decay_rate=0.99,
        end_value=2e-4,
    )
    

    opt = optax.chain(
        # added after training
        optax.adaptive_grad_clip(1.0),
        optax.lion(learning_rate=schedule, weight_decay=0.01),
        optax.zero_nans()
    )
        
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout, spike_counts = SNN.apply(weights, reg_init, events)
        traces, V_f = readout
        xe_loss = spyx.loss.integral_crossentropy(traces, targets, smoothing=0.1)
        return xe_loss + 5*regularizer(spike_counts)
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    # compile the meat of our training loop for speed
    @jax.jit
    def step(grad_params, opt_state, events, targets):
        # compute loss and gradient
        loss, grads = surrogate_grad(grad_params, events, targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        return optax.apply_updates(grad_params, updates), opt_state, loss
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params, events, targets):
        readout, spike_count = SNN.apply(grad_params, reg_init, events)
        traces, V_f = readout
        acc, pred = spyx.loss.integral_accuracy(traces, targets)
        loss = spyx.loss.integral_crossentropy(traces, targets)
        return acc, pred, loss
        
    # Here's the start of our training loop!
    for gen in range(epochs):
        # reset our training data loader so we're at the beginning of the train set
        # important to reset loader before making pbar so that the len of the dataset is correct.
        dl.train_reset()
        # make a progress bar with tqdm so things look official
        pbar = tqdm([*range(dl.train_len//dl.batch_size)])
        pbar.set_description("Epoch #{}".format(gen))

        for _ in pbar:
            # fetch the batch and the labels
            events, targets = dl.train_step()
            events = aug(events)
            # compute new params and loss
            grad_params, opt_state, loss = step(grad_params, opt_state, events, targets)
            #update progress bar
            pbar.set_postfix(Loss=loss)
            
        # after a number of epochs, check performance on validation set
        if gen % test_every == test_every-1:
            # reset validation iterator
            dl.val_reset()
            
            # containers for SNN results. Can return these if desired.
            accs = []
            preds = []
            losses = []
            
            # progress bars!
            pbar = tqdm([*range(dl.val_len//dl.batch_size)])
            pbar.set_description("Validating")
            for _ in pbar:
                # get validation batch
                events, targets = dl.val_step()
                # get perfomance on validation batch
                acc, pred, loss = eval_step(grad_params, events, targets)
                # save accuracy, prediction, loss
                accs.append(acc)
                preds.append(pred)
                losses.append(loss)
                # update progress bar, showing running loss and accuracy
                pbar.set_postfix(Loss=np.mean(losses), Accuracy=np.mean(accs))
                
    # return our final, optimized network.       
    return grad_params

In [13]:
def test_gd(SNN, params, dl):
    @jax.jit
    def net_eval(weights, events, targets):
        readout, spike_count = SNN.apply(weights, events)
        traces, V_f = readout
        return spyx.loss.integral_crossentropy(traces, targets)
    
    @jax.jit
    def eval_step(grad_params, events, targets):
        readout, spike_count = SNN.apply(grad_params, reg_init, events)
        traces, V_f = readout
        acc, pred = spyx.loss.integral_accuracy(traces, targets)
        loss = spyx.loss.integral_crossentropy(traces, targets)
        return acc, pred, loss
    
    dl.test_reset()
    accs = []
    preds = []
    tgts = []
    losses = []
    pbar = tqdm([*range(dl.test_len//dl.batch_size)])
    pbar.set_description("Validating")
    for _ in pbar:
        events, targets = dl.test_step()
        
        acc, pred, loss = eval_step(grad_params, events, targets)
        
        accs.append(acc)
        preds.append(pred)
        tgts.append(targets)
        losses.append(loss)
        
        pbar.set_postfix(Loss=np.mean(losses), Accuracy=np.mean(accs))
    
    tgts = jnp.array(tgts).flatten()
    preds = jnp.array(preds).flatten()
    return accs, preds, tgts, losses

## Training Time

We'll train the network for 50 epochs since SHD is more difficult than MNIST.

The SHD dataloader for Spyx has built-in leave-one-group-out cross validation. This is becuase the test set for SHD has two unseen speakers, so when we train our model we need to make it robust to speakers it isn't training on in the hopes of improving generalization accuracy.

In [14]:
grad_params = gd(SNN, params, shd_dl)

Epoch #0: 100%|█████████████████| 28/28 [00:28<00:00,  1.02s/it, Loss=11.417565]
Validating: 100%|█████████| 3/3 [00:02<00:00,  1.10it/s, Accuracy=0.082, Loss=4]
Epoch #1: 100%|█████████████████| 27/27 [00:16<00:00,  1.60it/s, Loss=10.771423]
Validating: 100%|██████| 3/3 [00:04<00:00,  1.52s/it, Accuracy=0.137, Loss=3.57]
Epoch #2: 100%|██████████████████| 29/29 [00:26<00:00,  1.08it/s, Loss=7.703622]
Validating: 100%|███████| 2/2 [00:02<00:00,  1.25s/it, Accuracy=0.15, Loss=5.18]
Epoch #3: 100%|█████████████████| 30/30 [00:27<00:00,  1.08it/s, Loss=6.4249244]
Validating: 100%|███████| 1/1 [00:00<00:00,  1.14it/s, Accuracy=0.32, Loss=4.91]
Epoch #4: 100%|█████████████████| 28/28 [00:26<00:00,  1.05it/s, Loss=5.6477027]
Validating: 100%|██████| 3/3 [00:03<00:00,  1.08s/it, Accuracy=0.199, Loss=4.27]
Epoch #5: 100%|█████████████████| 28/28 [00:27<00:00,  1.01it/s, Loss=5.7807016]
Validating: 100%|██████| 3/3 [00:03<00:00,  1.28s/it, Accuracy=0.337, Loss=4.18]
Epoch #6: 100%|█████████████

In [27]:
grad_params2 = gd(SNN, grad_params, shd_dl)

Epoch #0: 100%|██████████████████| 27/27 [00:18<00:00,  1.49it/s, Loss=2.505838]
Validating: 100%|██████| 3/3 [00:03<00:00,  1.09s/it, Accuracy=0.861, Loss=2.19]
Epoch #1: 100%|█████████████████| 29/29 [00:17<00:00,  1.70it/s, Loss=2.3373768]
Validating: 100%|██████| 2/2 [00:02<00:00,  1.08s/it, Accuracy=0.816, Loss=2.29]
Epoch #2: 100%|██████████████████| 30/30 [00:17<00:00,  1.74it/s, Loss=1.947675]
Validating: 100%|██████| 1/1 [00:00<00:00,  1.83it/s, Accuracy=0.891, Loss=2.39]
Epoch #3: 100%|█████████████████| 28/28 [00:16<00:00,  1.74it/s, Loss=2.1886528]
Validating: 100%|███████| 3/3 [00:02<00:00,  1.01it/s, Accuracy=0.695, Loss=2.4]
Epoch #4: 100%|█████████████████| 28/28 [00:25<00:00,  1.09it/s, Loss=1.8182712]
Validating: 100%|██████| 3/3 [00:03<00:00,  1.09s/it, Accuracy=0.691, Loss=2.34]
Epoch #5: 100%|█████████████████| 28/28 [00:16<00:00,  1.72it/s, Loss=1.9707742]
Validating: 100%|███████| 3/3 [00:01<00:00,  1.65it/s, Accuracy=0.678, Loss=2.6]
Epoch #6: 100%|█████████████

In [28]:
shd_dl.train_reset()

In [29]:
x = shd_dl.train_step().obs

readout, spks = SNN.apply(grad_params, reg_init, x)

In [30]:
derp = clipped_sq_err(0.4, .25, 100, 20)
yut = l1_reg(0.4, .25, 100, 20)

In [31]:
derp(spks)

Array(0.00594469, dtype=float32)

In [32]:
yut(spks)

Array(0.1218, dtype=float16)

In [33]:
yeet = jnp.sum(spks["ActReg"]["spike_count"], axis=0) # (neurons,)
yeet

Array([ 656.,  780.,  408.,  367.,  507.,  836., 1178.,  420.,  353.,
       1223.,  410.,  383.,  427.,  419., 1277.,  478.,  302.,  415.,
       1193.,  372.,  425.,  352.,  377.,  328.,  417.,  316.,  606.,
        866.,  308.,  344.,  452.,  678.,  360.,  360.,  357.,  369.,
        314.,  788.,  275.,  585.,  578.,  306.,  283.,  458., 1189.,
        244.,  427.,  319.,  359.,  385.,  545.,  773.,  408.,  358.,
       1152.,  310.,  344.,  429.,  309., 1210.,   10.,  423.,  447.,
        324.,  274.,  338.,  433.,  335.,  487.,  354.,  401.,  254.,
        427.,  375.,  341.,  392.,  789.,  517., 1143.,  305.,  397.,
        420.,  387.,  890.,  401.,  345.,  812.,  427.,  317.,  397.,
        338.,  236.,  676.,  497.,  286.,  307.,  373.,  313.,  428.,
        652.,  638.,  562.,  531.,  222.,  215.,  406.,  393.,  355.,
        599.,  819., 1114.,  516.,  394.,  426.,  327.,  398., 1388.,
        402.,  308.,  388.,  294.,  401.,  313.,  395.,  608.,  565.,
        254.,  297.]

In [34]:
jnp.sum(spks["ActReg"]["spike_count"], axis=1) # (n_samples,)

Array([219., 210., 246., 260., 196., 189., 259., 294., 336., 237., 245.,
       251., 195., 209., 196., 227., 254., 257., 209., 234., 250., 270.,
       280., 216., 242., 251., 149., 321., 365., 200., 247., 243., 242.,
       257., 213., 214., 294., 162., 268., 220., 358., 241., 235., 159.,
       266., 244., 256., 285., 269., 234., 283., 215., 216., 178., 222.,
       274., 238., 218., 206., 246., 244., 201., 217., 260., 235., 247.,
       308., 255., 267., 265., 220., 283., 308., 281., 272., 268., 293.,
       269., 255., 249., 225., 224., 204., 187., 265., 214., 203., 218.,
       244., 247., 261., 261., 298., 210., 284., 307., 232., 217., 240.,
       163., 251., 258., 266., 264., 223., 215., 214., 232., 241., 270.,
       168., 276., 284., 212., 217., 164., 265., 262., 289., 147., 225.,
       255., 226., 151., 317., 253., 261., 210., 261., 271., 236., 257.,
       222., 249., 233., 305., 260., 316., 189., 218., 251., 203., 202.,
       244., 276., 242., 241., 341., 276., 216., 20

In [35]:
grad_params

{'RLIF': {'b': Array([-0.0944717 ,  0.99519026,  0.31260842, -0.10548417, -0.32441178,
         -0.40041173, -0.6967001 , -0.26890615, -0.06818847,  0.30054495,
         -0.32884738, -0.21742369, -0.21318726, -0.11800125,  0.10058511,
          0.2807367 , -0.5804435 , -0.25983083,  0.7259979 , -0.24318625,
          0.04920679, -0.12058985, -0.08122146, -0.01540718, -0.38196623,
         -0.06906015, -0.3503448 ,  0.36554062, -0.18387362, -0.6539099 ,
         -0.13762917, -0.15069705, -0.3129064 , -0.45684862, -0.3761815 ,
          0.11548117, -0.5932166 , -0.38878626,  0.22265586, -0.3929718 ,
          0.0575144 , -0.2874209 , -0.03142158, -0.3582539 ,  0.6709823 ,
          0.14946057,  0.07277688,  0.11376859,  0.10186752, -0.08696722,
          0.05933382, -0.06813437, -0.31672868, -0.14652406,  0.82735276,
          0.1369703 , -0.01241367, -0.16617315,  0.03167582,  0.7404069 ,
         -0.27163747, -0.22815987,  0.02088248, -0.10537993, -0.21954314,
         -0.29780814, -0.

In [36]:
spks["ActReg"]["spike_count"][0]

Array([12.,  1.,  0.,  4., 10.,  4.,  0.,  0.,  2.,  0.,  1.,  6.,  1.,
        2.,  0.,  2.,  0.,  3.,  1.,  0.,  2.,  0.,  0.,  0.,  2.,  0.,
        0.,  0.,  1.,  0.,  0.,  0.,  4.,  1.,  0.,  2.,  1.,  0.,  5.,
        2.,  3.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  0.,  1.,  0.,
        0.,  0.,  0., 11.,  0.,  4.,  0.,  0.,  0.,  1.,  0.,  0.,  6.,
        1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  3.,  3.,  3.,
        1.,  3.,  0.,  2.,  3.,  5.,  0.,  0.,  6.,  0.,  0.,  2.,  4.,
        1.,  7.,  5.,  1.,  0.,  3.,  2.,  0.,  1.,  0.,  8.,  0.,  5.,
        0.,  3.,  3.,  1.,  0.,  0.,  1.,  5.,  3.,  7.,  0.,  4.,  2.,
        3.,  0.,  1.,  2.,  0.,  1.,  0.,  0.,  5.,  2.,  7.],      dtype=float16)

## Evaluation Time

Now we'll run the network on the test set and see what happens:

In [37]:
acc, preds, tgts, losses = test_gd(SNN, grad_params, shd_dl)

Validating: 100%|██████| 8/8 [00:04<00:00,  1.72it/s, Accuracy=0.645, Loss=2.65]


In [38]:
cm = confusion_matrix(tgts, preds)
ConfusionMatrixDisplay(cm).plot()
plt.show()

<IPython.core.display.Javascript object>