In [1]:
import snntorch as snn
import torch
from torchvision import datasets, transforms
from snntorch import utils

from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import matplotlib.pyplot as plt
from snntorch import spikegen
import numpy as np

dtype = torch.float
torch.set_default_dtype(dtype)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

  warn(


In [2]:
audio_stim = np.load('auditory_stimuli.npz')
X_train = audio_stim['X_train']
X_test = audio_stim['X_test']
y_train = audio_stim['y_train']
y_test = audio_stim['y_test']

In [7]:
# Dataloader arguments
batch_size = 128*2

# Define a transform
transform = transforms.Compose([
    transforms.ToTensor()
])

training_dataset = TensorDataset(
    torch.from_numpy(X_train.astype(np.float32)),
    torch.from_numpy(y_train.astype(np.float32))
)
testing_dataset = TensorDataset(
    torch.from_numpy(X_test.astype(np.float32)),
    torch.from_numpy(y_test.astype(np.float32))
)

# Create DataLoaders
train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(testing_dataset, batch_size=batch_size, shuffle=True, drop_last=True)


In [12]:
# Network Architecture
num_hidden = 3000
num_outputs = 100

num_epochs = 30
loss_hist = []
test_loss_hist = []
counter = 0

# Temporal Dynamics
beta = 0.95

# Define Network
class AudNet(nn.Module):
    def __init__(self, num_inputs):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_hidden)
        self.lif2 = snn.Leaky(beta=beta)
        self.fc3 = nn.Linear(num_hidden, num_hidden)
        self.lif3 = snn.Leaky(beta=beta)
        self.fc4 = nn.Linear(num_hidden, num_outputs)
        self.lif4 = snn.Leaky(beta=beta)

    def forward(self, x):

        x.to(torch.float32)

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif3.init_leaky()

        cur1 = self.fc1(x)
        spk1, mem1 = self.lif1(cur1, mem1)

        cur2 = self.fc2(spk1)
        spk2, mem2 = self.lif2(cur2, mem2)

        cur3 = self.fc3(spk2)
        spk3, mem3 = self.lif3(cur3, mem3)

        cur4 = self.fc3(spk3)
        spk4, mem4 = self.lif3(cur3, mem4)

        return  spk4, mem4

# Load the network onto CUDA if available
net = AudNet(num_inputs=np.size(X_train,1)).to(device)

# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    # if train:
    #     print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    # else:
    print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}, Train loss = {loss_hist[counter]:.2f} Test loss = {test_loss_hist[counter]:.2f} \n")
    print_batch_accuracy(data, targets)
    print('\n')


loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))


# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        # loss_val = torch.zeros((1), dtype=dtype, device=device)
        loss_val = loss(mem_rec, targets.type(torch.long))

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            # test_loss = torch.zeros((1), dtype=dtype, device=device)
            # for step in range(num_steps):
            test_loss = loss(test_mem, test_targets.type(torch.long))
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 10 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0, Train loss = 8.01 Test loss = 7.83 

Test set accuracy for a single minibatch: 7.42%


Epoch 1, Iteration 2, Train loss = 2.33 Test loss = 2.32 

Test set accuracy for a single minibatch: 10.16%


Epoch 2, Iteration 4, Train loss = 2.31 Test loss = 2.30 

Test set accuracy for a single minibatch: 8.98%


Epoch 3, Iteration 6, Train loss = 2.31 Test loss = 2.31 

Test set accuracy for a single minibatch: 7.81%


Epoch 5, Iteration 0, Train loss = 2.31 Test loss = 2.32 

Test set accuracy for a single minibatch: 9.77%


Epoch 6, Iteration 2, Train loss = 2.32 Test loss = 2.32 

Test set accuracy for a single minibatch: 9.77%


Epoch 7, Iteration 4, Train loss = 2.30 Test loss = 2.31 

Test set accuracy for a single minibatch: 9.77%


Epoch 8, Iteration 6, Train loss = 2.33 Test loss = 2.32 

Test set accuracy for a single minibatch: 12.11%


Epoch 10, Iteration 0, Train loss = 2.30 Test loss = 2.32 

Test set accuracy for a single minibatch: 10.16%


Epoch 11, Itera

In [None]:
output, _ = net(data.view(batch_size, -1))
output.shape

In [None]:
_, idx = output.max(1)
# acc = np.mean((targets == idx).detach().cpu().numpy())

In [None]:
test_mem.shape, test_targets.type(torch.long).shape

In [None]:
X_train.shape

In [None]:
print(data.shape, targets.shape)

In [None]:
mem_rec.shape

In [None]:
test_mem.shape

In [None]:
test_targets.shape

In [None]:
spk_rec, mem_rec = net(data.view(batch_size, -1))
spk_rec.shape, mem_rec.shape

In [None]:
for param in net.parameters():
    print(param.dtype)