<a href="https://colab.research.google.com/github/dattaayon7/Cifar10_torch/blob/main/CIFAR10_dataset_with_SnnTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install snntorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [None]:
# Leaky neuron model, overriding the backward pass with a custom function
class LeakySurrogate(nn.Module):
  def __init__(self, beta, threshold=1.0):
      super(LeakySurrogate, self).__init__()

      # initialize decay rate beta and threshold
      self.beta = beta
      self.threshold = threshold
      self.spike_op = self.SpikeOperator.apply

  # the forward function is called each time we call Leaky
  def forward(self, input_, mem):
    spk = self.spike_op((mem-self.threshold))  # call the Heaviside function
    reset = (spk * self.threshold).detach()  # removes spike_op gradient from reset
    mem = self.beta * mem + input_ - reset  # Eq (1)
    return spk, mem

  # Forward pass: Heaviside function
  # Backward pass: Override Dirac Delta with the Spike itself
  @staticmethod
  class SpikeOperator(torch.autograd.Function):
      @staticmethod
      def forward(ctx, mem):
          spk = (mem > 0).float() # Heaviside on the forward pass: Eq(2)
          ctx.save_for_backward(spk)  # store the spike for use in the backward pass
          return spk

      @staticmethod
      def backward(ctx, grad_output):
          (spk,) = ctx.saved_tensors  # retrieve the spike
          grad = grad_output * spk # scale the gradient by the spike: 1/0
          return grad

In [None]:
lif1 = LeakySurrogate(beta=0.9)

In [None]:
lif1 = snn.Leaky(beta=0.9)

In [None]:
# dataloader arguments
batch_size = 128
data_path='/data/CIFAR10'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,), (0.5))])

CIFAR10_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
CIFAR10_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
# Create DataLoaders
train_loader = DataLoader(CIFAR10_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(CIFAR10_test, batch_size=batch_size, shuffle=True, drop_last=True)




In [None]:
#   Convert from tensor image

import matplotlib.pyplot as plt
%matplotlib inline

def imshow(img):
    img = img / 2 + 0.5                      #unnormalizing
    plt.imshow(np.transpose(img, (1,2,0)))

In [None]:
#Convert to Grayscale

import torchvision as tv
import numpy as np
import torch.utils.data as data
dataDir         = '/data/CIFAR10'
trainTransform  = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=3),
                                    tv.transforms.ToTensor(),
                                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainSet        = tv.datasets.CIFAR10(dataDir, train=True, download=True, transform=trainTransform)
dataloader      = data.DataLoader(trainSet, batch_size=1, shuffle=True, num_workers=0)

images,labels  = iter(dataloader).next()

print (images.size())





Files already downloaded and verified
torch.Size([1, 3, 32, 32])


In [None]:
# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.95

In [None]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)

In [None]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target

def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

In [None]:
loss = nn.CrossEntropyLoss()

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [None]:
data, targets = next(iter(train_loader))
data = data.to(device)
targets = targets.to(device)

In [None]:

spk_rec, mem_rec = net(data.view(batch_size, -1))


In [None]:
print(mem_rec.size())
torch.Size([25, 128, 10])

torch.Size([25, 128, 10])


torch.Size([25, 128, 10])

In [None]:
# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

In [None]:
print(f"Training loss: {loss_val.item():.3f}")

Training loss: 61.524


In [None]:
print_batch_accuracy(data, targets, train=True)

Train set accuracy for a single minibatch: 7.81%


In [None]:
# clear previously stored gradients
optimizer.zero_grad()

# calculate the gradients
loss_val.backward()

# weight update
optimizer.step()

In [None]:
# calculate new network outputs using the same data
spk_rec, mem_rec = net(data.view(batch_size, -1))

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

In [None]:
print(f"Training loss: {loss_val.item():.3f}")
print_batch_accuracy(data, targets, train=True)

Training loss: 48.400
Train set accuracy for a single minibatch: 41.41%


In [None]:
num_epochs = 3
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 30 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

Epoch 0, Iteration 0
Train Set Loss: 55.12
Test Set Loss: 56.81
Train set accuracy for a single minibatch: 35.16%
Test set accuracy for a single minibatch: 27.34%


Epoch 0, Iteration 30
Train Set Loss: 43.74
Test Set Loss: 61.92
Train set accuracy for a single minibatch: 42.97%
Test set accuracy for a single minibatch: 29.69%


Epoch 0, Iteration 60
Train Set Loss: 53.97
Test Set Loss: 56.81
Train set accuracy for a single minibatch: 32.03%
Test set accuracy for a single minibatch: 29.69%


Epoch 0, Iteration 90
Train Set Loss: 56.18
Test Set Loss: 53.47
Train set accuracy for a single minibatch: 43.75%
Test set accuracy for a single minibatch: 31.25%


Epoch 0, Iteration 120
Train Set Loss: 49.41
Test Set Loss: 53.46
Train set accuracy for a single minibatch: 41.41%
Test set accuracy for a single minibatch: 35.16%


Epoch 0, Iteration 150
Train Set Loss: 47.89
Test Set Loss: 49.67
Train set accuracy for a single minibatch: 39.84%
Test set accuracy for a single minibatch: 33.59%


Epo

In [None]:
total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(CIFAR10_test, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
  net.eval()
  for data, targets in test_loader:
    data = data.to(device)
    targets = targets.to(device)

    # forward pass
    test_spk, _ = net(data.view(data.size(0), -1))

    # calculate total accuracy
    _, predicted = test_spk.sum(dim=0).max(1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()

In [None]:
print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total:.2f}%")

Total correctly classified test set images: 3400/10000
Test Set Accuracy: 34.00%
