# Training an SNN using surrogate gradients!



In [1]:
import jax.numpy as jnp
warmup = jnp.array([1,2,3])
warmup * 12.34

Array([12.34, 24.68, 37.02], dtype=float32, weak_type=True)

In [2]:
import spyx
import spyx.nn as snn

# JAX imports
import jax
from jax import numpy as jnp
import jmp
import numpy as np

from jax_tqdm import scan_tqdm
from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# for surrogate loss training.
import optax

# rendering tools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
%matplotlib notebook

## Set Mixed Precision Policy

In [3]:
policy = jmp.get_policy('half')


hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(snn.ALIF, policy)
hk.mixed_precision.set_policy(snn.LI, policy)

## Data Loading

In [4]:
shd_dl = spyx.data.SHD_loader(256,128,128)

In [5]:
key = jax.random.PRNGKey(0)
x, y = shd_dl.train_epoch(key)

In [6]:
y.shape

(25, 256)

## SNN


In [12]:
surrogate = spyx.axn.Axon(spyx.axn.arctan())

def snn_alif(x):
    
    x = hk.BatchApply(hk.Linear(64, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.ALIF((64,), activation=surrogate),
        hk.Linear(64, with_bias=False),
        snn.ALIF((64,), activation=surrogate),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=16)
    
    return spikes, V

## Gradient Descent

We define a training loop below.

We use the Lion optimizer from Optax, which is a more efficient competitor to the popular Adam. The eval steps and updates are JIT'ed to maximize time spent in optimized GPU code and minimize time spent in higher-level python.

The use of regularizers in the spiking network will be covered in a seperate tutorial.

In [8]:
def gd(SNN, params, dl, seed, epochs=300, schedule=4e-4):
    
    aug = spyx.data.shift_augment(max_shift=8) # need to make this stateless

    opt = optax.chain(
        optax.centralize(),
        optax.lion(learning_rate=schedule),
    )
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        return spyx.fn.integral_crossentropy(traces, targets)
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = seed        
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        # compute loss and gradient                    # need better augment rng
        loss, grads = surrogate_grad(grad_params, aug(events, jax.random.fold_in(rng,jnp.sum(targets))), targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params, data):
        events, targets = data # fix
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(grad_params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        loss = spyx.fn.integral_crossentropy(traces, targets)
        return grad_params, jnp.array([acc, loss])
        
    
    val_data = dl.val_epoch()
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state
        
        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = dl.train_epoch(shuffle_rng)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
        
        new_params, _ = end_state
            
        # val epoch
        _, val_metrics = jax.lax.scan(
            eval_step,# func
            new_params,# init
            val_data,# xs
            val_data.obs.shape[0]# len
        )

        
        return end_state, jnp.concatenate([jnp.expand_dims(jnp.mean(train_loss),0), jnp.mean(val_metrics, axis=0)])
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [grad_params, opt_state], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    final_params, _ = final_state
    
                
    # return our final, optimized network.       
    return final_params, metrics

In [9]:
def test_gd(SNN, params, dl):

    @jax.jit
    def test_step(params, data):
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        loss = spyx.fn.integral_crossentropy(traces, targets)
        return params, [acc, loss, pred, targets]
    
    test_data = dl.test_epoch()
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            params,# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    loss = jnp.mean(test_metrics[1])
    preds = jnp.array(test_metrics[2]).flatten()
    tgts = jnp.array(test_metrics[3]).flatten()
    return acc, loss, preds, tgts

## Training Time


In [10]:
from time import time

# Seed: 42

In [13]:
schedule = 2e-4

key = jax.random.PRNGKey(42)

# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN_alif = hk.without_apply_rng(hk.transform(snn_alif))
params_alif = SNN_alif.init(rng=key, x=x[0])

start = time()
grad_params_alif, metrics_alif = gd(SNN_alif, params_alif, shd_dl, key, epochs=500, schedule=schedule) # 1:09
elapsed = time() - start
print(elapsed)
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*metrics_alif[-1]))
acc, loss, preds, tgts = test_gd(SNN_alif, grad_params_alif, shd_dl)
print("Accuracy:", acc, "Loss:", loss)

  0%|          | 0/500 [00:00<?, ?it/s]

