In [1]:
# from google.colab import drive
# drive.mount('/content/drive')


In [2]:
# %pip install tonic snntorch icecream optuna --quiet

# Imports

In [3]:
import tonic
import tonic.transforms as transforms
from tonic.datasets import DVSGesture
from tonic import DiskCachedDataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader, Dataset
from torch.autograd import Function
from torch.optim import Adam

import torchvision


import snntorch as snn

import matplotlib.pyplot as plt

from icecream import ic
import time
import numpy as np

import optuna

import multiprocessing as mp

import os


# Config

In [4]:
config = {
    # Binning
    "filter_time": 10_000,
    "sensor_size": tonic.datasets.DVSGesture.sensor_size,
    "time_window": 1_000,

    # Batch
    "batch_size": 1,

    # BSNN / SNN Values
    "thresh_1": 2,
    "thresh_2": 2,
    "thresh_3": 3,
    "beta": 0.4,
    "num_steps": 10,

    # Network
    "batch_norm": True,
    "dropout": 0.25,

    # Hyper Params
    "lr": 7e-3,

    # Training
    "epochs": 10,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),

}


# Data Preperation

## Train / Val Split

In [5]:
def train_val_split(train_set):
    
    train_size = int(0.8 * len(train_set))
    val_size = len(train_set) - train_size
    train_set, val_set = random_split(train_set, [train_size, val_size])

    return train_set, val_set

## Download Data

In [6]:
# Initial transforms for preprocessing
transform_preprocess = tonic.transforms.Compose([
    transforms.Denoise(filter_time=config["filter_time"]),
    transforms.ToFrame(sensor_size=config["sensor_size"], time_window=config["time_window"]),
])

# Load the datasets
train_set = DVSGesture(save_to="/home/emercad3", transform=transform_preprocess, train=True)
test_set = DVSGesture(save_to="/home/emercad3", transform=transform_preprocess, train=False)

# Cache path correction (ensure unique paths for train and test)
cached_train_set_path = "/home/emercad3/cache/gesture/train"
cached_test_set_path = "/home/emercad3/cache/gesture/test"  # Corrected path for test cache

# Apply initial caching
cached_train_set = DiskCachedDataset(train_set, cache_path=cached_train_set_path)
cached_test_set = DiskCachedDataset(test_set, cache_path=cached_test_set_path)

# Augmentation transforms
transform_augment = tonic.transforms.Compose([
    torch.from_numpy,  # Ensure this is compatible with your data format
    torchvision.transforms.RandomRotation([-10, 10]),
])

# Reapply DiskCachedDataset with augmentation for training set
# Note: You may want to use a different cache path if the preprocessing changes significantly
cached_train_set = DiskCachedDataset(train_set, cache_path=cached_train_set_path, transform=transform_augment)

# For the test set, typically, you don't apply augmentation, so no need to redefine it

# DataLoaders with multithreading
train_loader = DataLoader(cached_train_set, batch_size=config["batch_size"], collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True, num_workers=0)
test_loader = DataLoader(cached_test_set, batch_size=config["batch_size"], collate_fn=tonic.collation.PadTensors(batch_first=False), num_workers=0)


# Binarization Model

## Binarization

In [7]:
class Binarize(Function):
  @staticmethod
  def forward(weight_ref, inpt):
    return inpt.sign().clamp(min=-1)

  @staticmethod
  def backward(weight_ref, gradient_out):
    gradient_in = gradient_out.clone()
    return gradient_in

## Binary Conv2d Layer

In [8]:
class BinaryConv2d(nn.Conv2d):
  def __init__(self, *kargs, **kwargs):
    super(BinaryConv2d, self).__init__(*kargs, **kwargs)

  def forward(self, inpt):
    binarized_weights = Binarize.apply(self.weight)
    return F.conv2d(inpt, binarized_weights)

  def reset_params(self):
    nn.init.xavier_normal_(self.weight)
    if self.bias is not None:
      nn.init.constant(self.bias, 0)


