In [None]:
import os
import random
import glob
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import seaborn as sns

import nni
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from snntorch import surrogate
import snntorch as snn
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import ConfusionMatrix
import csv

In [None]:
# Hyperparameters and Constants
input_size = 16
hidden_size = 24
output_size = 4

num_epochs = 2
batch_size = 32
scheduler_step_size = 3
scheduler_gamma = 0.5 # halving the LR every step_size epochs

# Adjusted after training for test
output_threshold = 1e7
hidden_threshold = 1

num_workers = max(1, os.cpu_count() - 1)
hidden_reset_mechanism = 'subtract'
output_reset_mechanism = 'none'

Vmem_shift_for_MSELoss = 0.2

In [None]:
CLASSES = ["CAR", "STREET", "HOME", "CAFE"]

def get_label_from_folder(folder_name):
    # folder_name format: "CAR-something", extract "CAR"
    base = os.path.basename(folder_name)
    label = base.split('-')[0]
    return label

def one_hot_encode(label):
    # from label in CLASSES -> CAR will be [1, 0, 0, 0], STREET will be [0, 1, 0, 0], following the CLASSES order
    idx = CLASSES.index(label)
    vec = np.zeros(len(CLASSES), dtype=np.float32)
    vec[idx] = 1.0
    return vec

def get_nni_params():
    params = {
        'learning_rate': 0.01,
        'optimizer_betas': (0.9, 0.99),
        'fast_sigmoid_slope': 10,
    }
    tuner_params = nni.get_next_parameter()
    params.update(tuner_params)
    return params

In [None]:
class FakeQuantize5bit(nn.Module):
    def __init__(self):
        super().__init__()
        self.levels = 32  # 2^5 = 32 levels = 5 bits
        self.w_min = 0.001
        self.w_max = 1

    def forward(self, input):
        input_clamped = torch.clamp(input, self.w_min, self.w_max)
        scale = (self.w_max - self.w_min) / (self.levels - 1)
        quant_indices = torch.round((input_clamped - self.w_min) / scale)
        quant_w = quant_indices * scale + self.w_min
        return quant_w

class QuantLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super().__init__(in_features, out_features, bias=bias)
        self.fake_quant = FakeQuantize5bit()

    def forward(self, input):
        quant_weight = self.fake_quant(self.weight)
        return F.linear(input, quant_weight, self.bias)

def finalize_quantization(model):
    with torch.no_grad():
        model.eval()
        for module in model.modules():
            if hasattr(module, 'fake_quant') and hasattr(module, 'weight'):
                quantized_weight = module.fake_quant(module.weight)
                module.weight.data.copy_(quantized_weight)

