<a href="https://colab.research.google.com/github/chardiwall/DPSNN/blob/main/DPSNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Project Discription
**Title:** DPSNN (A Differentially Private Spiking Neural Network with Temporal Enhanced Polling).

**Objectives:** Implementation of the [paper](https://arxiv.org/pdf/2205.12718.pdf).

**Discription:** The project is set in the context of the growing field of social robotics, which involves the deployment of robots in human-centric environments. These environments can range from healthcare facilities and educational institutions to customer service and domestic settings. The unique aspect of this project lies in its focus on privacy-preserving mechanisms, an increasingly critical concern in today's data-driven world.


[GitHub Link](https://github.com/chardiwall/DPSNN)

---
## Frameworks
SNN framwork: [SNNTorch](https://snntorch.readthedocs.io/en/latest/)

DP framwork: [Opacus](https://github.com/pytorch/opacus)

---
## Datasets

| Static      | Neuromorphic |
--------------|----------------
| [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)      |[N-MNIST](https://www.garrickorchard.com/datasets/n-mnist#h.p_ID_38) |
| [MNIST](https://yann.lecun.com/exdb/mnist/)      | [CIFAR10-DVS](https://figshare.com/articles/dataset/CIFAR10-DVS_New/4724671/2) |
| [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist)      |   |

---


# Loading Essential Libraries

In [1]:
# !pip install -q opacus
!pip install -q snntorch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.0/109.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.2/76.2 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Mounting Google Drive as drive
from google.colab import drive
drive.mount('drive')

Drive already mounted at drive; to attempt to forcibly remount, call drive.mount("drive", force_remount=True).


In [2]:
import os
import sys
import itertools
from copy import copy
from glob import glob
from tqdm import tqdm
from zipfile import ZipFile
from IPython.display import clear_output

import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from torchsummary import summary
from torchvision import datasets, transforms


import snntorch as snn
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils, backprop, surrogate

  from snntorch import utils, backprop, surrogate


# Unzipping and Loading Data

## Neuromorphic Data

In [None]:
# N-MNIST

# unzipping the data
base_dir_path = '/content/drive/MyDrive/Colab Notebooks/DPSNN/N-MNIST/'
train_zip_name = 'Copy of Train'
test_zip_name = 'Copy of Test'

for zip_file in [train_zip_name, test_zip_name]:
  zip_path = os.path.join(base_dir_path, zip_file+'.zip')

  with ZipFile(zip_path, 'r') as zip_ref:
    # extraction_path = os.path.join(base_dir_path, zip_file)
    # if not (os.path.exists(extraction_path)):
    #   os.mkdir(extraction_path)
    zip_ref.extractall(base_dir_path)


In [None]:
# CIFAR10-DVS
base_dir_path = '/content/drive/MyDrive/Colab Notebooks/DPSNN/CIFAR10-DVS'
file_name = 'CIFAR10-DVS.zip'

with ZipFile(os.path.join(base_dir_path, file_name), 'r') as zip_ref:
  zip_ref.extractall(base_dir_path)

## Static Data

In [3]:
def fetch_static_data(name = 'mnist', transform = None, batch_size = 64, shuffle = True):
  """
    Fetches static datasets such as MNIST, Fashion MNIST, and CIFAR10.

    Args:
    - name (str): Name of the dataset ('mnist', 'fashion-mnist', 'cifar10').
    - transform (callable, optional): A function/transform to apply to the data.
    - batch_size (int, optional): Number of samples in each batch.
    - shuffle (bool, optional): Set to True to shuffle the data.

    Returns:
    - train_loader (DataLoader): DataLoader for the training set.
    - test_loader (DataLoader): DataLoader for the test set.
    """


  # MNIST
  if name.lower() == 'mnist':
    train = datasets.MNIST(root = '.', train=True, transform=transform, download=True)
    test = datasets.MNIST(root = '.', train=False, transform=transform, download=True)

  # Fashion MNIST
  elif name.lower() == 'fashion-mnist':
    train = datasets.FashionMNIST(root = '.', train=True, transform=transform, download=True)
    test = datasets.FashionMNIST(root = '.', train=False, transform=transform, download=True)

  # CIFAR10
  elif name.lower() == 'cifar-10':
    train = datasets.CIFAR10('.', train = True, transform = transform, download = True)
    test = datasets.CIFAR10('.', train = False, transform = transform, download = True)

  else:
    raise Exception('Error! '+ name +' dataset not found...')

  train_loader = DataLoader(train, batch_size= batch_size, shuffle=shuffle, drop_last = True)
  test_loader = DataLoader(test, batch_size= batch_size, shuffle=shuffle, drop_last = True)

  return train_loader, test_loader

# Building SNN

## Initializing The Network


<details>
<summary> Show Code Block </summary>
```
class SNN(nn.Module):
  def __init__(self):
    super().__init__()

    # First convolutional block
    self.cb_01 = nn.Sequential(
        nn.Conv2d(in_channels = 1, out_channels = 12, kernel_size = 5),
        nn.GroupNorm(num_groups = 6, num_channels= 12),
        snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),
    )

    # Second convolutional block
    self.cb_02 = nn.Sequential(
        nn.Conv2d(in_channels = 12, out_channels = 64, kernel_size = 5),
        nn.GroupNorm(num_groups = 8, num_channels = 64),
        snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),
    )

    # Average pooling layer
    self.pooling = nn.AvgPool2d(3)

    # Third convolutional block
    self.cb_03 = nn.Sequential(
        nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 5),
        nn.GroupNorm(num_groups = 8, num_channels = 128),
        snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),
    )

    # Fully Connected Layer
    self.FC = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features = 128 * 2 * 2, out_features= 10),
        snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True, output=True)
    )
    
  def forward(self, x):
    # Forward pass through the network
    x = self.cb_01(x)
    x = self.cb_02(x)
    x = self.pooling(x)
    x = self.cb_03(x)
    x = self.FC(x)

    return x
```
</details>

In [4]:
# using gpu if available
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

### Mnist Architecture

In [5]:
mnist_net = nn.Sequential(
  # First convolution block
  nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 7),
  nn.GroupNorm(num_groups = 8, num_channels= 32),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(2),

  #Second convolution block
  nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4),
  nn.GroupNorm(num_groups = 8, num_channels = 64),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(2),

  # Fully connected or output block
  nn.Flatten(),
  nn.Linear(in_features = 64 * 4 * 4, out_features= 10),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True, output=True)
).to(device)

### CIFAR-10

In [6]:
cifar_10 = nn.Sequential(
  # First convolution block
  nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 64),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Second convolution block
  nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 64),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(5),

  # Third convolution block
  nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 128),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Forth convolution block
  nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 128),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(5),

  # Fifth convolution block
  nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 128),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Sixth convolution block
  nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 128),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(5),

  # Seventh convolution block
  nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 256),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # eighth convolution block
  nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 256),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(4),

  # Fully connected or output block
  nn.Flatten(),
  nn.Linear(in_features = 256 * 4 * 4, out_features= 10),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True, output=True)
).to(device)

