# Training an SNN using surrogate gradients!

Train your first SNN in JAX in less than 10 minutes without needing a heavy-duty GPU!

In [1]:
import spyx
import spyx.nn as snn

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".80"
from jax import numpy as jnp
import jmp
import numpy as np

from jax_tqdm import scan_tqdm
from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# for surrogate loss training.
import optax

# rendering tools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
%matplotlib notebook

## Set Mixed Precision Policy

In [2]:
policy = jmp.get_policy('half')


hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(snn.LIF, policy)
hk.mixed_precision.set_policy(snn.LI, policy)

## Data Loading

In [3]:
shd_dl = spyx.data.SHD_loader(256,128,256)

In [4]:
key = jax.random.PRNGKey(0)
x, y = shd_dl.train_epoch(key)

In [5]:
y.shape

(25, 256)

## SNN

Here we define a simple feed-forward SNN using Haiku's RNN features, incorporating our
LIF neuron models where activation functions would usually go. Haiku manages all of the state for us, so when we transform the function and get an apply() function we just need to pass the params!

Since spiking neurons have a discrete all-or-nothing activation, in order to do gradient descent we'll have to approximate the derivative of the Heaviside function with something smoother. In this case, we use the SuperSpike surrogate gradient from Zenke & Ganguli 2017.
Also not that we aren't using bias terms on the linear layers and since the inputs are images, we flatten the data before feeding it to the first layer.

Depending on computational constraints, we can use haiku's dynamic unroll to iterate the SNN, or we can use static unroll where the SNN will be unrolled during the JIT compiling process to further increase speed when training on GPU. Note that the static unroll will take longer to compile, but once it runs the iterations per second will be 2x-3x greater than the dynamic unroll.

In [6]:
def arctan_snn(x):
    
    x = hk.BatchApply(hk.Linear(128, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.arctan())),
        hk.Linear(128, with_bias=False),
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.arctan())),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=10)
    
    return spikes, V

In [7]:
def superspike_snn(x):
    
    x = hk.BatchApply(hk.Linear(128, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.superspike())),
        hk.Linear(128, with_bias=False),
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.superspike())),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=10)
    
    return spikes, V

In [8]:
def tanh_snn(x):
    
    x = hk.BatchApply(hk.Linear(128, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.tanh())),
        hk.Linear(128, with_bias=False),
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.tanh())),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=10)
    
    return spikes, V

In [9]:
def sigmoid_snn(x):
    
    x = hk.BatchApply(hk.Linear(128, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.sigmoid())),
        hk.Linear(128, with_bias=False),
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.sigmoid())),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=10)
    
    return spikes, V

In [10]:
def boxcar_snn(x):
    
    x = hk.BatchApply(hk.Linear(128, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.boxcar())),
        hk.Linear(128, with_bias=False),
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.boxcar())),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=10)
    
    return spikes, V

In [11]:
def triangular_snn(x):
    
    x = hk.BatchApply(hk.Linear(128, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.triangular())),
        hk.Linear(128, with_bias=False),
        snn.LIF((128,), activation=spyx.axn.Axon(spyx.axn.triangular())),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=10)
    
    return spikes, V