## Binary Linear Layer

In [9]:
class BinaryLinear(nn.Linear):
  def __init__(self, *kargs, **kwargs):
    super(BinaryLinear, self).__init__(*kargs, **kwargs)

    def forward(self, inpt):
        bin_weights = Binarize.apply(self.weight)
        if self.bias is None:
            return F.linear(inpt, bin_weights)
        else:
            return F.linear(inpt, bin_weights, self.bias)

    def reset_parameters(self):
        torch.nn.init.xavier_normal_(self.weight)
        if self.bias is not None:
            torch.nn.init.constant_(self.bias, 0)


## BSNN Architecture

In [10]:
class BSNN(nn.Module):
    def __init__(self, config):
        super(BSNN, self).__init__()

        self.thresh_1 = config["thresh_1"]
        self.thresh_2 = config["thresh_2"]
        self.thresh_3 = config["thresh_3"]
        self.beta = config["beta"]
        self.num_steps = config["num_steps"]

        self.batch_norm = config["batch_norm"]
        self.dropout = config["dropout"]

        self.bin_conv_1 = BinaryConv2d(in_channels=2, out_channels=12, kernel_size=5, bias=False)
        self.batch_norm_1 = nn.BatchNorm2d(num_features=12)
        self.max_pool_1 = nn.MaxPool2d(kernel_size=2)
        self.lif_1 = snn.Leaky(beta=self.beta, threshold=self.thresh_1)

        self.bin_conv_2 = BinaryConv2d(in_channels=12, out_channels=24, kernel_size=5, bias=False)
        self.batch_norm_2 = nn.BatchNorm2d(num_features=24)
        self.max_pool_2 = nn.MaxPool2d(kernel_size=2)
        self.lif_2 = snn.Leaky(beta=self.beta, threshold=self.thresh_2)

        self.flatten = nn.Flatten()
        # self.bin_fully_connected_1 = BinaryLinear(in_features=128*15*15, out_features=11)
        self.bin_fully_connected_1 = BinaryLinear(in_features=24*29*29, out_features=128)
        self.dropout = nn.Dropout(self.dropout)
        self.bin_fully_connected_2 = BinaryLinear(in_features=128, out_features=11)
        self.lif_3 = snn.Leaky(beta=self.beta, threshold=self.thresh_3)


    def forward(self, inpt):
        inpt = inpt.float()
        batch_size, time_bins, polarities, height, width = inpt.size()
        print(f"Input shape: {inpt.shape}")  # Print input shape

        # Merge time bins with batch dimension
        inpt = inpt.view(batch_size * time_bins, polarities, height, width)
        print(f"Shape after merging time bins: {inpt.shape}")

        mem_1 = self.lif_1.init_leaky()
        mem_2 = self.lif_2.init_leaky()
        mem_3 = self.lif_3.init_leaky()

        spike_3_rec = []
        mem_3_rec = []

        for step in range(self.num_steps):
            current_1 = self.bin_conv_1(inpt)
            print(f"Shape after bin_conv_1: {current_1.shape}")

            current_1 = self.batch_norm_1(current_1) if self.batch_norm else current_1
            print(f"Shape after batch_norm_1: {current_1.shape}")

            current_1 = self.max_pool_1(current_1)
            print(f"Shape after max_pool_1: {current_1.shape}")

            spike_1, mem_1 = self.lif_1(current_1, mem_1)

            current_2 = self.bin_conv_2(spike_1)
            print(f"Shape after bin_conv_2: {current_2.shape}")

            current_2 = self.batch_norm_2(current_2) if self.batch_norm else current_2
            print(f"Shape after batch_norm_2: {current_2.shape}")

            current_2 = self.max_pool_2(current_2)
            print(f"Shape after max_pool_2: {current_2.shape}")

            spike_2, mem_2 = self.lif_2(current_2, mem_2)

            current_3 = self.flatten(spike_2)
            print(f"Shape after flatten: {current_3.shape}")

            current_3 = self.bin_fully_connected_1(current_3)
            print(f"Shape after bin_fully_connected_1: {current_3.shape}")

            current_3 = self.dropout(current_3)
            print(f"Shape after dropout: {current_3.shape}")
            
            current_3 = self.bin_fully_connected_2(current_3)
            print(f"Shape after bin_fully_connected_2: {current_3.shape}")

            spike_3, mem_3 = self.lif_3(current_3, mem_3)

            spike_3_rec.append(spike_3)
            mem_3_rec.append(mem_3)

        return torch.stack(spike_3_rec, dim=0), torch.stack(mem_3_rec, dim=0)

