In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
! pip install tonic snntorch icecream optuna --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/110.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.7/110.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.0/109.0 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.4/413.4 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.5/107.5 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.2/76.2 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.4/233.4 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━

# Imports

In [None]:
import tonic
import tonic.transforms as transforms
from tonic.datasets import DVSGesture

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.autograd import Function
from torch.optim import Adam

import snntorch as snn

import matplotlib.pyplot as plt

from icecream import ic
import time
import numpy as np

import optuna

import multiprocessing as mp

import os


# Config

In [None]:
config = {
    # Transforms
    "filter_time": 10_000,
    "new_size": (32, 32),
    "sensor_size": (32, 32, 2),
    "n_time_bins": 64, #16
    "transform_size": [128, 128, 2],

    # Batch
    "batch_size": 64,

    # make BSNN if True else SNN
    "is_binarized": True,

    # BSNN / SNN Values
    "thresh_1": 2,
    "thresh_2": 2,
    "thresh_3": 3,
    "beta": 0.4,
    "num_steps": 15,

    # Network
    "batch_norm": True,
    "dropout": 0.25,

    # Hyper Params
    "lr": 7e-3,

    # Training
    "epochs": 50,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),

    "dataset_path": '/content/drive/My Drive/Colab Notebooks/CSE290D/data',

}


# Data Preperation

## Resize Events

In [None]:
class ResizeEvents:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, events):
        # Scale factor for the coordinates
        scale_x = self.new_size[0] / 128
        scale_y = self.new_size[1] / 128

        if isinstance(events, np.ndarray) and events.dtype.names is not None:
            # Handling structured array with named fields
            events['x'] = np.clip((events['x'] * scale_x).astype(int), 0, self.new_size[0] - 1)
            events['y'] = np.clip((events['y'] * scale_y).astype(int), 0, self.new_size[1] - 1)
        else:
            # Handling standard numpy array
            events[:, 0] = np.clip((events[:, 0] * scale_x).astype(int), 0, self.new_size[0] - 1)
            events[:, 1] = np.clip((events[:, 1] * scale_y).astype(int), 0, self.new_size[1] - 1)

        return events



## Transforms

In [None]:
# transform = transforms.Compose([
#     transforms.Denoise(filter_time=config["filter_time"]),
#     transforms.ToFrame(n_time_bins=16, sensor_size=config["transform_size"])
# ])

In [None]:
transform = transforms.Compose([
    transforms.Denoise(filter_time=config["filter_time"]),
    ResizeEvents(new_size=config["new_size"]),
    transforms.ToFrame(n_time_bins=config["n_time_bins"], sensor_size=config["sensor_size"])
])


## Train / Test Sets

In [None]:
# def dataset_exists(path):
#     # Example check - adjust based on actual dataset files or directories
#     required_files = ['train', 'test']  # Placeholder names, replace with actual dataset file or folder names
#     return all(os.path.exists(os.path.join(path, f)) for f in required_files)


In [None]:
# start_time = time.time()

# dataset_path = config["dataset_path"]
# if not dataset_exists(dataset_path):
#     print("Downloading the dataset...")
#     train_set = DVSGesture(save_to=dataset_path, train=True, transform=transform)
#     test_set = DVSGesture(save_to=dataset_path, train=False, transform=transform)
# else:
#     print("Dataset already downloaded. Loading...")
#     train_set = DVSGesture(save_to=dataset_path, train=True, transform=transform, download=False)
#     test_set = DVSGesture(save_to=dataset_path, train=False, transform=transform, download=False)

# end_time = time.time()

In [None]:
start_time = time.time()

train_set = DVSGesture(save_to="./data", train=True, transform=transform)
test_set = DVSGesture(save_to="./data", train=False, transform=transform)

end_time = time.time()

