In [1]:
import os

import numpy as np

import jax
import jax.numpy as jnp
import spyx
import haiku as hk
import optax
from jax_tqdm import scan_tqdm

AttributeError: module 'jax.typing' has no attribute 'DTypeLike'

In [2]:
x = jnp.array([1,2,3])
x.devices()

CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12030, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


{CpuDevice(id=0)}

### SHD Dataloading

In [3]:

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [4]:
class _SHD2Raster():
    """ 
    Tool for rastering SHD samples into frames. Packs bits along the temporal axis for memory efficiency. This means
        that the used will have to apply jnp.unpackbits(events, axis=<time axis>) prior to feeding the data to the network.
    """

    def __init__(self, encoding_dim, sample_T = 100):
        self.encoding_dim = encoding_dim
        self.sample_T = sample_T
        
    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        #return tensor[:self.sample_T,:]
        tensor = tensor[:self.sample_T,:]
        tensor = np.minimum(tensor, 1)
        tensor = np.packbits(tensor, axis=0)
        return tensor

In [5]:
sample_T = 256
shd_timestep = 1e-6
shd_channels = 700
net_channels = 128
net_dt = 1/sample_T

obs_shape = tuple([net_channels,])
act_shape = tuple([20,])

transform = transforms.Compose([
    transforms.Downsample(
        time_factor=shd_timestep / net_dt,
        spatial_factor=net_channels / shd_channels
    ),
    _SHD2Raster(net_channels, sample_T=sample_T)
])

train_dataset = datasets.SHD("./data", train=True, transform=transform)
test_dataset = datasets.SHD("./data", train=False, transform=transform)



In [6]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)

In [7]:
test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_test, y_test = next(test_dl)

In [8]:
print(type(x_train))
print(x_train.shape)

<class 'torch.Tensor'>
torch.Size([8156, 32, 128])


In [9]:
x_train = jnp.array(x_train, dtype=jnp.uint8)
y_train = jnp.array(y_train, dtype=jnp.uint8)

#x_test = jnp.array(x_test, dtype=jnp.uint8)
#y_test = jnp.array(y_test, dtype=jnp.uint8)

In [10]:
def shuffle(dataset, shuffle_rng, batch_size):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    obs = jax.random.permutation(shuffle_rng, x, axis=0)[:-cutoff]
    labels = jax.random.permutation(shuffle_rng, y, axis=0)[:-cutoff]

    obs = jnp.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = jnp.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

### Spyx SHD

In [22]:
def build_snn(hidden_shape, batch_size):

    def shd_snn(x): 
        
        x = hk.BatchApply(hk.Linear(hidden_shape, with_bias=False))(x)
        core = hk.DeepRNN([
            #hk.Linear(hidden_shape, with_bias=False),
            spyx.nn.LIF((hidden_shape,), activation=spyx.axn.Axon(spyx.axn.arctan())),
            hk.Linear(hidden_shape, with_bias=False),
            spyx.nn.LIF((hidden_shape,), activation=spyx.axn.Axon(spyx.axn.arctan())),
            hk.Linear(20, with_bias=False),
            spyx.nn.LI((20,))
        ])
        
        # static unroll for maximum performance
        spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=32)
        
        return spikes, V
    
    key = jax.random.PRNGKey(0)
    # Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
    sample_x, sample_y = shuffle((x_train,y_train),key, batch_size)
    SNN = hk.without_apply_rng(hk.transform(shd_snn))
    params = SNN.init(rng=key, x=jnp.float32(sample_x[0]))
    
    return SNN, params