In [None]:
class SNNQUT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, 
                 beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output, 
                 hidden_reset_mechanism, output_reset_mechanism, 
                 hidden_threshold, output_threshold, fast_sigmoid_slope):
        super().__init__()
        self.fc1 = QuantLinear(input_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(beta=beta_hidden_1, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc2 = QuantLinear(hidden_size, hidden_size, bias=False)
        self.lif2 = snn.Leaky(beta=beta_hidden_2, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc3 = QuantLinear(hidden_size, hidden_size, bias=False)
        self.lif3 = snn.Leaky(beta=beta_hidden_3, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc4 = QuantLinear(hidden_size, output_size, bias=False)
        self.lif4 = snn.Leaky(beta=beta_output, reset_mechanism=output_reset_mechanism,
                              threshold=output_threshold)

        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.xavier_normal_(self.fc3.weight)
        nn.init.xavier_normal_(self.fc4.weight)

    def forward(self, x):
        # x: (batch, time_steps, 16)
        x = x.float()
        batch_size, time_steps, _ = x.shape
        device = x.device

        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=device)
        mem4 = torch.zeros(batch_size, self.fc4.out_features, device=device)

        spk1_rec, spk2_rec, spk3_rec, mem4_rec = [], [], [], []

        for step in range(time_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            cur4 = self.fc4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)

            spk1_rec.append(spk1)
            spk2_rec.append(spk2)
            spk3_rec.append(spk3)
            spk4_rec.append(spk4)
            mem4_rec.append(mem4)

        return (torch.stack(spk1_rec, dim=0), torch.stack(spk2_rec, dim=0),
                torch.stack(spk3_rec, dim=0), torch.stack(spk4_rec, dim=0),
                torch.stack(mem4_rec, dim=0))


In [None]:
class Lightning_SNNQUT(pl.LightningModule):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        beta_hidden_1,
        beta_hidden_2,
        beta_hidden_3,
        beta_output,
        hidden_reset_mechanism,
        output_reset_mechanism,
        hidden_threshold,
        output_threshold,
        learning_rate,
        scheduler_step_size,
        scheduler_gamma,
        optimizer_betas,
        fast_sigmoid_slope,
    ):
        super().__init__()
        self.save_hyperparameters(
            'input_size',
            'hidden_size',
            'output_size',
            'hidden_reset_mechanism',
            'output_reset_mechanism',
            'hidden_threshold',
            'output_threshold',
            'learning_rate',
            'scheduler_step_size',
            'scheduler_gamma',
            'optimizer_betas',
            'fast_sigmoid_slope',
        )
        self.beta_hidden_1 = beta_hidden_1
        self.beta_hidden_2 = beta_hidden_2
        self.beta_hidden_3 = beta_hidden_3
        self.beta_output = beta_output

        self.train_confmat = ConfusionMatrix(task='multiclass',num_classes=self.hparams.output_size)
        self.val_confmat = ConfusionMatrix(task='multiclass',num_classes=self.hparams.output_size)
        self.test_confmat = ConfusionMatrix(task='multiclass',num_classes=self.hparams.output_size)

        self.model = SNNQUT(
            input_size=self.hparams.input_size,
            hidden_size=self.hparams.hidden_size,
            output_size=self.hparams.output_size,
            beta_hidden_1=self.beta_hidden_1,
            beta_hidden_2=self.beta_hidden_2,
            beta_hidden_3=self.beta_hidden_3,
            beta_output=self.beta_output,
            hidden_reset_mechanism=self.hparams.hidden_reset_mechanism,
            output_reset_mechanism=self.hparams.output_reset_mechanism,
            output_threshold=self.hparams.output_threshold,
            hidden_threshold=self.hparams.hidden_threshold,
            fast_sigmoid_slope=self.hparams.fast_sigmoid_slope,
        )

        self.loss_function = nn.MSELoss()

        # Remove unused commented-out code, ensure comments accurately reflect behavior

    def forward(self, x):
        return self.model(x)

    def compute_loss_and_metrics(self, spk4_rec, mem4_rec, labels, mode='Vmem'):
        """
        Compute loss and classification metrics.
        mode='Vmem' uses mem4_rec for classification (e.g., training/validation).
        mode='spike' uses spk4_rec for classification (e.g., testing if desired).
        """
        # Compute loss using mem4_rec
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)
        loss = self.loss_function(mem4_rec, (labels_expanded + Vmem_shift_for_MSELoss))

        # Compute predictions
        if mode == 'Vmem':
            # Use membrane potentials summed over time for prediction
            final_out = mem4_rec.sum(0)  # (batch, output_size)
        else:
            # Use spikes summed over time for prediction
            final_out = spk4_rec.sum(0)  # (batch, output_size)

        _, predicted = final_out.max(-1)
        _, targets = labels.max(-1)

        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total if total > 0 else 0.0

        return loss, accuracy, predicted, targets

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1_rec, spk2_rec, spk3_rec, spk4_rec, mem4_rec = self(inputs)

        # Use 'mem' mode during training for consistency with training approach
        loss, accuracy, predicted, targets = self.compute_loss_and_metrics(spk4_rec, mem4_rec, labels, mode='Vmem')

        self.train_confmat.update(predicted, targets)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_accuracy', accuracy * 100, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1_rec, spk2_rec, spk3_rec, spk4_rec, mem4_rec = self(inputs)

        # For validation, also use 'mem' mode to stay consistent
        loss, accuracy, predicted, targets = self.compute_loss_and_metrics(spk4_rec, mem4_rec, labels, mode='Vmem')

        self.val_confmat.update(predicted, targets)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_accuracy', accuracy * 100, on_epoch=True, prog_bar=True)

        return {'val_loss': loss, 'val_accuracy': accuracy}

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1_rec, spk2_rec, spk3_rec, spk4_rec, mem4_rec = self(inputs)

        # For testing, use 'spike' mode if you want to classify based on spikes
        loss, accuracy, predicted, targets = self.compute_loss_and_metrics(spk4_rec, mem4_rec, labels, mode='spike')

        self.test_confmat.update(predicted, targets)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_accuracy', accuracy * 100, on_step=False, on_epoch=True, prog_bar=True)

        return {'test_loss': loss, 'test_accuracy': accuracy}

    def configure_optimizers(self):
        optimizer = optim.Adamax(
            self.parameters(),
            lr=self.hparams.learning_rate,
            betas=self.hparams.optimizer_betas,
        )
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.hparams.scheduler_step_size,
            gamma=self.hparams.scheduler_gamma,
        )
        return [optimizer], [scheduler]