print(f"Time to download data {int((end_time - start_time) // 60)}min {int((end_time - start_time) % 60)}sec")


Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38022171/ibmGestureTrain.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20240212/eu-west-1/s3/aws4_request&X-Amz-Date=20240212T233503Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=d53059755f7134d2b3e515e81e1df0fc12f145b6035f2e065cb278ae9f59ca83 to ./data/DVSGesture/ibmGestureTrain.tar.gz


  0%|          | 0/2443675558 [00:00<?, ?it/s]

Extracting ./data/DVSGesture/ibmGestureTrain.tar.gz to ./data/DVSGesture
Downloading https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/38020584/ibmGestureTest.tar.gz?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20240212/eu-west-1/s3/aws4_request&X-Amz-Date=20240212T233642Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=26b7f8d8dca6e5a1d128aed5804e1deb429cba13f12225e9e1fefbcc7a10869a to ./data/DVSGesture/ibmGestureTest.tar.gz


  0%|          | 0/691455012 [00:00<?, ?it/s]

Extracting ./data/DVSGesture/ibmGestureTest.tar.gz to ./data/DVSGesture
Time to download data 2min 6sec


In [None]:
ic(len(train_set))
ic(len(test_set))

ic| len(train_set): 1077
ic| len(test_set): 264


264

## Train / Val Split

In [None]:
train_size = int(0.8 * len(train_set))
val_size = len(train_set) - train_size

train_subset, val_subset = random_split(train_set, [train_size, val_size])


In [None]:
data, labels = next(iter(train_subset))

# Print the shapes
print("Data shape:", data.shape)


Data shape: (64, 2, 32, 32)


## Data Loaders Train, Val, Test

In [None]:
train_loader = DataLoader(train_subset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_subset, batch_size=config["batch_size"], shuffle=False)
test_loader = DataLoader(test_set, batch_size=config["batch_size"], shuffle=False)

# Define Binarization Model

## Binarization

In [None]:
class Binarize(Function):
  @staticmethod
  def forward(weight_ref, inpt):
    return inpt.sign().clamp(min=-1)

  @staticmethod
  def backward(weight_ref, gradient_out):
    gradient_in = gradient_out.clone()
    return gradient_in

## Binary Conv2d Layer

In [None]:
class BinaryConv2d(nn.Conv2d):
  def __init__(self, *kargs, **kwargs):
    super(BinaryConv2d, self).__init__(*kargs, **kwargs)

  def forward(self, inpt):
    binarized_weights = Binarize.apply(self.weight)
    return F.conv2d(inpt, binarized_weights)

  def reset_params(self):
    nn.init.xavier_normal_(self.weight)
    if self.bias is not None:
      nn.init.constant(self.bias, 0)


## Binary Linear Layer

In [None]:
class BinaryLinear(nn.Linear):
  def __init__(self, *kargs, **kwargs):
    super(BinaryLinear, self).__init__(*kargs, **kwargs)

    def forward(self, inpt):
        bin_weights = Binarize.apply(self.weight)
        if self.bias is None:
            return F.linear(inpt, bin_weights)
        else:
            return F.linear(inpt, bin_weights, self.bias)

    def reset_parameters(self):
        torch.nn.init.xavier_normal_(self.weight)
        if self.bias is not None:
            torch.nn.init.constant_(self.bias, 0)


## BSNN Architecture