# Training

## Training Func

In [11]:
def train(
    train_loader,
    model,
    optimizer,
    loss_func,
    device=config["device"]):

  model.train()
  running_loss = 0.0
  correct = 0
  total = 0

  for data, labels in train_loader:

    # # Bypass when Training Batch < config["batch_size"]
    # if data.size(0) != config["batch_size"]:
    #   continue

    data, labels = data.to(device), labels.to(device)

    optimizer.zero_grad()
    spike_out, _ = model(data)
    output = spike_out.mean(dim=0)

    loss = loss_func(output, labels)
    running_loss += loss.item()

    _, prediction = torch.max(output.data, 1)
    total += labels.size(0)
    correct += (prediction == labels).sum().item()

    loss.backward()
    optimizer.step()

  train_loss = running_loss / len(train_loader)
  train_accuracy = 100 * correct / total

  return train_loss, train_accuracy


## Validation Func

In [12]:
def validate(
    val_loader,
    model,
    loss_func,
    device=config["device"]):

  model.eval()
  running_loss = 0.0
  correct = 0
  total = 0


  with torch.no_grad():
    for data, labels in val_loader:

      # # Bypass when Val Batch < config["batch_size"]
      # if data.size(0) != config["batch_size"]:
      #   continue

      data, labels = data.to(device), labels.to(device)
      spike_out, _ = model(data)
      output = spike_out.mean(dim=0)
      loss = loss_func(output, labels)
      running_loss += loss.item()

      _, prediction = torch.max(output.data, 1)
      total += labels.size(0)
      correct += (prediction == labels).sum().item()


  val_loss = running_loss / len(val_loader)
  val_accuracy = 100 * correct / total

  return val_loss, val_accuracy


# Run Code

In [13]:
device = config["device"]

model = BSNN(config).to(device)
optimizer = Adam(model.parameters(), lr=config["lr"], betas=(0.9, 0.999))
loss_func = nn.CrossEntropyLoss()

In [14]:
train_losses, train_accuracies, val_losses, val_accuracies = [], [], [], []
# best_val_accuracy = 0
# model_path = "best_BSNN_model.pth"

for epoch in range(config["epochs"]):
  start_time = time.time()

  train_loss, train_accuracy = train(train_loader=train_loader, model=model, optimizer=optimizer, loss_func=loss_func)
  train_losses.append(train_loss)
  train_accuracies.append(train_accuracy)

  # val_loss, val_accuracy = validate(val_loader=val_loader, model=model, loss_func=loss_func)
  # val_losses.append(val_loss)
  # val_accuracies.append(val_accuracy)

  end_time = time.time()

  # print(f"Epoch: {epoch + 1}, Training Loss: {train_loss:.5f}, Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.5f}, Validation Accuracy: {val_accuracy:.2f}%")
  print(f"Epoch: {epoch + 1}, Training Loss: {train_loss:.5f}, Training Accuracy: {train_accuracy:.2f}%")
  print(f"Time complete Epoch {epoch + 1}: {int((end_time - start_time) // 60)}min {int((end_time - start_time) % 60)}sec")

  # if val_accuracy > best_val_accuracy:
  #   best_val_accuracy = val_accuracy
  #   torch.save(model.state_dict(), model_path)
  #   print(f"Saved model with improved validation accuracy: {val_accuracy:.2f}% \n")