In [12]:
def ste_snn(x):
    
    x = hk.BatchApply(hk.Linear(128, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((128,)),
        hk.Linear(128, with_bias=False),
        snn.LIF((128,)),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=10)
    
    return spikes, V

In [13]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
arctan_SNN = hk.without_apply_rng(hk.transform(arctan_snn))
arctan_params = arctan_SNN.init(rng=key, x=x[0])

In [14]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
superspike_SNN = hk.without_apply_rng(hk.transform(superspike_snn))
superspike_params = superspike_SNN.init(rng=key, x=x[0])

In [15]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
tanh_SNN = hk.without_apply_rng(hk.transform(tanh_snn))
tanh_params = tanh_SNN.init(rng=key, x=x[0])

In [16]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
sigmoid_SNN = hk.without_apply_rng(hk.transform(sigmoid_snn))
sigmoid_params = sigmoid_SNN.init(rng=key, x=x[0])

In [17]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
boxcar_SNN = hk.without_apply_rng(hk.transform(boxcar_snn))
boxcar_params = boxcar_SNN.init(rng=key, x=x[0])

In [18]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
triangular_SNN = hk.without_apply_rng(hk.transform(triangular_snn))
triangular_params = triangular_SNN.init(rng=key, x=x[0])

In [19]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
ste_SNN = hk.without_apply_rng(hk.transform(ste_snn))
ste_params = ste_SNN.init(rng=key, x=x[0])

## Gradient Descent

We define a training loop below.

We use the Lion optimizer from Optax, which is a more efficient competitor to the popular Adam. The eval steps and updates are JIT'ed to maximize time spent in optimized GPU code and minimize time spent in higher-level python.

The use of regularizers in the spiking network will be covered in a seperate tutorial.

In [20]:
def gd(SNN, params, dl, epochs=500, schedule=3e-4):
    
    aug = spyx.data.shift_augment(max_shift=16) 
    
    
    opt = optax.chain(
        optax.centralize(),
        optax.lion(learning_rate=schedule)
    )
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        return spyx.fn.integral_crossentropy(traces, targets, smoothing=0.3)
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = jax.random.PRNGKey(0)        
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        # compute loss and gradient                    # need better augment rng
        loss, grads = surrogate_grad(grad_params, aug(events, jax.random.fold_in(rng,jnp.sum(targets))), targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params, data):
        events, targets = data # fix
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(grad_params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        loss = spyx.fn.integral_crossentropy(traces, targets, smoothing=0.3)
        return grad_params, jnp.array([acc, loss])
        
    
    val_data = dl.val_epoch()
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state
        
        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = dl.train_epoch(shuffle_rng)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
        
        new_params, _ = end_state
            
        # val epoch
        _, val_metrics = jax.lax.scan(
            eval_step,# func
            new_params,# init
            val_data,# xs
            val_data.obs.shape[0]# len
        )

        
        return end_state, jnp.concatenate([jnp.expand_dims(jnp.mean(train_loss),0), jnp.mean(val_metrics, axis=0)])
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [grad_params, opt_state], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    final_params, _ = final_state
    
                
    # return our final, optimized network.       
    return final_params, metrics

In [21]:
def test_gd(SNN, params, dl):

    @jax.jit
    def test_step(params, data):
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        loss = spyx.fn.integral_crossentropy(traces, targets, smoothing=0.3)
        return params, [acc, loss, pred, targets]
    
    test_data = dl.test_epoch()
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            params,# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    loss = jnp.mean(test_metrics[1])
    preds = jnp.array(test_metrics[2]).flatten()
    tgts = jnp.array(test_metrics[3]).flatten()
    return acc, loss, preds, tgts

## Training Time

## Arctan

In [22]:
arctan_grad_params, arctan_metrics = gd(arctan_SNN, arctan_params, 
                                        shd_dl, epochs=500, schedule=1e-4) 
# 1:08

  0%|          | 0/500 [00:00<?, ?it/s]

In [23]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*arctan_metrics[-1]))

Performance: train_loss=1.8070510625839233, val_acc=0.8802083730697632, val_loss=1.8233730792999268


In [30]:
plt.plot(arctan_metrics, label=["train loss", "val acc", "val loss"])
plt.title("LIF Neurons, Arctan Surrogate")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

# Superspike

First experiment was using the same learning rate schedule as arctan.

In [24]:
superspike_grad_params, superspike_metrics = gd(superspike_SNN,
                                     superspike_params,
                                     shd_dl, epochs=500,
                                     schedule=1e-4) #1:08

  0%|          | 0/500 [00:00<?, ?it/s]

In [25]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*superspike_metrics[-1]))

Performance: train_loss=1.8405778408050537, val_acc=0.890625, val_loss=1.8338732719421387


In [49]:
plt.plot(superspike_metrics, label=["train loss", "val acc", "val loss"])
plt.title("LIF Neurons, SuperSpike Surrogate")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

# Tanh

In [26]:
tanh_grad_params, tanh_metrics = gd(tanh_SNN, tanh_params, shd_dl, 500, 1.5e-4) #1:08

  0%|          | 0/500 [00:00<?, ?it/s]

In [27]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*tanh_metrics[-1]))

Performance: train_loss=1.8004976511001587, val_acc=0.8873698115348816, val_loss=1.8130784034729004


In [36]:
plt.plot(tanh_metrics, label=["train loss", "val acc", "val loss"])
plt.title("LIF Neurons, Tanh Surrogate")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

# Sigmoid

In [28]:
sigmoid_grad_params, sigmoid_metrics = gd(sigmoid_SNN, sigmoid_params, shd_dl, 500, 2.5e-4) #1:10

  0%|          | 0/500 [00:00<?, ?it/s]

In [29]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*sigmoid_metrics[-1]))