In [None]:
class BSNN(nn.Module):
    def __init__(self, config):
        super(BSNN, self).__init__()

        self.thresh_1 = config["thresh_1"]
        self.thresh_2 = config["thresh_2"]
        self.thresh_3 = config["thresh_3"]
        self.beta = config["beta"]
        self.num_steps = config["num_steps"]

        self.batch_norm = config["batch_norm"]
        self.dropout = config["dropout"]

        self.bin_conv_1 = BinaryConv2d(in_channels=2, out_channels=16, kernel_size=3, padding="same", bias=False)
        self.batch_norm_1 = nn.BatchNorm2d(num_features=16)
        self.max_pool_1 = nn.MaxPool2d(kernel_size=2)
        self.lif_1 = snn.Leaky(beta=self.beta, threshold=self.thresh_1)

        self.bin_conv_2 = BinaryConv2d(in_channels=16, out_channels=32, kernel_size=3, padding="same", bias=False)
        self.batch_norm_2 = nn.BatchNorm2d(num_features=32)
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2)
        self.lif_2 = snn.Leaky(beta=self.beta, threshold=self.thresh_2)

        self.flatten = nn.Flatten()
        # self.bin_fully_connected_1 = BinaryLinear(in_features=128*15*15, out_features=11)
        self.bin_fully_connected_1 = BinaryLinear(in_features=32*6*6, out_features=11)
        self.dropout = nn.Dropout(self.dropout)
        self.lif_3 = snn.Leaky(beta=self.beta, threshold=self.thresh_3)


    def forward(self, inpt):
        inpt = inpt.float()
        batch_size, time_bins, polarities, height, width = inpt.size()
        #print(f"Input shape: {inpt.shape}")  # Print input shape

        # Merge time bins with batch dimension
        inpt = inpt.view(batch_size * time_bins, polarities, height, width)
        #print(f"Shape after merging time bins: {inpt.shape}")

        mem_1 = self.lif_1.init_leaky()
        mem_2 = self.lif_2.init_leaky()
        mem_3 = self.lif_3.init_leaky()

        spike_3_rec = []
        mem_3_rec = []

        for step in range(self.num_steps):
            current_1 = self.bin_conv_1(inpt)
            #print(f"Shape after bin_conv_1: {current_1.shape}")

            current_1 = self.batch_norm_1(current_1) if self.batch_norm else current_1
            #print(f"Shape after batch_norm_1: {current_1.shape}")

            current_1 = self.max_pool_1(current_1)
            #print(f"Shape after max_pool_1: {current_1.shape}")

            spike_1, mem_1 = self.lif_1(current_1, mem_1)

            current_2 = self.bin_conv_2(spike_1)
            #print(f"Shape after bin_conv_2: {current_2.shape}")

            current_2 = self.batch_norm_2(current_2) if self.batch_norm else current_2
            #print(f"Shape after batch_norm_2: {current_2.shape}")

            current_2 = self.max_pool_2(current_2)
            #print(f"Shape after max_pool_2: {current_2.shape}")

            spike_2, mem_2 = self.lif_2(current_2, mem_2)

            current_3 = self.flatten(spike_2)
            #print(f"Shape after flatten: {current_3.shape}")

            current_3 = self.bin_fully_connected_1(current_3)
            #print(f"Shape after bin_fully_connected_1: {current_3.shape}")

            current_3 = self.dropout(current_3)
            #print(f"Shape after dropout: {current_3.shape}")

            spike_3, mem_3 = self.lif_3(current_3, mem_3)

            spike_3_rec.append(spike_3)
            mem_3_rec.append(mem_3)

        # You might not see the shapes here because they're lists of tensors
        # but you can print the shape of the first item as an example
        # if spike_3_rec:
        #     print(f"Shape of the first item in spike_3_rec: {spike_3_rec[0].shape}")

        final_spike = torch.stack(spike_3_rec, dim=0).view(batch_size, time_bins, -1)
        # print(f"Final spike shape: {final_spike.shape}")
        final_mem = torch.stack(mem_3_rec, dim=0).view(batch_size, time_bins, -1)
        # print(f"Final mem shape: {final_mem.shape}")

        return final_spike, final_mem

## SNN Architecture

In [None]:
# class SNN(nn.Module):
#   def __init__(self, config):
#       super(SNN, self).__init__()

#       self.thresh_1 = config["thresh_1"]
#       self.thresh_2 = config["thresh_2"]
#       self.thresh_3 = config["thresh_3"]
#       self.beta = config["beta"]
#       self.num_steps = config["num_steps"]

#       self.batch_norm = config["batch_norm"]
#       self.dropout_rate = config["dropout"]

