In [1]:
import torch
import snntorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_per_process_memory_fraction(0.85, device=0)



import numpy as np

import tonic
from tonic import datasets, transforms
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from collections import namedtuple

State = namedtuple("State", "obs labels")



In [2]:
class _SHD2Raster():
    """ 
    Tool for rastering SHD samples into frames. Packs bits along the temporal axis for memory efficiency. This means
        that the used will have to apply jnp.unpackbits(events, axis=<time axis>) prior to feeding the data to the network.
    """

    def __init__(self, encoding_dim, sample_T = 100):
        self.encoding_dim = encoding_dim
        self.sample_T = sample_T
        
    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        #return tensor[:self.sample_T,:]
        tensor = tensor[:self.sample_T,:]
        tensor = np.minimum(tensor, 1)
        #tensor = np.packbits(tensor, axis=0) pytorch does not have an unpack feature.
        return tensor

In [3]:
sample_T = 64
shd_timestep = 1e-6
shd_channels = 700
net_channels = 128
net_dt = 1/sample_T
batch_size = 256

obs_shape = tuple([net_channels,])
act_shape = tuple([20,])

transform = transforms.Compose([
    transforms.Downsample(
        time_factor=shd_timestep / net_dt,
        spatial_factor=net_channels / shd_channels
    ),
    _SHD2Raster(net_channels, sample_T=sample_T)
])

train_dataset = datasets.SHD("./data", train=True, transform=transform)
test_dataset = datasets.SHD("./data", train=False, transform=transform)



In [4]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

In [5]:
def shuffle(dataset):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    indices = torch.randperm(y.shape[0])[:-cutoff]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

In [6]:
test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_test, y_test = next(test_dl)
x_test, y_test = x_test.to(torch.uint8), y_test.to(torch.uint8)
x_test, y_test = x_test.to(device), y_test.to(device)
x_test, y_test = shuffle((x_test, y_test))

In [7]:
num_hidden = 64
beta = 0.8
# Define Network
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = torch.nn.Linear(128, num_hidden)
        self.lif1 = snntorch.Leaky(beta=beta)
        self.fc2 = torch.nn.Linear(num_hidden, num_hidden)
        self.lif2 = snntorch.Leaky(beta=beta)
        self.fc3 = torch.nn.Linear(num_hidden, 20)
        self.lif3 = snntorch.Leaky(beta=beta, threshold=10e6)

    def forward(self, x):

        x = x.float()
        x = x.permute(1,0,2)
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        
        # Record the final layer
        V = []

        # need to fix since data is not time leading axis...
        for step in x:
            cur1 = self.fc1(step)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            V.append(mem3)

        
        return torch.stack(V, axis=0).permute(1,0,2)
        
# Load the network onto CUDA if available
net = Net().to(device)

In [14]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [16]:
num_epochs = 300
loss_hist = []
counter = 0


# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    
    
    train_batch = shuffle((x_train, y_train))
    train_data, targets = train_batch
    
    
    # Minibatch training loop
    for data, targets in zip(train_data, targets):

        # forward pass
        net.train()
        out_V = net(data)
        # initialize the loss & sum over time
        loss_val = loss(torch.sum(out_V, axis=-2), targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

    # Store loss history for future plotting
    loss_hist.append(loss_val.item())


# Test set
with torch.no_grad():
    denominator = y_test[0]
    test_acc = 0
    batch_acc = []
    for test_data, test_targets in zip(x_test, y_test):
        net.eval()
        # Test set forward pass
        out_V = net(test_data)
        # Test set loss
        batch_acc.append( acc(torch.sum(out_V, axis=-2), test_targets) )
    
    test_acc = np.mean(batch_acc)


In [None]:
loss_hist

[14.016836166381836,
 5.47111177444458,
 4.38897180557251,
 3.6949737071990967,
 3.1036036014556885,
 2.8180136680603027,
 2.6146748065948486,
 2.524266004562378,
 2.364579200744629,
 2.174179792404175,
 1.7333829402923584,
 1.7091920375823975,
 1.74224853515625,
 1.5204150676727295,
 1.436014175415039,
 1.3159263134002686,
 1.1343700885772705,
 1.1117918491363525,
 1.13065505027771,
 1.2796425819396973,
 0.9392330050468445,
 1.1473838090896606,
 0.9924920797348022,
 1.1947990655899048,
 1.0341171026229858,
 1.0296378135681152,
 0.8375968337059021,
 0.8479127287864685,
 0.8093426823616028,
 1.1181706190109253,
 0.9630967378616333,
 0.7824838757514954,
 1.0526676177978516,
 0.7316159605979919,
 0.7704435586929321,
 0.8092421293258667,
 0.8359063863754272,
 0.9393559098243713,
 0.9334381818771362,
 0.8651061058044434,
 0.7047986388206482,
 0.8912121653556824,
 0.9294992685317993,
 0.7106137275695801,
 0.8002204895019531,
 0.6694325804710388,
 0.8686801195144653,
 1.0326778888702393,
 0.6

In [17]:
test_acc

0.66015625