In [12]:
def benchmark(SNN, params, dataset, epochs, batch_size):
        
    opt = optax.adam(learning_rate=5e-4)
    
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        return spyx.fn.integral_crossentropy(traces, targets) # smoothing needs to be more explicit in docs...
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = jax.random.PRNGKey(0)        
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        # compute loss and gradient                    # need better augment rng
        loss, grads = surrogate_grad(grad_params, events, targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state

        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = shuffle(dataset, shuffle_rng, batch_size)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
                    
        return end_state, jnp.mean(train_loss)
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [grad_params, opt_state], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    final_params, _ = final_state
    
                
    # return our final, optimized network.       
    return final_params, metrics

In [13]:
from time import time

def run_bench(trials, num_epochs, net_width, batch_size):
    
    SNN, params = build_snn(net_width, batch_size)

    times = []
    for t in range(trials+1):
        print(t, ":", end="")
        start = time()
        benchmark(SNN, params, (x_train,y_train), num_epochs, batch_size)
        times.append(time() - start)
        print(times[t])
    
    print("Mean:", np.mean(times[1:]), "Std. Dev.:", np.std(times[1:]))

In [14]:
run_bench(5, 100, 128, 256)

0 :

  0%|          | 0/100 [00:00<?, ?it/s]

36.123450756073
1 :

  0%|          | 0/100 [00:00<?, ?it/s]

35.33270812034607
2 :

  0%|          | 0/100 [00:00<?, ?it/s]

35.5319561958313
3 :

  0%|          | 0/100 [00:00<?, ?it/s]

35.423484802246094
4 :

  0%|          | 0/100 [00:00<?, ?it/s]

35.56348395347595
5 :

  0%|          | 0/100 [00:00<?, ?it/s]

35.573920488357544
Mean: 35.485110712051394 Std. Dev.: 0.09300359499217725


In [15]:
run_bench(5, 100, 128, 128)

0 :

  0%|          | 0/100 [00:00<?, ?it/s]

46.61739468574524
1 :

  0%|          | 0/100 [00:00<?, ?it/s]

46.04706287384033
2 :

  0%|          | 0/100 [00:00<?, ?it/s]

45.88801288604736
3 :

  0%|          | 0/100 [00:00<?, ?it/s]

46.05511808395386
4 :

  0%|          | 0/100 [00:00<?, ?it/s]

45.91830658912659
5 :

  0%|          | 0/100 [00:00<?, ?it/s]

45.91482067108154
Mean: 45.96466422080994 Std. Dev.: 0.07138665976859078


In [16]:
run_bench(5, 100, 128, 64)

0 :

  0%|          | 0/100 [00:00<?, ?it/s]

70.3422839641571
1 :

  0%|          | 0/100 [00:00<?, ?it/s]

69.3633873462677
2 :

  0%|          | 0/100 [00:00<?, ?it/s]

69.38180875778198
3 :

  0%|          | 0/100 [00:00<?, ?it/s]

69.17154335975647
4 :

  0%|          | 0/100 [00:00<?, ?it/s]

69.1550395488739
5 :

  0%|          | 0/100 [00:00<?, ?it/s]

69.49734258651733
Mean: 69.31382431983948 Std. Dev.: 0.13131169535746676


In [17]:
run_bench(5, 100, 512, 64)

0 :

  0%|          | 0/100 [00:00<?, ?it/s]

97.50239658355713
1 :

  0%|          | 0/100 [00:00<?, ?it/s]

98.45136499404907
2 :

  0%|          | 0/100 [00:00<?, ?it/s]

99.08661842346191
3 :

  0%|          | 0/100 [00:00<?, ?it/s]

98.69157314300537
4 :

  0%|          | 0/100 [00:00<?, ?it/s]

98.5738685131073
5 :

  0%|          | 0/100 [00:00<?, ?it/s]

98.65601754188538
Mean: 98.69188852310181 Std. Dev.: 0.213952710561666


In [18]:
run_bench(5, 100, 512, 128)

0 :

  0%|          | 0/100 [00:00<?, ?it/s]

81.29359841346741
1 :

  0%|          | 0/100 [00:00<?, ?it/s]

80.37861514091492
2 :

  0%|          | 0/100 [00:00<?, ?it/s]

80.15967798233032
3 :

  0%|          | 0/100 [00:00<?, ?it/s]

80.07884001731873
4 :

  0%|          | 0/100 [00:00<?, ?it/s]

80.34240293502808
5 :

  0%|          | 0/100 [00:00<?, ?it/s]

79.97054433822632
Mean: 80.18601608276367 Std. Dev.: 0.15502239966829467


In [19]:
run_bench(5, 100, 512, 256)

0 :

  0%|          | 0/100 [00:00<?, ?it/s]

64.72566843032837
1 :

  0%|          | 0/100 [00:00<?, ?it/s]

64.08358216285706
2 :

  0%|          | 0/100 [00:00<?, ?it/s]

64.47463417053223
3 :

  0%|          | 0/100 [00:00<?, ?it/s]

63.99449133872986
4 :

  0%|          | 0/100 [00:00<?, ?it/s]

63.921780586242676
5 :

  0%|          | 0/100 [00:00<?, ?it/s]

63.9120135307312
Mean: 64.0773003578186 Std. Dev.: 0.2079793682085679


In [None]:
metrics

In [None]:
acc, loss, preds, tgts = test_gd(SNN, grad_params, (x_test,y_test))
print("Accuracy:", acc, "Loss:", loss)

### Use NIR to save our network and then load it up later, in any framework of our choosing!

In [None]:
import nir
export_params = spyx.nir.reorder_layers(params, grad_params)
G = spyx.nir.to_nir(export_params, obs_shape, act_shape, 1)
nir.write("./spyx_shd.nir", G)