# Training an SNN using surrogate gradients!

Train your first SNN in JAX in less than 10 minutes without needing a heavy-duty GPU!

In [1]:
import jax.numpy as jnp
warmup = jnp.array([1,2,3])
warmup * 12.34

Array([12.34, 24.68, 37.02], dtype=float32, weak_type=True)

In [2]:
import spyx
import spyx.nn as snn

# JAX imports

import jax
from jax import numpy as jnp
import jmp
import numpy as np

from jax_tqdm import scan_tqdm
from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# for surrogate loss training.
import evosax
from evosax.strategies import CR_FM_NES as CRFMNES

from evosax import FitnessShaper

# rendering tools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
%matplotlib notebook

## Set Mixed Precision Policy

In [3]:
policy = jmp.get_policy('half')


hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(snn.ALIF, policy)
hk.mixed_precision.set_policy(snn.LIF, policy)

## Data Loading

In [4]:
shd_dl = spyx.data.SHD_loader(256,128,128)

In [5]:
key = jax.random.PRNGKey(0)
x, y = shd_dl.train_epoch(key)

In [None]:
plt.imshow(np.unpackbits(x[0][69], axis=0).T)
plt.show()

In [6]:
x.shape

(25, 256, 16, 128)

In [7]:
y.shape

(25, 256)

## SNN