99.79465556144714
Performance: train_loss=2.001251220703125, val_acc=0.7584635615348816, val_loss=2.007822036743164
Accuracy: 0.68847656 Loss: 2.1133943


# Seed: 12345

In [14]:
schedule = 2e-4

key = jax.random.PRNGKey(12345)

# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN_alif = hk.without_apply_rng(hk.transform(snn_alif))
params_alif = SNN_alif.init(rng=key, x=x[0])

start = time()
grad_params_alif, metrics_alif = gd(SNN_alif, params_alif, shd_dl, key, epochs=500, schedule=schedule) # 1:09
elapsed = time() - start
print(elapsed)
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*metrics_alif[-1]))
acc, loss, preds, tgts = test_gd(SNN_alif, grad_params_alif, shd_dl)
print("Accuracy:", acc, "Loss:", loss)

  0%|          | 0/500 [00:00<?, ?it/s]

100.25066351890564
Performance: train_loss=2.056994676589966, val_acc=0.767578125, val_loss=2.009467601776123
Accuracy: 0.6660156 Loss: 2.147407


# Seed: 54321

In [15]:
schedule = 2e-4

key = jax.random.PRNGKey(54321)

# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN_alif = hk.without_apply_rng(hk.transform(snn_alif))
params_alif = SNN_alif.init(rng=key, x=x[0])

start = time()
grad_params_alif, metrics_alif = gd(SNN_alif, params_alif, shd_dl, key, epochs=500, schedule=schedule) # 1:09
elapsed = time() - start
print(elapsed)
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*metrics_alif[-1]))
acc, loss, preds, tgts = test_gd(SNN_alif, grad_params_alif, shd_dl)
print("Accuracy:", acc, "Loss:", loss)

  0%|          | 0/500 [00:00<?, ?it/s]

102.2600417137146
Performance: train_loss=2.046523094177246, val_acc=0.7584635615348816, val_loss=2.0117335319519043
Accuracy: 0.64941406 Loss: 2.167109


# Seed: 0

In [16]:
schedule = 2e-4

key = jax.random.PRNGKey(0)

# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN_alif = hk.without_apply_rng(hk.transform(snn_alif))
params_alif = SNN_alif.init(rng=key, x=x[0])

start = time()
grad_params_alif, metrics_alif = gd(SNN_alif, params_alif, shd_dl, key, epochs=500, schedule=schedule) # 1:09
elapsed = time() - start
print(elapsed)
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*metrics_alif[-1]))
acc, loss, preds, tgts = test_gd(SNN_alif, grad_params_alif, shd_dl)
print("Accuracy:", acc, "Loss:", loss)

  0%|          | 0/500 [00:00<?, ?it/s]

101.34287238121033
Performance: train_loss=2.002713203430176, val_acc=0.7747396230697632, val_loss=1.9932903051376343
Accuracy: 0.69433594 Loss: 2.099601


# Seed: 7

In [17]:
schedule = 2e-4

key = jax.random.PRNGKey(7)

# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN_alif = hk.without_apply_rng(hk.transform(snn_alif))
params_alif = SNN_alif.init(rng=key, x=x[0])

start = time()
grad_params_alif, metrics_alif = gd(SNN_alif, params_alif, shd_dl, key, epochs=500, schedule=schedule) # 1:09
elapsed = time() - start
print(elapsed)
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*metrics_alif[-1]))
acc, loss, preds, tgts = test_gd(SNN_alif, grad_params_alif, shd_dl)
print("Accuracy:", acc, "Loss:", loss)

  0%|          | 0/500 [00:00<?, ?it/s]

103.24864721298218
Performance: train_loss=1.9891881942749023, val_acc=0.7825521230697632, val_loss=1.9762778282165527
Accuracy: 0.73583984 Loss: 2.0580053
