In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".80"

import numpy as np

import jax
import jax.numpy as jnp
import spyx
import haiku as hk
import optax
from jax_tqdm import scan_tqdm

### SHD Dataloading

In [2]:

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [3]:
class _SHD2Raster():
    """ 
    Tool for rastering SHD samples into frames. Packs bits along the temporal axis for memory efficiency. This means
        that the used will have to apply jnp.unpackbits(events, axis=<time axis>) prior to feeding the data to the network.
    """

    def __init__(self, encoding_dim, sample_T = 100):
        self.encoding_dim = encoding_dim
        self.sample_T = sample_T
        
    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        #return tensor[:self.sample_T,:]
        tensor = tensor[:self.sample_T,:]
        tensor = np.minimum(tensor, 1)
        tensor = np.packbits(tensor, axis=0)
        return tensor

In [4]:
sample_T = 64
shd_timestep = 1e-6
shd_channels = 700
net_channels = 128
net_dt = 1/sample_T
batch_size = 256

obs_shape = tuple([net_channels,])
act_shape = tuple([20,])

transform = transforms.Compose([
    transforms.Downsample(
        time_factor=shd_timestep / net_dt,
        spatial_factor=net_channels / shd_channels
    ),
    _SHD2Raster(net_channels, sample_T=sample_T)
])

train_dataset = datasets.SHD("./data", train=True, transform=transform)
test_dataset = datasets.SHD("./data", train=False, transform=transform)



In [5]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)

In [6]:
test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_test, y_test = next(test_dl)

In [7]:
print(type(x_train))
print(x_train.shape)

<class 'torch.Tensor'>
torch.Size([8156, 8, 128])


In [8]:
x_train = jnp.array(x_train, dtype=jnp.uint8)
y_train = jnp.array(y_train, dtype=jnp.uint8)

x_test = jnp.array(x_test, dtype=jnp.uint8)
y_test = jnp.array(y_test, dtype=jnp.uint8)

In [9]:
def shuffle(dataset, shuffle_rng):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    obs = jax.random.permutation(shuffle_rng, x, axis=0)[:-cutoff]
    labels = jax.random.permutation(shuffle_rng, y, axis=0)[:-cutoff]

    obs = jnp.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = jnp.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

### Spyx SHD

In [39]:
def shd_snn(x): 
        
    core = hk.DeepRNN([
        hk.Linear(64, with_bias=False),
        spyx.nn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.arctan())),
        hk.Linear(64, with_bias=False),
        spyx.nn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.arctan())),
        hk.Linear(20, with_bias=False),
        spyx.nn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=32)
    
    return spikes, V

In [40]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
sample_x, sample_y = shuffle((x_train,y_train),key)
SNN = hk.without_apply_rng(hk.transform(shd_snn))
params = SNN.init(rng=key, x=jnp.float32(sample_x[0]))

In [41]:
def gd(SNN, params, dataset, epochs=300):
        
    opt = optax.adam(learning_rate=5e-4)
    
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        return spyx.fn.integral_crossentropy(traces, targets) # smoothing needs to be more explicit in docs...
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = jax.random.PRNGKey(0)        
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        # compute loss and gradient                    # need better augment rng
        loss, grads = surrogate_grad(grad_params, events, targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state

        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = shuffle(dataset, shuffle_rng)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
                    
        return end_state, jnp.mean(train_loss)
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [grad_params, opt_state], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    final_params, _ = final_state
    
                
    # return our final, optimized network.       
    return final_params, metrics

In [42]:
def test_gd(SNN, params, dataset):

    @jax.jit
    def test_step(params, data):
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        loss = spyx.fn.integral_crossentropy(traces, targets)
        return params, [acc, loss, pred, targets]
    
    test_data = shuffle(dataset, jax.random.PRNGKey(0))
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            params,# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    loss = jnp.mean(test_metrics[1])
    preds = jnp.array(test_metrics[2]).flatten()
    tgts = jnp.array(test_metrics[3]).flatten()
    return acc, loss, preds, tgts

In [43]:
grad_params, metrics = gd(SNN, params, (x_train,y_train), epochs=300)


  0%|          | 0/300 [00:00<?, ?it/s]

In [44]:
metrics

Array([3.3910465, 3.016463 , 3.030699 , 3.0076985, 2.9221492, 2.8107104,
       2.7359867, 2.65059  , 2.5880544, 2.5470932, 2.5164132, 2.465876 ,
       2.4436922, 2.4118633, 2.3856888, 2.352564 , 2.332061 , 2.315615 ,
       2.2930186, 2.2780628, 2.2673588, 2.2505302, 2.2261753, 2.2158225,
       2.194442 , 2.1847625, 2.169232 , 2.157136 , 2.1424909, 2.1340435,
       2.1233969, 2.1086009, 2.1064436, 2.0974517, 2.0898352, 2.0800953,
       2.0742557, 2.067817 , 2.0587747, 2.053098 , 2.053758 , 2.0434794,
       2.034027 , 2.0378797, 2.0217361, 2.0174112, 2.0115888, 2.0055044,
       1.9960711, 1.9975836, 1.9898484, 1.9836861, 1.9812824, 1.9706731,
       1.9612522, 1.9600667, 1.9502194, 1.9454533, 1.9411132, 1.9386741,
       1.9352958, 1.927839 , 1.92639  , 1.9222975, 1.9149606, 1.9142834,
       1.9100983, 1.9087315, 1.9014016, 1.8975301, 1.8981328, 1.8936834,
       1.8925552, 1.8859547, 1.8824438, 1.8849107, 1.8788638, 1.8789203,
       1.8728931, 1.8714441, 1.8716774, 1.8650261, 

In [45]:
acc, loss, preds, tgts = test_gd(SNN, grad_params, (x_test,y_test))
print("Accuracy:", acc, "Loss:", loss)

Accuracy: 0.7421875 Loss: 2.0260453


### Use NIR to save our network and then load it up later, in any framework of our choosing!

In [46]:
import nir
export_params = spyx.nir.reorder_layers(params, grad_params)
G = spyx.nir.to_nir(export_params, obs_shape, act_shape, 1)
nir.write("./spyx_shd.nir", G)