In [8]:
def snn_alif(x):
    
    x = hk.BatchApply(hk.Linear(64, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.ALIF((64,)),
        hk.Linear(64, with_bias=False),
        snn.ALIF((64,)),
        hk.Linear(20, with_bias=False),
        snn.LIF((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]),
                                  time_major=False, unroll=16)
    
    return spikes, V

In [9]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN_alif = hk.without_apply_rng(hk.transform(snn_alif))
params_alif = SNN_alif.init(rng=key, x=x[0])

## evolution


In [10]:
import optax


def mse_spikerate(traces, targets, sparsity=.35, smoothing=0):
    """
    Calculate the mean squared error of the mean spike rate.
    Allows for label smoothing to discourage silencing 
    the other neurons in the readout layer.

    Attributes:
        traces: the output of the final layer of the SNN
        targets: the integer labels for each class
        smoothing: [optional] rate at which to smooth labels.
    """
    t = traces.shape[1]
    logits = jnp.sum(traces, axis=-2) # time axis.
    labels = optax.smooth_labels(jax.nn.one_hot(targets, logits.shape[-1]), smoothing)
    return jnp.mean(optax.squared_error(logits, labels * sparsity * t))

In [11]:
def evo(SNN, params, dl, key, epochs=300): # rename this
    rng = key        
    aug = spyx.data.shift_augment(8) # need to make this stateless
    
    param_reshaper = evosax.ParameterReshaper(params)
        

    # Instantiate jittable fitness shaper (e.g. for Open ES)
    fit_shaper = FitnessShaper(maximize=False)

    # Shape the evaluated fitness scores
    
    strategy = CRFMNES(popsize=512, # 192
                num_dims=param_reshaper.total_params,
                sigma_init=0.15
                )

    opt_state = strategy.initialize(rng)
    
            
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        #acc, _ = spyx.fn.integral_accuracy(traces, targets)
        loss = mse_spikerate(traces, targets)
        return loss #acc
        
    sim_fn = jax.vmap(net_eval, (0, None, None))
    
    # compile the meat of our training loop for speed
    @jax.jit
    def step(state, data):
        old_state, rng = state
        rng, rng_ask, rng_aug = jax.random.split(rng, 3)
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        pop, evo_state = strategy.ask(rng_ask, old_state)
        population_params = param_reshaper.reshape(pop.astype(jnp.float16)) 
        # EVAL
        fit = sim_fn(population_params, aug(events, rng_aug), targets) 
        # TELL
        fit_shaped = fit_shaper.apply(pop, fit) 
        new_state = [strategy.tell(pop, fit_shaped, evo_state), rng]
        
        return new_state, fit
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params, data):
        events, targets = data # fix
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        readout = SNN.apply(grad_params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        return grad_params, acc
        
    
    val_data = dl.val_epoch()
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(curr_opt_state, epoch_num):
        
        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = dl.train_epoch(shuffle_rng)
        
        # train epoch
        end_state, train_acc = jax.lax.scan(
            step,# func
            curr_opt_state,# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
        
        elite = param_reshaper.reshape(jnp.array([end_state[0].best_member]))
        new_params = jax.tree_util.tree_map(lambda x: x[0], elite)
                        
        # val epoch
        _, val_acc = jax.lax.scan(
            eval_step,# func
            new_params,# init
            val_data,# xs
            val_data.obs.shape[0]# len
        )

        
        return end_state, jnp.hstack([jnp.max(train_acc), jnp.mean(val_acc)])
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [opt_state, rng], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    elite = param_reshaper.reshape(jnp.array([final_state[0].best_member]))
    final_params = jax.tree_util.tree_map(lambda x: x[0], elite)
                
    # return our final, optimized network.       
    return final_params, metrics

In [12]:
def test_evo(SNN, params, dl):

    @jax.jit
    def test_step(params, data):
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        return params, [acc, pred, targets]
    
    test_data = dl.test_epoch()
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            params,# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    preds = jnp.array(test_metrics[1]).flatten()
    tgts = jnp.array(test_metrics[2]).flatten()
    return acc, preds, tgts,

## Training Time


In [13]:
from time import time

# Seed: 42

In [14]:
seed = jax.random.PRNGKey(42)

In [16]:
start = time()
evolved_params_alif, metrics_alif = evo(SNN_alif, params_alif, shd_dl, seed, epochs=1000) # started 13:33
elapsed = time() - start
print(elapsed)
print("Performance: train_acc={}, val_acc={}".format(*metrics_alif[-1]))
acc, preds, tgts = test_evo(SNN_alif, evolved_params_alif, shd_dl)
print(acc)

ParameterReshaper: 13844 parameters detected for optimization.
128


  0%|          | 0/1000 [00:00<?, ?it/s]

1750.874398946762
Performance: train_acc=109.28594207763672, val_acc=0.7649739980697632
0.7597656


In [24]:
acc, preds, tgts = test_evo(SNN_alif, evolved_params_alif, shd_dl)
acc

Array(0.75634766, dtype=float32)

# Seed: 12345

In [25]:
seed = jax.random.PRNGKey(12345)

In [27]:
start = time()
evolved_params_alif, metrics_alif = evo(SNN_alif, params_alif, shd_dl, seed, epochs=1000) # started 13:33
elapsed = time() - start
print(elapsed)

ParameterReshaper: 13844 parameters detected for optimization.


  0%|          | 0/1000 [00:00<?, ?it/s]

1756.8226170539856


In [29]:
print("Performance: train_acc={}, val_acc={}".format(*metrics_alif[-1]))

Performance: train_acc=61.52933883666992, val_acc=0.75


In [31]:
acc, preds, tgts = test_evo(SNN_alif, evolved_params_alif, shd_dl)
acc

Array(0.70166016, dtype=float32)

# Seed: 54321

In [32]:
seed = jax.random.PRNGKey(54321)

In [34]:
start = time()
evolved_params_alif, metrics_alif = evo(SNN_alif, params_alif, shd_dl, seed, epochs=1000) # started 13:33
elapsed = time() - start
print(elapsed)

ParameterReshaper: 13844 parameters detected for optimization.


  0%|          | 0/1000 [00:00<?, ?it/s]

1756.6970422267914


In [36]:
print("Performance: train_acc={}, val_acc={}".format(*metrics_alif[-1]))

Performance: train_acc=101.5151596069336, val_acc=0.75390625


In [38]:
acc, preds, tgts = test_evo(SNN_alif, evolved_params_alif, shd_dl)
acc

Array(0.75341797, dtype=float32)

# Seed: 0

In [16]:
seed = jax.random.PRNGKey(0)

In [19]:
start = time()
evolved_params_alif, metrics_alif = evo(SNN_alif, params_alif, shd_dl, seed, epochs=1000) # started 13:33
elapsed = time() - start
print(elapsed)

ParameterReshaper: 13844 parameters detected for optimization.


  0%|          | 0/1000 [00:00<?, ?it/s]

1756.853637456894


In [20]:
print("Performance: train_acc={}, val_acc={}".format(*metrics_alif[-1]))

Performance: train_acc=79.02629852294922, val_acc=0.73828125


In [22]:
acc, preds, tgts = test_evo(SNN_alif, evolved_params_alif, shd_dl)
acc

Array(0.73046875, dtype=float32)

# Seed: 7

In [23]:
seed = jax.random.PRNGKey(7)

In [25]:
start = time()
evolved_params_alif, metrics_alif = evo(SNN_alif, params_alif, shd_dl, seed, epochs=1000) # started 13:33
elapsed = time() - start
print(elapsed)

ParameterReshaper: 13844 parameters detected for optimization.


  0%|          | 0/1000 [00:00<?, ?it/s]

1755.4052205085754


In [27]:
print("Performance: train_acc={}, val_acc={}".format(*metrics_alif[-1]))

Performance: train_acc=57.57223129272461, val_acc=0.751953125


In [29]:
acc, preds, tgts = test_evo(SNN_alif, evolved_params_alif, shd_dl)
acc

Array(0.7211914, dtype=float32)