In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".9"

import numpy as np

import jax
import jax.numpy as jnp
import spyx
import haiku as hk
import optax
from jax_tqdm import scan_tqdm

In [2]:
test = jnp.array(1.)
test.devices()

{cuda(id=0)}

### NMNIST Dataloading

In [3]:

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [4]:
sensor_size = tonic.datasets.NMNIST.sensor_size

# Denoise removes isolated, one-off events
# time_window
frame_transform = transforms.Compose([
                                      transforms.ToFrame(sensor_size=sensor_size, 
                                                         n_time_bins=64),
                                      lambda x: np.packbits(x, axis=0)
                                     ])

train_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=True)
#test_dataset = tonic.datasets.NMNIST(save_to='./tmp/data', transform=frame_transform, train=False)

In [5]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset)//2,
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=True))
        
x_train, y_train = next(train_dl)

In [6]:
#test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
#                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
#        
#x_test, y_test = next(test_dl)

In [7]:
x_train = jnp.array(x_train, dtype=jnp.uint8)
y_train = jnp.array(y_train, dtype=jnp.uint8)

#x_test = jnp.array(x_test, dtype=jnp.uint8)
#y_test = jnp.array(y_test, dtype=jnp.uint8)

In [8]:
def _shuffle(dataset, shuffle_rng, batch_size):
    x, y = dataset

    full_batches = y.shape[0] // batch_size

    indices = jax.random.permutation(shuffle_rng, y.shape[0])[:full_batches*batch_size]
    obs, labels = x[indices], y[indices]

    obs = jnp.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = jnp.reshape(labels, (-1, batch_size)) # should make batch size a global

    return obs, labels

shuffle = jax.jit(_shuffle, static_argnums=2)

### Spyx NMNIST

In [9]:
def build_snn(batch_size, channel_multiplier):

    mult = channel_multiplier
    
    def nmnist_snn(x):
        x = hk.BatchApply(hk.Conv2D(12*mult, 5, padding="VALID", data_format="NCHW", with_bias=False))(x.astype(jnp.float32))

        core1 = spyx.nn.LIF((12*mult, 30, 30, ), beta=0.5, )
        x, V = hk.static_unroll(core1, x, core1.initial_state(x.shape[1]), time_major=True)
        
        fused1 = hk.Sequential([
            hk.MaxPool((2,2,), (2,2,), "VALID"),
            hk.Conv2D(32*mult, 5, padding="VALID", data_format="NCHW", with_bias=False)
        ])
        x = hk.BatchApply(fused1)(x)

        core2 = spyx.nn.LIF(( 32*mult, 11, 11, ), beta=0.5, )
        x, V = hk.static_unroll(core2, x, core2.initial_state(x.shape[1]), time_major=True)
        
        fused2 = hk.Sequential([
            hk.MaxPool((2,2), (2,2), "VALID",),
            hk.Flatten(),
            hk.Linear(10, with_bias=False)
        ])
        x = hk.BatchApply(fused2)(x)

        core3 = spyx.nn.LI((10,), beta=0.5, )
        spikes, V = hk.static_unroll(core3, x, core3.initial_state(x.shape[1]), time_major=True)
    
        return spikes, V
    
    key = jax.random.PRNGKey(0)
    # Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
    sample_x, sample_y = shuffle((x_train,y_train),key, batch_size)
    SNN = hk.without_apply_rng(hk.transform(nmnist_snn))
    params = SNN.init(rng=key, x=jnp.float32(sample_x[0]))
    
    return SNN, params




In [10]:
def benchmark(SNN, params, dataset, epochs, batch_size):
        
    opt = optax.adam(learning_rate=2e-2)
    
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params

    Loss = spyx.fn.integral_crossentropy(time_axis=0)
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        return Loss(traces, targets) # smoothing needs to be more explicit in docs...
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = jax.random.PRNGKey(0)        
    
    # compile the meat of our training loop for speed
    # @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data
        #events = jnp.transpose(events, (1,0,2,3,4))
        events = jnp.swapaxes(events, 0, 1)
        events = jnp.unpackbits(events, axis=0) # decompress temporal axis
        # compute loss and gradient
        loss, grads = surrogate_grad(grad_params, events, targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state

        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = shuffle(dataset, shuffle_rng, batch_size)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
        )

        return end_state, jnp.mean(train_loss)
    # end epoch
    
    # epoch loop
    

    def _run():
        final_state, metrics = jax.lax.scan(
            epoch,
            [grad_params, opt_state], # metric arrays
            jnp.arange(epochs), # 
            epochs # len of loop
        )
        
        return metrics

    comp_start = time()
    run = jax.jit(_run).lower().compile()
    comp_end = time() - comp_start
    print("Compile Time:", comp_end)
    
    start = time()
    results = run()
    results.block_until_ready()
    end = time() - start
           
    return end

In [11]:
from time import time

def run_bench(trials, num_epochs, batch_size, mult):
    
    SNN, params = build_snn(batch_size, mult)

    # need to change how time is measured to match...
    times = []
    for t in range(trials+1):
        times.append(benchmark(SNN, params, (x_train,y_train), num_epochs, batch_size))
        print(times[t])
    
    print("Runtime Mean:", np.mean(times[1:]), "Std. Dev.:", np.std(times[1:]))
    return SNN, params