In [None]:
class CustomSNNTrainValDataset(Dataset):
    def __init__(self, root_dir, class_list=CLASSES):
        self.samples = []  # list of (data_tensor, label_tensor)
        self.class_list = class_list

        # Load all samples from each class directory
        # Assuming structure: root_dir/<LABEL>-something/*.csv
        for folder in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, folder)
            if not os.path.isdir(folder_path):
                continue
            label = get_label_from_folder(folder_path)
            if label not in self.class_list:
                continue
            label_vec = one_hot_encode(label)
            # load all csv files
            csv_files = glob.glob(os.path.join(folder_path, '*.csv'))
            for csv_file in csv_files:
                data_array = self.load_csv(csv_file, time_steps=100, input_dim=16)
                self.samples.append((data_array, label_vec))

    def load_csv(self, csv_path, time_steps=1000, input_dim=16):
        # load a 1000x16 csv
        data = np.loadtxt(csv_path, delimiter=',', dtype=np.float32)
        # shape should be (1000,16)
        # if data.shape != (time_steps, input_dim):
        #     raise ValueError(f"CSV {csv_path} shape mismatch: {data.shape}")
        return data

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        data, label = self.samples[idx]
        # return as tensors
        return torch.tensor(data), torch.tensor(label)


def stratified_split(dataset, train_ratio=0.8):
    # Simple stratified split:
    # Count how many samples per class
    labels = []
    for i in range(len(dataset)):
        _, l = dataset[i]
        c = torch.argmax(l).item()
        labels.append(c)
    labels = np.array(labels)

    train_indices = []
    val_indices = []

    for c in np.unique(labels):
        class_indices = np.where(labels == c)[0]
        np.random.shuffle(class_indices)
        split_point = int(len(class_indices)*train_ratio)
        train_part = class_indices[:split_point]
        val_part = class_indices[split_point:]
        train_indices.extend(train_part)
        val_indices.extend(val_part)

    # shuffle overall?
    np.random.shuffle(train_indices)
    np.random.shuffle(val_indices)

    # create subsets
    train_subset = torch.utils.data.Subset(dataset, train_indices)
    val_subset = torch.utils.data.Subset(dataset, val_indices)
    return train_subset, val_subset


class CustomSNNTestDataset(Dataset):
    def __init__(self, test_dir, class_list=CLASSES):
        self.samples = []
        # Each class folder has one large csv (60000 x 16)
        for folder in os.listdir(test_dir):
            folder_path = os.path.join(test_dir, folder)
            if not os.path.isdir(folder_path):
                continue
            label = get_label_from_folder(folder_path)
            if label not in class_list:
                continue
            label_vec = one_hot_encode(label)
            # assume one test.csv per class folder
            csv_files = glob.glob(os.path.join(folder_path, '*.csv'))
            if len(csv_files) == 0:
                raise ValueError(f"No test csv found in {folder_path}")
            # just take the first csv file
            csv_file = csv_files[0]
            data_array = self.load_csv(csv_file, time_steps=6000, input_dim=16)
            self.samples.append((data_array, label_vec))

    def load_csv(self, csv_path, time_steps=60000, input_dim=16):
        data = np.loadtxt(csv_path, delimiter=',', dtype=np.float32)
        if data.shape != (time_steps, input_dim):
            raise ValueError(f"Test CSV {csv_path} shape mismatch: {data.shape}")
        return data

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        data, label = self.samples[idx]
        return torch.tensor(data), torch.tensor(label)


