In [1]:
import snntorch as snn
import torch
from torchvision import datasets, transforms
from snntorch import utils
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
from snntorch import spikegen
import numpy as np

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

  warn(


In [3]:
# Dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)


# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 50

num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Temporal Dynamics
num_steps = 25
beta = 0.95

# Define Network
class VisNet(nn.Module):
    def __init__(self, num_inputs):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_hidden)
        self.lif2 = snn.Leaky(beta=beta)
        self.fc3 = nn.Linear(num_hidden, num_outputs)
        self.lif3 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # record the middle layer
        # spk2_rec = []
        # mem2_rec = []

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        for step in range(num_steps):

            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = self.fc2(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            # spk2_rec.append(spk2)
            spk3_rec.append(spk3)
            # mem2_rec.append(mem2)
            mem3_rec.append(mem3)

        return  torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

# Load the network onto CUDA if available
net = VisNet(num_inputs=28*28).to(device)

# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}, Train loss = {loss_hist[counter]:.2f} Test loss = {test_loss_hist[counter]:.2f} \n")


loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0, Train loss = 172.18 Test loss = 165.13 

Epoch 0, Iteration 50, Train loss = 21.63 Test loss = 18.51 

Epoch 0, Iteration 100, Train loss = 20.06 Test loss = 19.36 

Epoch 0, Iteration 150, Train loss = 19.61 Test loss = 20.97 

Epoch 0, Iteration 200, Train loss = 14.11 Test loss = 15.08 

Epoch 0, Iteration 250, Train loss = 16.51 Test loss = 13.86 

Epoch 0, Iteration 300, Train loss = 14.83 Test loss = 13.82 

Epoch 0, Iteration 350, Train loss = 14.21 Test loss = 13.04 

Epoch 0, Iteration 400, Train loss = 13.01 Test loss = 15.66 

Epoch 0, Iteration 450, Train loss = 14.98 Test loss = 14.60 



In [6]:
torch.save(net.state_dict(), 'visnet_model_state_dict.pt')