In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".95"

import numpy as np

import jax
import jax.numpy as jnp
import spyx
import haiku as hk
import optax
from jax_tqdm import scan_tqdm

### SHD Dataloading

In [2]:

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [3]:
sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([
                                      transforms.ToFrame(sensor_size=sensor_size, 
                                                         n_time_bins=48),
                                      lambda x: np.packbits(x, axis=0)
                                     ])

train_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=True)
test_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=False)

In [4]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset)//3,
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=True))
        
x_train, y_train = next(train_dl)

In [5]:
#test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
#                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
#        
#x_test, y_test = next(test_dl)

In [6]:
x_train = jnp.array(x_train, dtype=jnp.uint8)
y_train = jnp.array(y_train, dtype=jnp.uint8)

#x_test = jnp.array(x_test, dtype=jnp.uint8)
#y_test = jnp.array(y_test, dtype=jnp.uint8)

In [7]:
def shuffle(dataset, shuffle_rng, batch_size):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    obs = jax.random.permutation(shuffle_rng, x, axis=0)[:-cutoff] # this is a bug if cutoff == 0
    labels = jax.random.permutation(shuffle_rng, y, axis=0)[:-cutoff]
    print(labels.shape)

    obs = jnp.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = jnp.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

### Spyx NMNIST

In [8]:
def build_snn(batch_size):

    def nmnist_snn(x): 
        
        core = hk.DeepRNN([
            hk.Conv2D(12, 5),
            spyx.nn.IF((2, 34, 12,), activation=spyx.axn.Axon(spyx.axn.arctan())),
            hk.MaxPool((2,2), (2,2), "SAME"),
            hk.Conv2D(32, 5),
            spyx.nn.LIF((2, 17, 32,), activation=spyx.axn.Axon(spyx.axn.arctan())),
            hk.MaxPool((2,2), (2,2), "SAME"),
            hk.Flatten(),
            hk.Linear(10, with_bias=False),
            spyx.nn.LIF((10,), activation=spyx.axn.Axon(spyx.axn.arctan()))
        ])
        
        # static unroll for maximum performance
        spikes, V = hk.dynamic_unroll(core, x.astype(jnp.float32), core.initial_state(x.shape[0]), time_major=False, unroll=16)
    
        return spikes, V
    
    key = jax.random.PRNGKey(0)
    # Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
    sample_x, sample_y = shuffle((x_train,y_train),key, batch_size)
    print(sample_x.shape)
    SNN = hk.without_apply_rng(hk.transform(nmnist_snn))
    params = SNN.init(rng=key, x=jnp.float32(sample_x[0]))
    
    return SNN, params




In [9]:
def benchmark(SNN, params, dataset, epochs, batch_size):
        
    opt = optax.adam(learning_rate=5e-4)
    
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        return spyx.fn.integral_crossentropy(traces, targets) # smoothing needs to be more explicit in docs...
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = jax.random.PRNGKey(0)        
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        # compute loss and gradient                    # need better augment rng
        loss, grads = surrogate_grad(grad_params, events, targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state

        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = shuffle(dataset, shuffle_rng, batch_size)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
                    
        return end_state, jnp.mean(train_loss)
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [grad_params, opt_state], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    final_params, _ = final_state
    
                
    # return our final, optimized network.       
    return final_params, metrics

In [10]:
from time import time

def run_bench(trials, num_epochs, batch_size):
    
    SNN, params = build_snn(batch_size)

    times = []
    for t in range(trials+1):
        print(t, ":", end="")
        start = time()
        benchmark(SNN, params, (x_train,y_train), num_epochs, batch_size)
        times.append(time() - start)
        print(times[t])
    
    print("Mean:", np.mean(times[1:]), "Std. Dev.:", np.std(times[1:]))

In [11]:
run_bench(5, 10, 32)

(49984,)
(1562, 32, 8, 2, 34, 34)
0 :(49984,)


  0%|          | 0/10 [00:00<?, ?it/s]

166.84630870819092
1 :(49984,)


  0%|          | 0/10 [00:00<?, ?it/s]

166.1280755996704
2 :(49984,)


  0%|          | 0/10 [00:00<?, ?it/s]

165.97760009765625
3 :(49984,)


  0%|          | 0/10 [00:00<?, ?it/s]

166.11146712303162
4 :(49984,)


  0%|          | 0/10 [00:00<?, ?it/s]

166.00985074043274
5 :(49984,)


  0%|          | 0/10 [00:00<?, ?it/s]

166.16814804077148
Mean: 166.0790283203125 Std. Dev.: 0.07276463154332144