class CustomDataModule(pl.LightningDataModule):
    def __init__(self, train_val_dir, test_dir, batch_size=32, num_workers=4):
        super().__init__()
        self.train_val_dir = train_val_dir
        self.test_dir = test_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            full_dataset = CustomSNNTrainValDataset(self.train_val_dir)
            self.train_dataset, self.val_dataset = stratified_split(full_dataset, train_ratio=0.8)

        if stage == 'test' or stage is None:
            self.test_dataset = CustomSNNTestDataset(self.test_dir)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1, shuffle=False,
                          num_workers=self.num_workers, persistent_workers=True)

In [None]:
def generate_tau_beta_values(hidden_size, output_size):
    def create_power_vector(n, size):
        powers = [2 ** i for i in range(1, n + 1)]
        repeat_count = size // n
        power_vector = np.repeat(powers, repeat_count)
        return power_vector

    size = hidden_size
    tau_hidden_1 = create_power_vector(n=2, size=size)
    tau_hidden_2 = create_power_vector(n=4, size=size)
    tau_hidden_3 = create_power_vector(n=8, size=size)

    delta_t = 1
    beta_hidden_1 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_1, dtype=torch.float32))
    beta_hidden_2 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_2, dtype=torch.float32))
    beta_hidden_3 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_3, dtype=torch.float32))

    tau_output = np.repeat(10, output_size)
    beta_output = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_output, dtype=torch.float32))

    return beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output

In [None]:
def main():
    params = get_nni_params()
    beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output = generate_tau_beta_values(hidden_size, output_size)

    pl.seed_everything(42)

    # Directories
    train_val_dir = 'data/PROCESSED_YES_COCHLEA/CUT/TRAIN_VAL'  # directory containing CAR-..., STREET-..., HOME-..., CAFE-...
    test_dir = 'data/PROCESSED_YES_COCHLEA/CUT/TEST'            # directory containing test sets for each class

    data_module = CustomDataModule(
        train_val_dir=train_val_dir,
        test_dir=test_dir,
        batch_size=batch_size,
        num_workers=num_workers
    )
    data_module.setup('fit')

    model = Lightning_SNNQUT(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        beta_hidden_1=beta_hidden_1,
        beta_hidden_2=beta_hidden_2,
        beta_hidden_3=beta_hidden_3,
        beta_output=beta_output,
        hidden_reset_mechanism=hidden_reset_mechanism,
        output_reset_mechanism=output_reset_mechanism,
        learning_rate=params['learning_rate'],
        optimizer_betas=params['optimizer_betas'],
        scheduler_step_size=scheduler_step_size,
        scheduler_gamma=scheduler_gamma,
        output_threshold=output_threshold,
        hidden_threshold=hidden_threshold,
        fast_sigmoid_slope=params['fast_sigmoid_slope'],
    )

    logger = TensorBoardLogger(save_dir='logs', name='Mikel_LIF_quant_5bit')

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        log_every_n_steps=10,
        logger=logger,
    )

    # Train
    trainer.fit(model, datamodule=data_module)

    # After training, finalize quantization
    finalize_quantization(model.model)

    # Adjust lif4 threshold and reset mechanism for testing
    model.model.lif4.threshold = 1.0
    model.model.lif4.reset_mechanism = 'subtract'

    # Validate
    trainer.validate(model, datamodule=data_module)
    val_accuracy = trainer.callback_metrics.get('val_accuracy', torch.tensor(0)).item()

    # Test
    # Setup test
    data_module.setup('test')
    trainer.test(model, datamodule=data_module)
    test_accuracy = trainer.callback_metrics.get('test_accuracy', torch.tensor(0)).item()
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    nni.report_final_result(test_accuracy)

In [None]:
if __name__ == '__main__':
    main()

## To add:

- checkpoints and extra training
- log spikes, Vmem, ...
- learning rate finder (optional)
- thorough documentation 
- loss dependent on num spikes