#       # Standard Convolutional Layers
#       self.conv_1 = nn.Conv2d(in_channels=2, out_channels=16, kernel_size=3, padding="same")
#       self.batch_norm_1 = nn.BatchNorm2d(num_features=16)
#       self.max_pool_1 = nn.MaxPool2d(kernel_size=2)
#       self.lif_1 = snn.Leaky(beta=self.beta, threshold=self.thresh_1)

#       self.conv_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding="same")
#       self.batch_norm_2 = nn.BatchNorm2d(num_features=32)
#       self.max_pool_2 = nn.MaxPool2d(kernel_size=2)
#       self.lif_2 = snn.Leaky(beta=self.beta, threshold=self.thresh_2)

#       self.flatten = nn.Flatten()
#       # self.fully_connected_1 = nn.Linear(in_features=128*15*15, out_features=11)
#       self.dropout = nn.Dropout(self.dropout_rate)
#       self.lif_3 = snn.Leaky(beta=self.beta, threshold=self.thresh_3)

#   def forward(self, inpt):
#       inpt = inpt.float()
#       batch_size, time_bins, polarities, height, width = inpt.size()

#       # Merge time bins with batch dimension
#       inpt = inpt.view(batch_size * time_bins, polarities, height, width)

#       mem_1 = self.lif_1.init_leaky()
#       mem_2 = self.lif_2.init_leaky()
#       mem_3 = self.lif_3.init_leaky()

#       spike_3_rec = []
#       mem_3_rec = []

#       for step in range(self.num_steps):
#           current_1 = self.conv_1(inpt)
#           current_1 = self.batch_norm_1(current_1) if self.batch_norm else current_1
#           current_1 = self.max_pool_1(current_1)
#           spike_1, mem_1 = self.lif_1(current_1, mem_1)

#           current_2 = self.conv_2(spike_1)
#           current_2 = self.batch_norm_2(current_2) if self.batch_norm else current_2
#           current_2 = self.max_pool_2(current_2)
#           spike_2, mem_2 = self.lif_2(current_2, mem_2)

#           current_3 = self.flatten(spike_2)
#           current_3 = self.fully_connected_1(current_3)
#           current_3 = self.dropout(current_3)
#           spike_3, mem_3 = self.lif_3(current_3, mem_3)

#           spike_3_rec.append(spike_3)
#           mem_3_rec.append(mem_3)

#       # Split the time dimension from the batch dimension
#       return torch.stack(spike_3_rec, dim=0).view(batch_size, time_bins, -1), torch.stack(mem_3_rec, dim=0).view(batch_size, time_bins, -1)


# Training

## Training Func

In [None]:
def train(
    train_loader,
    model,
    optimizer,
    loss_func,
    device=config["device"]):

  model.train()
  running_loss = 0.0
  correct = 0
  total = 0

  batch_num = 1
  total_batches = len(train_loader)

  for data, labels in train_loader:

    # Bypass when Training Batch < config["batch_size"]
    if data.size(0) != config["batch_size"]:
      continue

    data, labels = data.to(device), labels.to(device)

    optimizer.zero_grad()
    spike_out, _ = model(data)
    output = spike_out.mean(dim=0)

    loss = loss_func(output, labels)
    running_loss += loss.item()

    _, prediction = torch.max(output.data, 1)
    total += labels.size(0)
    correct += (prediction == labels).sum().item()

    loss.backward()
    optimizer.step()

    batch_num += 1

  train_loss = running_loss / len(train_loader)
  train_accuracy = 100 * correct / total

  return train_loss, train_accuracy


## Validation Func

In [None]:
def validate(
    val_loader,
    model,
    loss_func,
    device=config["device"]):

  model.eval()
  running_loss = 0.0
  correct = 0
  total = 0

  batch_num = 1
  total_batches = len(val_loader)

  with torch.no_grad():
    for data, labels in val_loader:

      # Bypass when Val Batch < config["batch_size"]
      if data.size(0) != config["batch_size"]:
        continue

      data, labels = data.to(device), labels.to(device)
      spike_out, _ = model(data)
      output = spike_out.mean(dim=0)
      loss = loss_func(output, labels)
      running_loss += loss.item()

      _, prediction = torch.max(output.data, 1)
      total += labels.size(0)
      correct += (prediction == labels).sum().item()

      batch_num += 1

  val_loss = running_loss / len(val_loader)
  val_accuracy = 100 * correct / total

  return val_loss, val_accuracy


