In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".40"

import numpy as np

import jax
import jax.numpy as jnp
import spyx
import haiku as hk
import optax
from jax_tqdm import scan_tqdm

import torch
import snntorch

## SHD Dataloading

In [2]:
shd_dl = spyx.data.SHD_loader(256,128,128)

key = jax.random.PRNGKey(0)
x, y = shd_dl.train_epoch(key)

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
j2t_data = lambda data: torch.from_numpy(np.array(jnp.unpackbits(data, axis=1))).to(device)
j2t_targets = lambda tgt: torch.from_numpy(np.array(tgt)).to(device)

## Spyx SHD

In [4]:
def shd_snn(x):
        
    core = hk.DeepRNN([
        hk.Linear(64, with_bias=False),
        spyx.nn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.triangular())),
        hk.Linear(64, with_bias=False),
        spyx.nn.LIF((64,), activation=spyx.axn.Axon(spyx.axn.triangular())),
        hk.Linear(20, with_bias=False),
        spyx.nn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=32)
    
    return spikes, V

In [5]:
type(x)

jaxlib.xla_extension.ArrayImpl

In [6]:
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN = hk.without_apply_rng(hk.transform(shd_snn))
params = SNN.init(rng=key, x=jnp.float16(x[0]))

In [7]:
def gd(SNN, params, dl, epochs=300):
        
    opt = optax.adam(learning_rate=5e-4)
    
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout = SNN.apply(weights, events)
        traces, V_f = readout
        return spyx.fn.integral_crossentropy(traces, targets)
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = jax.random.PRNGKey(0)        
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        # compute loss and gradient                    # need better augment rng
        loss, grads = surrogate_grad(grad_params, events, targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params, data):
        events, targets = data # fix
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(grad_params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        loss = spyx.fn.integral_crossentropy(traces, targets)
        return grad_params, jnp.array([acc, loss])
        
    
    val_data = dl.val_epoch()
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state
        
        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = dl.train_epoch(shuffle_rng)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
        
        new_params, _ = end_state
            
        # val epoch
        _, val_metrics = jax.lax.scan(
            eval_step,# func
            new_params,# init
            val_data,# xs
            val_data.obs.shape[0]# len
        )

        
        return end_state, jnp.concatenate([jnp.expand_dims(jnp.mean(train_loss),0), jnp.mean(val_metrics, axis=0)])
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [grad_params, opt_state], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    final_params, _ = final_state
    
                
    # return our final, optimized network.       
    return final_params, metrics

In [8]:
def test_gd(SNN, params, dl):

    @jax.jit
    def test_step(params, data):
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout = SNN.apply(params, events)
        traces, V_f = readout
        acc, pred = spyx.fn.integral_accuracy(traces, targets)
        loss = spyx.fn.integral_crossentropy(traces, targets)
        return params, [acc, loss, pred, targets]
    
    test_data = dl.test_epoch()
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            params,# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    loss = jnp.mean(test_metrics[1])
    preds = jnp.array(test_metrics[2]).flatten()
    tgts = jnp.array(test_metrics[3]).flatten()
    return acc, loss, preds, tgts

In [9]:
grad_params, metrics = gd(SNN, params, shd_dl, epochs=300)


  0%|          | 0/300 [00:00<?, ?it/s]

In [10]:
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*metrics[-1]))


Performance: train_loss=1.735216736793518, val_acc=0.8795573115348816, val_loss=1.7838948965072632


In [11]:
acc, loss, preds, tgts = test_gd(SNN, grad_params, shd_dl)
print("Accuracy:", acc, "Loss:", loss)

Accuracy: 0.7519531 Loss: 1.9914175


## snnTorch SHD

In [12]:
num_hidden = 64
beta = 0.8
# Define Network
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = torch.nn.Linear(128, num_hidden)
        self.lif1 = snntorch.Leaky(beta=beta)
        self.fc2 = torch.nn.Linear(num_hidden, num_hidden)
        self.lif2 = snntorch.Leaky(beta=beta)
        self.fc3 = torch.nn.Linear(num_hidden, 20)
        self.lif3 = snntorch.Leaky(beta=beta, threshold=10e6)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        
        # Record the final layer
        V = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc2(spk2)
            spk3, mem3 = self.lif2(cur3, mem3)
            
            V.append(mem3)

        return torch.stack(V, dim=0)
        
# Load the network onto CUDA if available
net = Net().to(device)

In [13]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [14]:
num_epochs = 10
loss_hist = []
test_loss_hist = []
counter = 0
batch_size=256
num_steps=128

rng = jax.random.PRNGKey(0)

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    
    shuffle_rng = jax.random.fold_in(rng, iter_counter)
    train_batch = shd_dl.train_epoch(shuffle_rng)
    train_data, targets = train_batch
    
    
    # Minibatch training loop
    for data, targets in zip(train_data, targets):

        data = j2t_data(data).to(dtype=torch.float32)
        targets = j2t_targets(targets)
        # forward pass
        net.train()
        spk_rec, mem_rec = net(data)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        ## Test set
        #with torch.no_grad():
        #    net.eval()
        #    test_data, test_targets = next(iter(test_loader))
        #    test_data = test_data.to(device)
        #    test_targets = test_targets.to(device)

        #    # Test set forward pass
        #    test_spk, test_mem = net(test_data.view(batch_size, -1))

        #    # Test set loss
        #    test_loss = torch.zeros((1), dtype=dtype, device=device)
        #    for step in range(num_steps):
        #        test_loss += loss(test_mem[step], test_targets)
        #    test_loss_hist.append(test_loss.item())

        #    # Print train/test loss/accuracy
        #    if counter % 50 == 0:
        #        train_printer(
        #            data, targets, epoch,
        #            counter, iter_counter,
        #            loss_hist, test_loss_hist,
        #            test_data, test_targets)
        counter += 1
        iter_counter +=1

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 5.80 GiB total capacity; 2.31 GiB already allocated; 367.69 MiB free; 2.90 GiB allowed; 2.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
train_data