### CIFAR-10-DVS

In [7]:
cifar_10_dvs = nn.Sequential(
  # First convolution block
  nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 64),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Second convolution block
  nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 64),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(4),

  # Third convolution block
  nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 128),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(4),

  # Forth convolution block
  nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 256),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Fifth convolution block
  nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 256),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(4),

  # Sixth convolution block
  nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 512),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(4),

   # Seventh convolution block
  nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 1024),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Fifth convolution block
  nn.Conv2d(in_channels = 1024, out_channels = 1024, kernel_size = 3, stride = 1),
  nn.GroupNorm(num_groups = 8, num_channels= 1024),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True),

  # Pooling
  nn.AvgPool2d(4),

  # Fully connected or output block
  nn.Flatten(),
  nn.Linear(in_features = 1024 * 4 * 4, out_features= 10),
  snn.Leaky(beta = 0.5, spike_grad = surrogate.fast_sigmoid(slope = 25), init_hidden = True, output=True)
).to(device)

## Training the Network

In [8]:
# tranformer
transform = transforms.Compose([
  transforms.Resize(size = (28, 28)),
  transforms.Grayscale(),
  transforms.ToTensor(),
  transforms.Normalize((0,), (1,))
])

In [9]:
def plot_metrics(train_acc, train_loss, test_acc, test_loss):
    # Set Seaborn style
    sns.set(style="whitegrid")

    # Create figure and axes
    fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True)

    # Plot training accuracy
    sns.lineplot(x=range(len(train_acc)), y=train_acc, ax=axes[0], label='Train Accuracy')
    # Plot testing accuracy
    sns.lineplot(x=range(len(test_acc)), y=test_acc, ax=axes[0], label='Test Accuracy')

    # Plot training loss
    sns.lineplot(x=range(len(train_loss)), y=train_loss, ax=axes[1], label='Train Loss')
    # Plot testing loss
    sns.lineplot(x=range(len(test_loss)), y=test_loss, ax=axes[1], label='Test Loss')

    # Set titles and labels
    axes[0].set_title('Accuracy')
    axes[1].set_title('Loss')
    axes[1].set_xlabel('Epoch')

    # Show legend
    axes[0].legend()
    axes[1].legend()

    # Show plot
    plt.tight_layout()
    plt.show()

In [16]:
def fwrd_pass(net, num_steps, data):
  mem_rec = []
  spk_rec = []
  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      spk_out, mem_out = net(data)
      spk_rec.append(spk_out)
      mem_rec.append(mem_out)

  return torch.stack(spk_rec), torch.stack(mem_rec)