Input shape: torch.Size([4920, 1, 2, 128, 128])
Shape after merging time bins: torch.Size([4920, 2, 128, 128])
Shape after bin_conv_1: torch.Size([4920, 12, 124, 124])
Shape after batch_norm_1: torch.Size([4920, 12, 124, 124])
Shape after max_pool_1: torch.Size([4920, 12, 62, 62])
Shape after bin_conv_2: torch.Size([4920, 24, 58, 58])
Shape after batch_norm_2: torch.Size([4920, 24, 58, 58])
Shape after max_pool_2: torch.Size([4920, 24, 29, 29])
Shape after flatten: torch.Size([4920, 20184])
Shape after bin_fully_connected_1: torch.Size([4920, 128])
Shape after dropout: torch.Size([4920, 128])
Shape after bin_fully_connected_2: torch.Size([4920, 11])
Shape after bin_conv_1: torch.Size([4920, 12, 124, 124])
Shape after batch_norm_1: torch.Size([4920, 12, 124, 124])
Shape after max_pool_1: torch.Size([4920, 12, 62, 62])
Shape after bin_conv_2: torch.Size([4920, 24, 58, 58])
Shape after batch_norm_2: torch.Size([4920, 24, 58, 58])
Shape after max_pool_2: torch.Size([4920, 24, 29, 29])
Shap

: 

: 

: 

# Hyper Param Sweep

In [None]:
# def objective(trial):
#     config["thresh_1"] = trial.suggest_float("thresh_1", 1, 20)
#     config["thresh_2"] = trial.suggest_float("thresh_2", 1, 20)
#     config["thresh_3"] = trial.suggest_float("thresh_3", 1, 20)
#     config["beta"] = trial.suggest_float("beta", 0.1, 0.9)
#     config["lr"] = trial.suggest_float("lr", 1e-10, 1e-3)
#     config["num_steps"] = trial.suggest_int("num_steps", 5, 10)
#     config["dropout"] = trial.suggest_float("dropout", 0.1, 0.75)

#     device = config["device"]

#     model = BSNN(config).to(device)
#     loss_func = nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

#     for epoch in range(config["epochs"]):
#         train(train_loader, model, optimizer, loss_func)
#         val_loss, _ = validate(val_loader, model, loss_func)

#     return val_loss.item()




In [None]:
# def run_study():
#     study_name = "bsnn_gesture_study"
#     storage_name = "sqlite:///{}.db".format(study_name)

#     study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True, direction="minimize")
#     study.optimize(objective, n_trials=10)

#     study.trials_dataframe().to_csv("study_results.csv")

#     print("Study statistics: ")
#     print("  Number of finished trials: ", len(study.trials))
#     print("  Best trial:")
#     trial = study.best_trial
#     print("    Value: ", trial.value)
#     print("    Params: ")
#     for key, value in trial.params.items():
#         print(f"      {key}: {value}")

In [None]:
# run_study()

# Visualization and Analysis

In [None]:
# # Plotting training, validation, and test losses
# plt.figure(figsize=(10, 5))
# plt.plot(train_losses, label='Training Loss')
# plt.plot(val_losses, label='Validation Loss')
# plt.title('Loss over Epochs')
# plt.xlabel('Epochs')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()

# # Plotting training, validation, and test accuracies
# plt.figure(figsize=(10, 5))
# plt.plot(train_accuracies, label='Training Accuracy')
# plt.plot(val_accuracies, label='Validation Accuracy')
# plt.title('Accuracy over Epochs')
# plt.xlabel('Epochs')
# plt.ylabel('Accuracy (%)')
# plt.legend()
# plt.show()