## Training Set-Up

In [None]:
# device = config["device"]

# # if config["binarize"]:
# #     model = BSNN(config).to(config["device"])
# # else:
# #     model = SNN(config).to(config["device"])

# model = BSNN(config).to(config["device"])
# optimizer = Adam(model.parameters(), lr=config["lr"])
# loss_func = nn.CrossEntropyLoss()


## Training Loop

In [None]:
# train_losses, train_accuracies, val_losses, val_accuracies = [], [], [], []
# best_val_accuracy = 0
# model_path = "best_BSNN_model.pth"

# for epoch in range(config["epochs"]):
#   start_time = time.time()

#   train_loss, train_accuracy = train(train_loader=train_loader, model=model, optimizer=optimizer, loss_func=loss_func)
#   train_losses.append(train_loss)
#   train_accuracies.append(train_accuracy)

#   val_loss, val_accuracy = validate(val_loader=val_loader, model=model, loss_func=loss_func)
#   val_losses.append(val_loss)
#   val_accuracies.append(val_accuracy)

#   end_time = time.time()

#   print(f"Epoch: {epoch + 1}, Training Loss: {train_loss:.5f}, Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.5f}, Validation Accuracy: {val_accuracy:.2f}%")
#   print(f"Time complete Epoch {epoch + 1}: {int((end_time - start_time) // 60)}min {int((end_time - start_time) % 60)}sec")

#   if val_accuracy > best_val_accuracy:
#     best_val_accuracy = val_accuracy
#     torch.save(model.state_dict(), model_path)
#     print(f"Saved model with improved validation accuracy: {val_accuracy:.2f}% \n")


# Hyper Param Sweep

In [None]:
def objective(trial):
    config["thresh_1"] = trial.suggest_float("thresh_1", 1, 20)
    config["thresh_2"] = trial.suggest_float("thresh_2", 1, 20)
    config["thresh_3"] = trial.suggest_float("thresh_3", 1, 20)
    config["beta"] = trial.suggest_float("beta", 0.1, 0.9)
    config["lr"] = trial.suggest_float("lr", 1e-10, 1e-3)
    config["num_steps"] = trial.suggest_int("num_steps", 5, 20)
    config["dropout"] = trial.suggest_float("dropout", 0.1, 0.75)

    device = config["device"]

    model = BSNN(config).to(device)
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

    for epoch in range(config["epochs"]):
        train(train_loader, model, optimizer, loss_func)
        val_loss, _ = validate(val_loader, model, loss_func)

    return val_loss.item()




In [None]:
def run_study():
    study_name = "bsnn_gesture_study"
    storage_name = "sqlite:///{}.db".format(study_name)

    study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True, direction="minimize")
    study.optimize(objective, n_trials=10)

    study.trials_dataframe().to_csv("study_results.csv")

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for key, value in trial.params.items():
        print(f"      {key}: {value}")

In [None]:
run_study()

[I 2024-02-12 23:37:11,287] A new study created in RDB with name: bsnn_gesture_study


# Visualization and Analysis

In [None]:
# # Plotting training, validation, and test losses
# plt.figure(figsize=(10, 5))
# plt.plot(train_losses, label='Training Loss')
# plt.plot(val_losses, label='Validation Loss')
# plt.title('Loss over Epochs')
# plt.xlabel('Epochs')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()

# # Plotting training, validation, and test accuracies
# plt.figure(figsize=(10, 5))
# plt.plot(train_accuracies, label='Training Accuracy')
# plt.plot(val_accuracies, label='Validation Accuracy')
# plt.title('Accuracy over Epochs')
# plt.xlabel('Epochs')
# plt.ylabel('Accuracy (%)')
# plt.legend()
# plt.show()