In [1]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [2]:

# Spiking Basic Block for ResNet
class SpikingBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, beta=0.9, spike_grad=surrogate.fast_sigmoid(slope=25)):
        super(SpikingBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)  # LIF after residual addition

        self.downsample = None
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
            )

    def forward(self, x):
        residual = x if self.downsample is None else self.downsample(x)
        out = self.conv1(x)
        out = self.lif1(out)
        out = self.conv2(out)
        out = self.lif2(out)
        out = self.lif3(out + residual)
        return out

# Modified Spiking ResNet for MNIST
class SpikingResNet(nn.Module):
    def __init__(self, num_classes=10, beta=0.9, spike_grad=surrogate.fast_sigmoid(slope=25), num_steps=25):
        super(SpikingResNet, self).__init__()
        self.num_steps = num_steps
        # Adjusted for 1-channel 28x28 MNIST input
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)  # Smaller kernel, no stride
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)  # Reduced pooling size

        # Smaller channel sizes for MNIST
        self.layer1 = self._make_layer(16, 16, blocks=2, stride=1, beta=beta, spike_grad=spike_grad)
        self.layer2 = self._make_layer(16, 32, blocks=2, stride=2, beta=beta, spike_grad=spike_grad)
        self.layer3 = self._make_layer(32, 64, blocks=2, stride=2, beta=beta, spike_grad=spike_grad)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64 * SpikingBasicBlock.expansion, num_classes)
        self.lif_fc = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def _make_layer(self, in_channels, out_channels, blocks, stride, beta, spike_grad):
        layers = []
        layers.append(SpikingBasicBlock(in_channels, out_channels, stride, beta, spike_grad))
        for _ in range(1, blocks):
            layers.append(SpikingBasicBlock(out_channels, out_channels, beta=beta, spike_grad=spike_grad))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x shape: (batch, 1, 28, 28)
        # Repeat input over time steps: (num_steps, batch, 1, 28, 28)
        x = x.unsqueeze(0).repeat(self.num_steps, 1, 1, 1, 1)

        spk_rec = []
        mem_rec = []
        for step in range(self.num_steps):
            out = self.conv1(x[step])
            out = self.lif1(out)
            out = self.maxpool(out)

            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)

            out = self.avgpool(out)
            out = out.view(out.size(0), -1)
            spk, mem = self.lif_fc(self.fc(out))
            spk_rec.append(spk)
            mem_rec.append(mem)

        return torch.stack(spk_rec, dim=0), torch.stack(mem_rec, dim=0)


In [3]:

# Data loading and preprocessing for MNIST
batch_size = 128
data_path = './data'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))
])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)


In [4]:
def print_batch_accuracy(data, targets, train=False):
    spk_rec, _ = net(data)
    acc = SF.accuracy_rate(spk_rec, targets)
    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")
    return acc

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")


In [7]:
# Initialize network and optimizer
num_epochs = 1
num_steps = 25
net = SpikingResNet(num_classes=10, num_steps=num_steps).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss()  # Use ce_count_loss for efficiency with rate coding


In [None]:
data, targets = next(iter(train_loader))
spk_rec, mem_rec = net(data)
loss_fn(spk_rec, targets)

In [None]:
loss_val = loss_fn(spk_rec, targets)
optimizer.zero_grad()
loss_val.backward()
optimizer.step()

In [36]:
optimizer.step()

In [8]:
loss_hist = []
test_loss_hist = []
test_acc_hist = []
counter = 0

# Training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # Forward pass
        net.train()
        spk_rec, mem_rec = net(data)  # spk_rec: (num_steps, batch_size, num_classes)

        # Compute loss
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation and weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss
        loss_hist.append(loss_val.item())

        # Test set evaluation
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test forward pass
            test_spk, test_mem = net(test_data)
            test_loss = loss_fn(test_spk, test_targets)
            test_loss_hist.append(test_loss.item())
            test_acc = SF.accuracy_rate(test_spk, test_targets)
            test_acc_hist.append(test_acc)

            # Print progress
            if counter % 5 == 0:
                train_printer()
            counter += 1
            iter_counter += 1


Epoch 0, Iteration 0
Train Set Loss: 2.31
Test Set Loss: 2.31
Train set accuracy for a single minibatch: 5.47%
Test set accuracy for a single minibatch: 10.16%


