# Quant net 1.0

## Importing the python libraries

In [None]:
# PyTorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.utils.data import DataLoader , TensorDataset
# from torchvision import datasets, transforms
from torch.optim import Adam
from torch.utils.data import random_split

# Additional Imports
import snntorch as snn
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import math
import pandas as pd
import random

#metrics evaluator 
from sklearn.metrics import classification_report
# Set the seed for reproducibility of results
torch.manual_seed(0)

## loading the dataset and configuration

In [None]:
#loading the dataset
test_df = pd.read_csv("mnist_test.csv")
train_df = pd.read_csv("mnist_train.csv")

In [None]:
cols = test_df.columns
label = cols[0]
vals  = cols[1:]
x_train_temp = train_df[vals]   #.to_numpy()
y_train_temp = train_df[label]  #.to_numpy()
x_val = torch.tensor(x_train_temp[round(0.8*len(x_train_temp)):].to_numpy() , dtype = torch.float32)
y_val = torch.tensor(y_train_temp[round(0.8*len(x_train_temp)):].to_numpy() , dtype = torch.long)
x_train = torch.tensor(x_train_temp[:round(0.8*len(x_train_temp))].to_numpy(), dtype = torch.float32)
y_train = torch.tensor(y_train_temp[:round(0.8*len(x_train_temp))].to_numpy(), dtype = torch.long)
x_test = torch.tensor(test_df[vals].to_numpy(), dtype = torch.float32)
y_test = torch.tensor(test_df[label].to_numpy(), dtype = torch.long)

In [None]:
train_dataset = TensorDataset(x_train , y_train)
test_dataset = TensorDataset(x_test , y_test)
val_dataset = TensorDataset(x_val , y_val)

train_loader = DataLoader(train_dataset , batch_size= 512 , shuffle = True)
test_loader = DataLoader(test_dataset , batch_size= 512 , shuffle = True)
val_loader = DataLoader(val_dataset , batch_size= 512 , shuffle = True)

In [None]:
config_2 = {
    # SNN
    "threshold1": 2.3599835635698114,
    "threshold2": 7.985043705972782,
    "threshold3": 3.849629060468402,
    "beta": 0.44154740154430405,
    "num_steps": 5,

    # Network
    "batch_norm": False,
    "dropout": 0.3276864426153669,

    # Hyper Params
    "lr": 0.002,

    # Early Stopping
    "min_delta": 1e-6,
    "patience_es": 3,

    # Training
    "epochs": 100,

    #number of quantization 
    "n_quant" : 16,

    #precision
    "min" : -4,

    "max" : 4


}

## definition of submodules 

In [None]:
precision = (config_2["max"] - config_2["min"])/config_2["n_quant"]
multiplier = pow(10,math.log(1/precision)/math.log(10))
class quantize(Function):
    @staticmethod 
    def forward(weight_ref , input):
        return torch.round(input.clamp(min=config_2["min"] , max=config_2["max"])*multiplier)/multiplier
    
    @staticmethod
    def backward(weight_ref , gradient_out):
        gradient_in = gradient_out.clone()
        return gradient_in

In [None]:
class BinaryLinear(nn.Linear):

    def forward(self, input):
        bin_weights = quantize.apply(self.weight)
        if self.bias is None:
            return F.linear(input, bin_weights)
        else:
            return F.linear(input, bin_weights, self.bias)

    def reset_parameters(self):
        # Apply Xavier normal initialization
        torch.nn.init.xavier_normal_(self.weight)
        if self.bias is not None:
            # Initialize bias to zero
            torch.nn.init.constant_(self.bias, 0)

In [None]:
class quant_net(nn.Module):

    def __init__(self):
        super(quant_net, self).__init__()
        self.num_steps = 100#config["num_steps"]
        self.beta = 0.5 #config["beta"]
        self.drop_percent = 0.3 #config["dropout"]

        self.quant_fc_1 = BinaryLinear(in_features=28*28 , out_features=128)
        self.lif1 = snn.Leaky(beta = 0.5 ,learn_threshold= True , learn_beta= True)
        self.quant_fc_2 = BinaryLinear(in_features=128 , out_features=64)
        self.lif2 = snn.Leaky(beta = 0.5 ,learn_threshold= True , learn_beta= True)
        self.quant_fc_3 = BinaryLinear(in_features=64 , out_features=32)
        self.lif3 = snn.Leaky(beta = 0.5 ,learn_threshold= True , learn_beta= True)
        self.quant_fc_4 = BinaryLinear(in_features=32 , out_features=10)
        self.lif4 = snn.Leaky(beta = 0.5 ,learn_threshold= True , learn_beta= True)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self , input):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()

        spike4_rec = []
        mem4_rec = []

        for step in range(self.num_steps):
            
            val_1 = self.quant_fc_1(input)
            spike1 , mem1 = self.lif1(val_1,mem1)
            
            spike1 = self.dropout(spike1)
            val_2 = self.quant_fc_2(spike1)
            spike2 , mem2 = self.lif2(val_2,mem2)

            val_3 = self.quant_fc_3(spike2)
            spike3 , mem3 = self.lif3(val_3,mem3)

            spike3 = self.dropout(spike3)
            val_4 = self.quant_fc_4(spike3)
            spike4 , mem4 = self.lif4(val_4,mem4)

            spike4_rec.append(spike4)
            mem4_rec.append(mem4)

        return torch.stack(spike4_rec , dim = 0) , torch.stack(mem4_rec , dim = 0)

