<a href="https://colab.research.google.com/github/forsaken27/net4/blob/main/SNN_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

SNN classifer for MNIST dataset

In [2]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive')

# TODO: Enter the foldername in your Drive where you have saved the unzipped
# assignment folder, e.g. 'cs231n/assignments/assignment2/'
FOLDERNAME = 'SNN_classifier/'
assert FOLDERNAME is not None, "[!] Enter the foldername."

# Now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load

# python files from within it.
import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install spikingjelly

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->spikingjelly)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->spikingjelly)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->spikingjelly)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->spikingjelly)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->spikingjelly)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->spikingjelly)
  Downl

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from spikingjelly.activation_based import neuron, functional, surrogate, encoding, layer

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np

USE_GPU = True
dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [4]:
transform = T.Compose([
                T.ToTensor(),
            ])
train_dataset = dset.MNIST(root='./drive/MyDrive/SNN_classifier/data', train=True, download=True, transform=transform)
test_dataset = dset.MNIST(root='./drive/MyDrive/SNN_classifier/data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(train_dataset.data.shape)
print(test_dataset.data.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])


In [7]:
class Flatten(nn.Module):
  def forward(self, x):
    return x.view(x.shape[0], -1)
tau = 2.0
learning_rate = 1e-3


net_4 = nn.Sequential(
    #(T, 64, 1, 28, 28)
    Flatten(),                         # (B, 1*28*28)
    layer.Linear(28*28, 100, bias=False),
    neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
    nn.Dropout(0.2),
    layer.Linear(100, 10, bias=False),
    neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
    )



optimizer = optim.Adam(net_4.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
encoder = encoding.PoissonEncoder()
epochs = 10

In [8]:
import time
from torch import amp
import torch.nn.functional as F
scaler = amp.GradScaler('cuda')
T = 30

def train_net(T, epochs):
  net_4.train()
  for epoch in range(epochs):
    start_time = time.time()
    net_4.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for img, label in train_loader:
      optimizer.zero_grad()
      img = img.to(device)
      label = label.to(device)
      label_onehot = F.one_hot(label, 10).float()

          # Mixed-precision training
      if scaler is not None:
        with amp.autocast('cuda'):
          out_fr = 0.
          # Run T time steps
          for t in range(T):
            encoded_img = encoder(img)
            out_fr += net_4(encoded_img)
          out_fr = out_fr / T
          # out_fr is tensor whose shape is [batch_size, 10]
          # The firing rate of 10 neurons in the output layer was recorded during the whole simulation period
          loss = F.mse_loss(out_fr, label_onehot)
          # The loss function is the MSE between the firing rate of the output layer and the true category.
          # The loss function will cause the firing rate of the correct neuron in the output layer to approach 1 when the label i is given, and the firing rate of the other neurons to approach 0.
        scaler.scale(loss).backward(retain_graph=True)
        scaler.step(optimizer)
        scaler.update()
      else:
        out_fr = 0.
        for t in range(T):
          encoded_img = encoder(img)
          out_fr += net_4(encoded_img)
        out_fr = out_fr / T
        loss = F.mse_loss(out_fr, label_onehot)
        loss.backward()
        optimizer.step()

      train_samples += label.numel()

      train_loss += loss.item() * label.numel()
      # The correct rate is calculated as follows. The subscript i of the neuron with the highest firing rate in the output layer is considered as the result of classification.
      train_acc += (out_fr.argmax(1) == label).float().sum().item()

      # After optimizing the parameters, the state of the network should be reset because the neurons of the SNN have “memory”.
      functional.reset_net(net_4)
    print(f"Epoch: {epoch}")
    print(f"Train Loss: {train_loss / train_samples}")
    print(f"Train Acc: {train_acc /  train_samples}")
    print(f"Time: {time.time() - start_time}")

train_net(T, epochs)



Epoch: 0
Train Loss: 0.021236207398275534
Train Acc: 0.8816
Time: 100.16350102424622
Epoch: 1
Train Loss: 0.010012805598974227
Train Acc: 0.9452833333333334
Time: 96.05491185188293
Epoch: 2
Train Loss: 0.007879040757318339
Train Acc: 0.9575333333333333
Time: 94.27358913421631
Epoch: 3
Train Loss: 0.006639468543728193
Train Acc: 0.9655166666666667
Time: 96.61773729324341
Epoch: 4
Train Loss: 0.005814811121672392
Train Acc: 0.9711666666666666
Time: 97.08322596549988
Epoch: 5
Train Loss: 0.005279079673749705
Train Acc: 0.9747
Time: 94.7437355518341
Epoch: 6
Train Loss: 0.00480752413480853
Train Acc: 0.9777333333333333
Time: 95.9030168056488
Epoch: 7
Train Loss: 0.004419301891575257
Train Acc: 0.9798666666666667
Time: 94.89069271087646
Epoch: 8
Train Loss: 0.004165172279067338
Train Acc: 0.9815666666666667
Time: 96.27680349349976
Epoch: 9
Train Loss: 0.003881313017755747
Train Acc: 0.9831833333333333
Time: 96.3009831905365


In [9]:
def test_net():
  net_4.eval()                         # evaluation mode
  test_acc = test_samples = 0
  T = 20
  start_time = time.time()

  with torch.no_grad():                # no gradients for speed / memory
      for img, label in test_loader:
          img,  label  = img.to(device), label.to(device)


          # ----- forward (AMP optional) -----
          with amp.autocast('cuda'):                  # works even if scaler is None
              out_fr = 0.
              for t in range(T):
                  encoded_img = encoder(img)              # or let net handle the loop
                  out_fr += net_4(encoded_img)
              out_fr = out_fr / T              # average firing rate

          # ----- metrics -----
          test_samples += label.size(0)
          test_acc     += (out_fr.argmax(1) == label).sum().item()

          functional.reset_net(net_4)          # clear membranes for next batch

  print(f"Test Acc: {test_acc / test_samples:.4f}")
  print(f"Time: {time.time() - start_time:.2f}s")

test_net()

Test Acc: 0.9726
Time: 6.22s


In [10]:
SAVE_PATH = 'best_model_net_4.pth'
torch.save({
    'model_state_dict': net_4.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epochs}, SAVE_PATH)
print("saved to", SAVE_PATH)

saved to best_model_net_4.pth