In [12]:
snn, p = run_bench(5, 20, 32, 1) # 160 seconds on laptop3060, 10k images.

Compile Time: 37.64500951766968


  0%|          | 0/20 [00:00<?, ?it/s]

133.9073519706726
Compile Time: 27.27926778793335


  0%|          | 0/20 [00:00<?, ?it/s]

135.0098569393158
Compile Time: 27.616517543792725


  0%|          | 0/20 [00:00<?, ?it/s]

134.95792269706726
Compile Time: 26.505394458770752


  0%|          | 0/20 [00:00<?, ?it/s]

134.97410702705383
Compile Time: 27.926079273223877


  0%|          | 0/20 [00:00<?, ?it/s]

135.05541467666626
Compile Time: 26.408557653427124


  0%|          | 0/20 [00:00<?, ?it/s]

134.86229848861694
Runtime Mean: 134.971919965744 Std. Dev.: 0.06423107617170751


In [13]:
snn, p = run_bench(5, 20, 64, 1)

Compile Time: 37.998034715652466


  0%|          | 0/20 [00:00<?, ?it/s]

128.1665997505188
Compile Time: 26.181029796600342


  0%|          | 0/20 [00:00<?, ?it/s]

128.13299107551575
Compile Time: 25.68013906478882


  0%|          | 0/20 [00:00<?, ?it/s]

128.40380263328552
Compile Time: 25.271388053894043


  0%|          | 0/20 [00:00<?, ?it/s]

128.43401336669922
Compile Time: 25.997533559799194


  0%|          | 0/20 [00:00<?, ?it/s]

128.4314284324646
Compile Time: 26.139952659606934


  0%|          | 0/20 [00:00<?, ?it/s]

128.35471177101135
Runtime Mean: 128.3513894557953 Std. Dev.: 0.11286184466069653


In [18]:
snn, p = run_bench(5, 20, 128, 1) # 160 seconds on laptop3060, 10k images.

  0%|          | 0/20 [00:00<?, ?it/s]

126.78099060058594


  0%|          | 0/20 [00:00<?, ?it/s]

127.30293607711792


  0%|          | 0/20 [00:00<?, ?it/s]

127.3445155620575


  0%|          | 0/20 [00:00<?, ?it/s]

127.34854197502136


  0%|          | 0/20 [00:00<?, ?it/s]

127.36679887771606


  0%|          | 0/20 [00:00<?, ?it/s]

127.38501191139221
Mean: 127.34956088066102 Std. Dev.: 0.027401787665478735


In [14]:
snn, p = run_bench(5, 20, 32, 2) # 160 seconds on laptop3060, 10k images.

Compile Time: 35.90413522720337


  0%|          | 0/20 [00:00<?, ?it/s]

245.8349483013153
Compile Time: 25.80127263069153


  0%|          | 0/20 [00:00<?, ?it/s]

247.14030385017395
Compile Time: 26.146437406539917


  0%|          | 0/20 [00:00<?, ?it/s]

247.1255714893341
Compile Time: 26.303130388259888


  0%|          | 0/20 [00:00<?, ?it/s]

247.43779158592224
Compile Time: 26.25004291534424


  0%|          | 0/20 [00:00<?, ?it/s]

247.14701628684998
Compile Time: 26.31723427772522


  0%|          | 0/20 [00:00<?, ?it/s]

247.15912556648254
Runtime Mean: 247.20196175575256 Std. Dev.: 0.11841184280787566


In [15]:
snn, p = run_bench(5, 20, 64, 2) # 160 seconds on laptop3060, 10k images.

Compile Time: 38.9309356212616


  0%|          | 0/20 [00:00<?, ?it/s]

241.11891746520996
Compile Time: 25.906715154647827


  0%|          | 0/20 [00:00<?, ?it/s]

241.0532741546631
Compile Time: 26.164097785949707


  0%|          | 0/20 [00:00<?, ?it/s]

241.21807718276978
Compile Time: 25.36104130744934


  0%|          | 0/20 [00:00<?, ?it/s]

241.2156093120575
Compile Time: 25.95780324935913


  0%|          | 0/20 [00:00<?, ?it/s]

241.29518818855286
Compile Time: 26.010263442993164


  0%|          | 0/20 [00:00<?, ?it/s]

241.12646198272705
Runtime Mean: 241.18172216415405 Std. Dev.: 0.08354297535757978


In [16]:
snn, p = run_bench(5, 20, 128, 2) # 160 seconds on laptop3060, 10k images.

Compile Time: 41.75052571296692


  0%|          | 0/20 [00:00<?, ?it/s]

236.33672213554382
Compile Time: 26.364729642868042


  0%|          | 0/20 [00:00<?, ?it/s]

236.53257727622986
Compile Time: 25.484204292297363


  0%|          | 0/20 [00:00<?, ?it/s]

236.54382729530334
Compile Time: 26.426677227020264


  0%|          | 0/20 [00:00<?, ?it/s]

236.52414178848267
Compile Time: 25.715760946273804


  0%|          | 0/20 [00:00<?, ?it/s]

236.51134419441223
Compile Time: 26.352440118789673


  0%|          | 0/20 [00:00<?, ?it/s]

236.50918078422546
Runtime Mean: 236.5242142677307 Std. Dev.: 0.01300969199088646
