In [1]:
import snntorch as snn
from snntorch import utils
from snntorch import functional as SF

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [41]:
# dataloader arguments
batch_size = 128
data_path='./data'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [18]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(64*4*4, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)



    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk_rec = []
        mem_rec = []
        
        # record spike and membrane
        for step in range(num_steps):
            cur1 = F.max_pool2d(self.conv1(x), 2)
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = F.max_pool2d(self.conv2(spk1), 2)
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = self.fc1(spk2.view(batch_size, -1))
            spk3, mem3 = self.lif3(cur3)

            spk_rec.append(spk3)
            mem_rec.append(mem3)
            
        return torch.stack(spk_rec), torch.stack(mem_rec)

In [29]:
class Net(nn.Module):
    def __init__(self, beta, spike_grad, num_steps):
        super().__init__()
        self.num_steps = num_steps

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 12, kernel_size=5, padding=2)  # Output: 12x28x28
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 12, kernel_size=5, padding=2)  # Output: 12x14x14
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(12*7*7, 10)  # Adjusted for 12 channels, 7x7 after pooling
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def forward(self, x):
        # Initialize hidden states
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk_rec = []
        mem_rec = []

        for step in range(self.num_steps):
            # Conv1 + pooling
            cur1 = self.conv1(x)  # 12x28x28
            cur1_pooled = F.max_pool2d(cur1, 2)  # 12x14x14
            spk1, mem1 = self.lif1(cur1_pooled, mem1)

            # Conv2 + skip connection + pooling
            cur2 = self.conv2(spk1)  # 12x14x14
            cur2 = cur2 + cur1_pooled  # Skip connection: add conv1 output (12x14x14)
            cur2_pooled = F.max_pool2d(cur2, 2)  # 12x7x7
            spk2, mem2 = self.lif2(cur2_pooled, mem2)

            # Fully connected layer
            cur3 = self.fc1(spk2.view(spk2.size(0), -1))  # Dynamic batch size
            spk3, mem3 = self.lif3(cur3)

            spk_rec.append(spk3)
            mem_rec.append(mem3)

        return torch.stack(spk_rec)

In [33]:
# Minimal SCNN with skip connection
class SCNN(nn.Module):
    def __init__(self, beta=0.9, spike_grad=None, num_steps=25):
        super(SCNN, self).__init__()
        self.num_steps = num_steps
        # Conv1: 3x32x32 -> 16x32x32 (padding=1 to maintain size)
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # Conv2: 16x16x16 (after pooling) -> 16x16x16
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(16)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        # FC layer: 16x16x16 -> 10
        self.fc = nn.Linear(16 * 16 * 16, 10)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):
        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk_rec = []

        for step in range(self.num_steps):
            # Conv1 + pooling
            cur1 = self.conv1(x)
            cur1 = self.bn1(cur1)
            spk1, mem1 = self.lif1(cur1, mem1)
            spk1_pooled = F.max_pool2d(spk1, 2)  # 16x16x16

            # Conv2 + skip connection (add conv1 output after pooling)
            cur2 = self.conv2(spk1_pooled)
            cur2 = self.bn2(cur2)
            # cur2 = cur2 + spk1_pooled  # Skip connection: same dimensions (16x16x16)
            spk2, mem2 = self.lif2(cur2, mem2)

            # FC layer
            cur3 = self.fc(spk2.view(spk2.size(0), -1))
            spk3, mem3 = self.lif3(cur3, mem3)
            spk_rec.append(spk3)

        return torch.stack(spk_rec, dim=0)

In [38]:
# training loop
beta = 0.9
spike_grad = None
num_steps = 25

net = Net(beta, spike_grad, num_steps)
# net = SCNN()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss()

num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0


In [39]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64

# Data loading for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [42]:
for epoch in range(num_epochs):

    # Training loop
    for data, targets in iter(train_loader):

        # forward pass
        net.train()
        utils.reset(net)  # resets hidden states for all LIF neurons in net 
        spk_rec = net(data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)
        print(f'iter {counter}: loss {loss_val.item():.4f}')

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        counter += 1

iter 0: loss 2.3248
iter 1: loss 2.3191
iter 2: loss 2.2888
iter 3: loss 2.2224
iter 4: loss 2.2012
iter 5: loss 2.1336
iter 6: loss 2.1131
iter 7: loss 2.0952
iter 8: loss 2.1050
iter 9: loss 2.1173
iter 10: loss 2.0628
iter 11: loss 2.0534
iter 12: loss 2.1113
iter 13: loss 2.0825
iter 14: loss 2.0576


KeyboardInterrupt: 

In [None]:
cnn = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(64*4*4, 10),
                    snn.Leaky(beta=beta, init_hidden=True, output=True)
                    )

def forward_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net
  
  # record spike and membrane
  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

loss_fn = SF.ce_rate_loss()

cnn_optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0

# overall accuracy
# def batch_accuracy(train_loader, net, num_steps):
#   with torch.no_grad():
#     total = 0
#     acc = 0
#     net.eval()
#     train_loader = iter(train_loader)
#     for data, targets in train_loader:
#       spk_rec, _ = forward_pass(net, num_steps, data)
#       acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
#       total += spk_rec.size(1)
#   return acc/total

# Outer training loop
for epoch in range(num_epochs):

    # Training loop
    for data, targets in iter(train_loader):

        # forward pass
        cnn.train()
        spk_rec, _ = forward_pass(cnn, num_steps, data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)
        print(f'iter {counter}: loss {loss_val.item():.4f}')

        # Gradient calculation + weight update
        cnn_optimizer.zero_grad()
        loss_val.backward()
        cnn_optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        # if counter % 5 == 0:
        #     with torch.no_grad():
        #         cnn.eval()

        #         # Test set forward pass
        #         test_acc = batch_accuracy(test_loader, cnn, num_steps)
        #         print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
        #         test_acc_hist.append(test_acc.item())

        counter += 1