In [None]:
class EarlyStopping:
    def __init__(self, patience=config_2["patience_es"], min_delta=config_2["min_delta"]):
        # Early stops the training if validation loss doesn't improve after a given patience.
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_score is None:
            self.best_score = val_loss
        elif val_loss > self.best_score - self.min_delta:
            self.counter += 1
            print(f"Earlystop {self.counter}/{self.patience}\n")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_loss
            self.counter = 0

In [None]:
# Model initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = quant_net().to(device)

# Optimizer and Loss Function
optimizer = Adam(model.parameters(), lr=config_2["lr"])
criterion = nn.CrossEntropyLoss()

# Early Stopping
early_stopping = EarlyStopping(patience=config_2["patience_es"], min_delta=config_2["min_delta"])

In [None]:
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()
        spike_out, _ = model(data)
        output = spike_out.sum(dim=0)
        loss = criterion(output, targets)
        running_loss += loss.item()

        _, predicted_train = torch.max(output.data, 1)
        total_train += targets.size(0)
        correct_train += (predicted_train == targets).sum().item()

        loss.backward()
        optimizer.step()

    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct_train / total_train
    return train_loss, train_accuracy

In [None]:
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            spike_out, _ = model(data)
            output = spike_out.sum(dim=0)
            loss = criterion(output, targets)
            val_loss += loss.item()

            _, predicted_val = torch.max(output.data, 1)
            total_val += targets.size(0)
            correct_val += (predicted_val == targets).sum().item()

    val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * correct_val / total_val
    return val_loss, val_accuracy

In [None]:
train_losses, train_accuracies, val_losses, val_accuracies = [], [], [], []
best_val_accuracy = 0
model_path = "updated_best_BSNN_model.pth"
torch.autograd.set_detect_anomaly(True)
for epoch in range(config_2["epochs"]):
    train_loss, train_accuracy = train(model, train_loader, optimizer, criterion, device)
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)

    val_loss, val_accuracy = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    print(f"Epoch: {epoch + 1}, Training Loss: {train_loss:.5f}, Training Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.5f}, Validation Accuracy: {val_accuracy:.2f}%\n")

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), model_path)
        print(f"Saved model with improved validation accuracy: {val_accuracy:.2f}% \n")

    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("\nEarly stopping triggered")
        break

In [None]:
# Plotting training, validation, and test losses
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Plotting training, validation, and test accuracies
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title('Accuracy over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.show()

In [None]:
def test(model, test_loader, criterion, device, model_path="best_BSNN_model.pth"):

    # Initialize variables for test loss and accuracy
    test_loss = 0.0
    correct_test = 0
    total_test = 0

    # Restore best BSNN Model
    if os.path.isfile(model_path):
        model.load_state_dict(torch.load(model_path))
        print(f"Loaded saved model from {model_path}\n")

    # Switch model to evaluation mode
    model.eval()
    predicted_lst = torch.empty(0)
    target_lst = torch.empty(0)

    # Iterate over the test data
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)

            # Forward pass
            outputs, _ = model(data)  # Modify according to your model's output
            outputs = outputs.mean(dim=0)

            # Calculate loss
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_test += targets.size(0)
            predicted_lst = torch.cat((predicted_lst , predicted.cpu()) , dim = 0)
            target_lst = torch.cat((target_lst , targets.cpu()) , dim = 0)
            
            correct_test += (predicted == targets).sum().item()

    # Calculate average loss and accuracy
    test_loss /= len(test_loader)
    test_accuracy = 100 * correct_test / total_test

    print(classification_report(predicted_lst , target_lst))
    return test_loss, test_accuracy

In [None]:
test_loss, test_accuracy = test(model, val_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")