Epoch 0, Iteration 5
Train Set Loss: 2.30
Test Set Loss: 2.30
Train set accuracy for a single minibatch: 10.94%
Test set accuracy for a single minibatch: 7.81%


Epoch 0, Iteration 10
Train Set Loss: 2.30
Test Set Loss: 2.30
Train set accuracy for a single minibatch: 12.50%
Test set accuracy for a single minibatch: 13.28%


Epoch 0, Iteration 15
Train Set Loss: 2.30
Test Set Loss: 2.31
Train set accuracy for a single minibatch: 10.16%
Test set accuracy for a single minibatch: 9.38%




KeyboardInterrupt: 

In [9]:
loss_hist

[2.3067405223846436,
 2.3067405223846436,
 2.305178165435791,
 2.3025853633880615,
 2.3025853633880615,
 2.3025853633880615,
 2.3025853633880615,
 2.305803060531616,
 2.3064279556274414,
 2.305490493774414,
 2.3039281368255615,
 2.304865598678589,
 2.304865598678589,
 2.3045530319213867,
 2.3036155700683594,
 2.304865598678589]

In [1]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate

class SpikingBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, beta=0.9, spike_grad=surrogate.fast_sigmoid(slope=25)):
        super(SpikingBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)  # LIF after residual addition

        self.downsample = None
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
                snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
            )

    def forward(self, x):
        residual = x if self.downsample is None else self.downsample(x)
        out = self.conv1(x)
        out = self.lif1(out)
        out = self.conv2(out)
        out = self.lif2(out)
        out = self.lif3(out + residual)
        return out

class SpikingResNet(nn.Module):
    def __init__(self, num_classes=10, beta=0.9, spike_grad=surrogate.fast_sigmoid(slope=25), num_steps=25):
        super(SpikingResNet, self).__init__()
        self.num_steps = num_steps
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 64, blocks=2, stride=1, beta=beta, spike_grad=spike_grad)
        self.layer2 = self._make_layer(64, 128, blocks=2, stride=2, beta=beta, spike_grad=spike_grad)
        self.layer3 = self._make_layer(128, 256, blocks=2, stride=2, beta=beta, spike_grad=spike_grad)
        self.layer4 = self._make_layer(256, 512, blocks=2, stride=2, beta=beta, spike_grad=spike_grad)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * SpikingBasicBlock.expansion, num_classes)
        self.lif_fc = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)

    def _make_layer(self, in_channels, out_channels, blocks, stride, beta, spike_grad):
        layers = []
        layers.append(SpikingBasicBlock(in_channels, out_channels, stride, beta, spike_grad))
        for _ in range(1, blocks):
            layers.append(SpikingBasicBlock(out_channels, out_channels, beta=beta, spike_grad=spike_grad))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x shape: (batch, channels, height, width)
        # For SNN, repeat input over time steps: (num_steps, batch, channels, height, width)
        x = x.unsqueeze(0).repeat(self.num_steps, 1, 1, 1, 1)

        spk_rec = []
        mem_rec = []
        for step in range(self.num_steps):
            out = self.conv1(x[step])
            out = self.lif1(out)
            out = self.maxpool(out)

            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)

            out = self.avgpool(out)
            out = out.view(out.size(0), -1)
            spk, mem = self.lif_fc(self.fc(out))
            spk_rec.append(spk)
            mem_rec.append(mem)

        return torch.stack(spk_rec, dim=0), torch.stack(mem_rec, dim=0)

# Example usage
if __name__ == "__main__":
    num_steps = 25
    model = SpikingResNet(num_classes=10, num_steps=num_steps)
    print(model)

    # Test with dummy input (batch_size=1, channels=3, 224x224 for ResNet input)
    dummy_input = torch.randn(1, 3, 224, 224)
    spk, mem = model(dummy_input)
    print(f"Output spikes shape: {spk.shape}")  # (num_steps, batch_size, num_classes)

SpikingResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (lif1): Leaky()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): SpikingBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lif1): Leaky()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lif2): Leaky()
      (lif3): Leaky()
    )
    (1): SpikingBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lif1): Leaky()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lif2): Leaky()
      (lif3): Leaky()
    )
  )
  (layer2): Sequential(
    (0): SpikingBasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (lif1): Leaky()
      (conv2): Conv2d(1