In [10]:
def train_network(net = mnist_net, dataset_name = 'mnist', epoches = 1, lr = 1e-5, betas=(0.9, 0.999), transform = transform, batch_size = 128, shuffle = True):
  # Tracking training and evaluating record
  test_acc_hist = []
  train_acc_hist = []

  test_loss_hist = []
  train_loss_hist = []


  loss = SF.ce_rate_loss()
  optimizer = torch.optim.AdamW(net.parameters(), lr = lr, betas = betas)

  # Fetching dataset
  train_loader, test_loader = fetch_static_data( name = dataset_name, transform = transform, batch_size = batch_size, shuffle = shuffle)
  train_loader_len, test_loader_len = len(train_loader), len(test_loader)


  clear_output()
  # Creating Training and evaluation loop
  for epoch in range(epoches):
    print('-'*10, 'ITERATION', epoch, '-'*10, )

    # Training model
    net.train()
    acc = 0
    total = 0
    print('Training...')
    for batch_no, (X, y) in enumerate(train_loader):
      # sending the data to same device
      X, y = X.to(device), y.to(device)

      # forward pass
      spk_rec, _ = fwrd_pass(net, 10, X)

      # initialize the loss & sum over time
      loss_val = loss(spk_rec, y)

      # Gradient calculation + weight update
      optimizer.zero_grad()
      loss_val.backward()
      optimizer.step()

      # calculating accuracy
      acc += SF.accuracy_rate(spk_rec, y) * spk_rec.size(1)
      total += spk_rec.size(1)

      # showing progress
      if (batch_no+1) % (train_loader_len//10) == 0:
        train_acc = acc/total
        print(f'{(batch_no // (train_loader_len//10)+1)*10}% training completed with {train_acc * 100:.2f}% accuracy and {loss_val.item():.2f} loss')

        # Store accuracy and loss history for future plotting
        train_loss_hist.append(loss_val.item())
        train_acc_hist.append(train_acc.item()*100)


    # Evaluating model
    print('Testing...')
    net.eval
    with torch.inference_mode():
      acc = 0
      total = 0
      for batch_no, (X, y) in enumerate(test_loader):
        # sending the data to same device
        X, y = X.to(device), y.to(device)

        # forward pass
        spk_rec, _ = fwrd_pass(net, 10, X)

        # initialize the loss & sum over time
        loss_val = loss(spk_rec, y)


        # calculating accuracy
        acc += SF.accuracy_rate(spk_rec, y) * spk_rec.size(1)
        total += spk_rec.size(1)

        # showing progress
        if (batch_no+1) % (test_loader_len//10) == 0:
          test_acc = acc/total
          print(f'{(batch_no // round(test_loader_len/10)+1)*10}% testing completed with {test_acc * 100:.2f}%  accuracy and {loss_val.item():.2f} loss')

          # Store accuracy and loss history for future plotting
          test_loss_hist.append(loss_val.item())
          test_acc_hist.append(test_acc.item()*100)

  return {
      'net': net,
      'train accuracy' : train_acc_hist,
      'test accuracy': test_acc_hist,
      'train loss': train_loss_hist,
      'test loss': test_loss_hist
      }

In [11]:
summary(mnist_net, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 22, 22]           1,600
         GroupNorm-2           [-1, 32, 22, 22]              64
             Leaky-3           [-1, 32, 22, 22]               0
         AvgPool2d-4           [-1, 32, 11, 11]               0
            Conv2d-5             [-1, 64, 8, 8]          32,832
         GroupNorm-6             [-1, 64, 8, 8]             128
             Leaky-7             [-1, 64, 8, 8]               0
         AvgPool2d-8             [-1, 64, 4, 4]               0
           Flatten-9                 [-1, 1024]               0
           Linear-10                   [-1, 10]          10,250
            Leaky-11       [[-1, 10], [-1, 10]]               0
Total params: 44,874
Trainable params: 44,874
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/ba

In [None]:
mnist_history = train_network(net = copy(mnist_net), batch_size = 512, epoches=20)

plot_metrics(mnist_history['train accuracy'], mnist_history['train loss'], mnist_history['test accuracy'], mnist_history['test loss'])

---------- ITERATION 0 ----------
Training...
10% training completed with 9.46% accuracy and 2.30 loss


In [None]:
summary(mnist_net, (1, 28, 28))

In [None]:
fashion_mnist_history = train_network(net = copy(mnist_net),  dataset_name = 'fashion-mnist', batch_size = 512, , epoches=20)

plot_metrics(fashion_mnist_history['train accuracy'], fashion_mnist_history['train loss'], fashion_mnist_history['test accuracy'], fashion_mnist_history['test loss'])

In [None]:
summary(mnist_net, (1, 28, 28))

In [None]:
cifar_10 = train_network(net = copy(mnist_net),  dataset_name = 'cifar-10', batch_size = 512, epoches=20)

plot_metrics(cifar_10['train accuracy'], cifar_10['train loss'], cifar_10['test accuracy'], cifar_10['test loss'])

In [None]:
l.