Performance: train_loss=1.8050484657287598, val_acc=0.884765625, val_loss=1.8104445934295654


In [39]:
plt.plot(sigmoid_metrics, label=["train loss", "val acc", "val loss"])
plt.title("LIF Neurons, Sigmoid Surrogate")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

# Boxcar

In [30]:
boxcar_grad_params, boxcar_metrics = gd(boxcar_SNN, boxcar_params, shd_dl, 500, 2e-4) #1:07

  0%|          | 0/500 [00:00<?, ?it/s]

In [31]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*boxcar_metrics[-1]))

Performance: train_loss=1.7967289686203003, val_acc=0.8802083730697632, val_loss=1.8076056241989136


In [42]:
plt.plot(boxcar_metrics, label=["train loss", "val acc", "val loss"])
plt.title("LIF Neurons, Boxcar Surrogate")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

# Triangular

In [32]:
triangular_grad_params, triangular_metrics = gd(triangular_SNN,
                                                triangular_params,
                                                shd_dl, 500, 1e-4) #1:07

  0%|          | 0/500 [00:00<?, ?it/s]

In [33]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*triangular_metrics[-1]))

Performance: train_loss=1.845432996749878, val_acc=0.8763021230697632, val_loss=1.844926118850708


In [45]:
plt.plot(triangular_metrics, label=["train loss", "val acc", "val loss"])
plt.title("LIF Neurons, Triangular Surrogate")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

# Straight Through Estimator

In [34]:
ste_grad_params, ste_metrics = gd(ste_SNN,
                                  ste_params,
                                  shd_dl, 500, 7e-5) #1:00

  0%|          | 0/500 [00:00<?, ?it/s]

In [35]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*ste_metrics[-1]))

Performance: train_loss=2.4279189109802246, val_acc=0.5084635615348816, val_loss=2.3556199073791504


In [36]:
plt.plot(ste_metrics, label=["train loss", "val acc", "val loss"])
plt.title("LIF Neurons, S.T.E. Surrogate")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

## Evaluation Time

Now we'll run the network on the test set and see what happens:

In [38]:
plt.plot(superspike_metrics[:,1], label="SuperSpike Val. Acc.")
plt.plot(arctan_metrics[:,1], label="Arctan Val. Acc.")
plt.plot(tanh_metrics[:,1], label="Tanh Val. Acc.")
plt.plot(sigmoid_metrics[:,1], label="Sigmoid Val. Acc.")
plt.plot(boxcar_metrics[:,1], label="Boxcar Val. Acc.")
plt.plot(triangular_metrics[:,1], label="Triangular Val. Acc.")
plt.plot(ste_metrics[:,1], label="S.T.E. Val. Acc.")
plt.xlabel("Epochs")
plt.ylabel("Validation Accuracy")
# need to fix line labeling to make easier to distinguish.
plt.title("LIF Neurons, Surrogate Gradient Comparison")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

In [39]:
thetas = [arctan_grad_params, tanh_grad_params, sigmoid_grad_params,
          superspike_grad_params, boxcar_grad_params, triangular_grad_params,
          ste_grad_params]

spike_nets = [arctan_SNN, tanh_SNN, sigmoid_SNN, superspike_SNN,
              boxcar_SNN, triangular_SNN, ste_SNN]

names = ["Arctan", "Tanh", "Sigmoid", "SuperSpike", "Boxcar", "Triangular", "S.T.E."]

for name, net, theta in zip(names, spike_nets, thetas):
    acc, loss, preds, tgts = test_gd(net, theta, shd_dl)
    print(name, "Accuracy:", acc, "Loss:", loss)

Arctan Accuracy: 0.7729492 Loss: 1.991092
Tanh Accuracy: 0.7753906 Loss: 1.995582
Sigmoid Accuracy: 0.796875 Loss: 1.9538465
SuperSpike Accuracy: 0.7626953 Loss: 2.0083747
Boxcar Accuracy: 0.7573242 Loss: 1.9939734
Triangular Accuracy: 0.74902344 Loss: 2.0298672
S.T.E. Accuracy: 0.43652344 Loss: 2.465756


In [58]:
acc, loss, preds, tgts = test_gd(sigmoid_SNN, sigmoid_grad_params, shd_dl)

In [60]:
cm = confusion_matrix(tgts, preds)
ConfusionMatrixDisplay(cm).plot()
plt.title("Sigmoid LIF Test Confusion Matrix")
plt.show()

<IPython.core.display.Javascript object>