## Imports and Hyperparameters

In [2]:
def get_nni_params():
    # Default hyperparameters
    params = {
        'learning_rate': 0.001,
        'optimizer_betas': (0.9, 0.999),
        'fast_sigmoid_slope': 20.0,
    }
    # Update with parameters from NNI
    tuner_params = nni.get_next_parameter()
    params.update(tuner_params)
    return params

In [3]:
import os
import torch.nn as nn
import snntorch as snn
import torch.optim as optim
import torch.backends.mps
import random
import numpy as np
import nni
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import Dataset, DataLoader, random_split
from QUT_DataLoader import QUTDataset

In [4]:
input_size = 16
hidden_size = 24
output_size = 4
num_epochs = 2
batch_size = 32
optimizer_betas = (0.9, 0.999)
scheduler_step_size = 20
scheduler_gamma = 0.2
learning_rate = 0.05
num_workers = os.cpu_count()
cuba_tau = 2 # 2ms of tau_Vmem of the CUBA neuron
hidden_reset_mechanism = 'subtract'
output_reset_mechanism = 'none'
output_threshold = 10000
hid_threshold = 0.001
# logger = TensorBoardLogger('tb_logs', name='snn_QUT_TensorBoardLogger')

## Spiking Neural Network Model

In [5]:
class SNNQUT(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        beta_hidden_1,
        beta_hidden_2,
        beta_hidden_3,
        beta_output,
        cuba_beta,
        hidden_reset_mechanism,
        output_reset_mechanism,
        output_threshold,
        hid_threshold,
        fast_sigmoid_slope,
    ):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        # here for CUBA I am using alpha = the betas of the previous LIF neurons, and beta = cuba_beta, which is extracted from cuba_tau
        self.lif1 = snn.Synaptic(alpha=beta_hidden_1, beta=cuba_beta, spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope), reset_mechanism=hidden_reset_mechanism, threshold=hid_threshold, learn_threshold=False)

        self.fc2 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif2 = snn.Synaptic(alpha=beta_hidden_2, beta=cuba_beta, spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope), reset_mechanism=hidden_reset_mechanism, threshold=hid_threshold, learn_threshold=False)

        self.fc3 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif3 = snn.Synaptic(alpha=beta_hidden_3, beta=cuba_beta, spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope), reset_mechanism=hidden_reset_mechanism, threshold=hid_threshold, learn_threshold=False)

        self.fc4 = nn.Linear(hidden_size, output_size, bias=False)
        self.lif4 = snn.Leaky(
            beta=beta_output,
            reset_mechanism=output_reset_mechanism,
            threshold=output_threshold,
        )

        self._initialize_weights()

    # Initialization with normal distribution, let's try xavier for better convergence
    # def _initialize_weights(self):
    #     nn.init.normal_(self.fc1.weight, mean=1.0, std=0.1)
    #     nn.init.normal_(self.fc2.weight, mean=0.3, std=0.1)
    #     nn.init.normal_(self.fc3.weight, mean=0.2, std=0.1)
    #     nn.init.normal_(self.fc4.weight, mean=0.1, std=0.1)

    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.xavier_uniform_(self.fc4.weight)


    def forward(self, x):
        x = x.to(torch.float32)  # Convert input to float32
        batch_size, time_steps, _ = x.shape

        # Initialization of membrane potentials
        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=x.device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=x.device)
        mem4 = torch.zeros(batch_size, self.fc4.out_features, device=x.device)

        syn1 = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        syn2 = torch.zeros(batch_size, self.fc2.out_features, device=x.device)
        syn3 = torch.zeros(batch_size, self.fc3.out_features, device=x.device)

        mem4_rec = []
        spk3_rec = []

        for step in range(time_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)
            cur2 = self.fc2(spk1)
            spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)
            cur3 = self.fc3(spk2)
            spk3, syn3, mem3 = self.lif3(cur3, syn3, mem3)
            cur4 = self.fc4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)

            # Record at every time step
            mem4_rec.append(mem4)
            spk3_rec.append(spk3)

        return torch.stack(mem4_rec, dim=0), torch.stack(spk3_rec, dim=0) # so they will be stacked along the time axis (1000 steps) on the first dimension (dim=0)

## Lightning Module

In [6]:
class Lightning_SNNQUT(pl.LightningModule):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        beta_hidden_1,
        beta_hidden_2,
        beta_hidden_3,
        beta_output,
        hidden_reset_mechanism,
        output_reset_mechanism,
        hid_threshold,
        output_threshold,
        cuba_beta,
        learning_rate=learning_rate,
        optimizer_betas=optimizer_betas,
        scheduler_step_size=scheduler_step_size,
        scheduler_gamma=scheduler_gamma,
        fast_sigmoid_slope=20.0,
    ):
        super().__init__()
        self.save_hyperparameters('learning_rate', 'optimizer_betas', 'fast_sigmoid_slope')

        # Initialize the SNN model
        self.model = SNNQUT(
            input_size=input_size,
            hidden_size=hidden_size,
            output_size=output_size,
            beta_hidden_1=beta_hidden_1,
            beta_hidden_2=beta_hidden_2,
            beta_hidden_3=beta_hidden_3,
            beta_output=beta_output,
            cuba_beta=cuba_beta,
            hidden_reset_mechanism=hidden_reset_mechanism,
            output_reset_mechanism=output_reset_mechanism,
            output_threshold=output_threshold,
            hid_threshold = hid_threshold,
            fast_sigmoid_slope=fast_sigmoid_slope,
        )

        # Initialize the loss function
        self.loss_function = nn.MSELoss()
        # self.loss_function = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        mem4_rec, spk3_rec = self(inputs)

        # Expanding labels to match mem4_rec's shape
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)


        # Calculate loss
        loss = self.loss_function(mem4_rec, (labels_expanded*5)+5)
        # loss = self.loss_function(mem4_rec, labels_expanded)

        # Use the final membrane potential for prediction
        final_mem4 = mem4_rec.sum(0)

        # Predicted class is the one with the highest membrane potential
        _, predicted = final_mem4.max(-1)
        _, targets = labels.max(-1)

        # Calculate accuracy
        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total

        # Log training loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_accuracy', accuracy*100, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        mem4_rec, spk3_rec = self(inputs)

        # Expanding labels to match mem4_rec's shape
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)

        # Calculate loss
        loss = self.loss_function(mem4_rec, (labels_expanded*5)+5)
        # loss = self.loss_function(mem4_rec, labels_expanded)

        # Use the final membrane potential for prediction
        final_mem4 = mem4_rec.sum(0)

        # Predicted class is the one with the highest membrane potential
        _, predicted = final_mem4.max(-1)
        _, targets = labels.max(-1)

        # Calculate accuracy
        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total

        # Log validation loss
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_accuracy', accuracy * 100, on_step=False, on_epoch=True, prog_bar=True)
    
        return {'val_loss': loss, 'val_accuracy': accuracy}

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        mem4_rec, spk3_rec = self(inputs)

        # Expanding labels to match mem4_rec's shape
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)

        # Calculate loss
        loss = self.loss_function(mem4_rec, (labels_expanded*5)+5)
        # loss = self.loss_function(mem4_rec, labels_expanded)

        # Use the final membrane potential for prediction
        final_mem4 = mem4_rec.sum(0)

        # Predicted class is the one with the highest membrane potential
        _, predicted = final_mem4.max(-1)
        _, targets = labels.max(-1)

        # For checking mems
        #print(mem4_rec[:,0])
        
        # Calculate accuracy
        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total

        # Log test loss
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test_accuracy', accuracy*100, on_step=True, on_epoch=True, prog_bar=True)
        
        return {'test_loss': loss, 'test_accuracy': accuracy}

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            betas=self.hparams.optimizer_betas,
        )
        # scheduler = optim.lr_scheduler.StepLR(
        #     optimizer,
        #     step_size=self.hparams.scheduler_step_size,
        #     gamma=self.hparams.scheduler_gamma,
        # )
        return [optimizer]#, [scheduler]

## Data Module

In [7]:
class QUTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        dataset = QUTDataset(self.data_dir)
        train_size = int(0.65 * len(dataset))
        val_size = int(0.15 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            dataset, [train_size, val_size, test_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

## Tau and Beta

In [8]:
"""

        GEOMETRIC SERIES TAU
 
"""

def create_power_vector(n, size):        
    # Generate the powers of 2 up to 2^n
    powers = [2**i for i in range(1, n+1)]
    
    # Calculate how many times each power should be repeated
    repeat_count = size // n
    
    # Create the final vector by repeating each power equally
    vector = np.repeat(powers, repeat_count)
    
    return vector


size = 24
n = 2
tau_hidden_1 = create_power_vector(n, size)
print(f"Vector: tau_hidden_1 = {tau_hidden_1}")

size = 24
n = 4
tau_hidden_2 = create_power_vector(n, size)
print(f"Vector: tau_hidden_2 = {tau_hidden_2}")

size = 24
n = 8
tau_hidden_3 = create_power_vector(n, size)
print(f"Vector: tau_hidden_3 = {tau_hidden_3}")


Vector: tau_hidden_1 = [2 2 2 2 2 2 2 2 2 2 2 2 4 4 4 4 4 4 4 4 4 4 4 4]
Vector: tau_hidden_2 = [ 2  2  2  2  2  2  4  4  4  4  4  4  8  8  8  8  8  8 16 16 16 16 16 16]
Vector: tau_hidden_3 = [  2   2   2   4   4   4   8   8   8  16  16  16  32  32  32  64  64  64
 128 128 128 256 256 256]


In [9]:
"""

        BETA DISTRIBUTION FROM TAU

"""

delaT = 1 # 1ms time step


beta_hidden_1 = torch.exp(-torch.tensor(delaT, dtype=torch.float32) / torch.tensor(tau_hidden_1, dtype=torch.float32))
print(f"\nbeta_hidden_1: {beta_hidden_1}")

beta_hidden_2 = torch.exp(-torch.tensor(delaT, dtype=torch.float32) / torch.tensor(tau_hidden_2, dtype=torch.float32))
print(f"\nbeta_hidden_2: {beta_hidden_2}")

beta_hidden_3 = torch.exp(-torch.tensor(delaT, dtype=torch.float32) / torch.tensor(tau_hidden_3, dtype=torch.float32))
print(f"\nbeta_hidden_3: {beta_hidden_3}")

tau_output = np.repeat(10, output_size)
print(f"\ntau_output: {tau_output}")

beta_output = torch.exp(-torch.tensor(delaT, dtype=torch.float32) / torch.tensor(tau_output, dtype=torch.float32))
print(f"\nbeta_output: {beta_output}")

cuba_beta = torch.exp(-torch.tensor(delaT, dtype=torch.float32) / torch.tensor(cuba_tau, dtype=torch.float32))
print(f"\ncuba_beta: {cuba_beta}")


beta_hidden_1: tensor([0.6065, 0.6065, 0.6065, 0.6065, 0.6065, 0.6065, 0.6065, 0.6065, 0.6065,
        0.6065, 0.6065, 0.6065, 0.7788, 0.7788, 0.7788, 0.7788, 0.7788, 0.7788,
        0.7788, 0.7788, 0.7788, 0.7788, 0.7788, 0.7788])

beta_hidden_2: tensor([0.6065, 0.6065, 0.6065, 0.6065, 0.6065, 0.6065, 0.7788, 0.7788, 0.7788,
        0.7788, 0.7788, 0.7788, 0.8825, 0.8825, 0.8825, 0.8825, 0.8825, 0.8825,
        0.9394, 0.9394, 0.9394, 0.9394, 0.9394, 0.9394])

beta_hidden_3: tensor([0.6065, 0.6065, 0.6065, 0.7788, 0.7788, 0.7788, 0.8825, 0.8825, 0.8825,
        0.9394, 0.9394, 0.9394, 0.9692, 0.9692, 0.9692, 0.9845, 0.9845, 0.9845,
        0.9922, 0.9922, 0.9922, 0.9961, 0.9961, 0.9961])

tau_output: [10 10 10 10]

beta_output: tensor([0.9048, 0.9048, 0.9048, 0.9048])

cuba_beta: 0.6065306663513184


In [10]:
# # Set random seeds for reproducibility
# pl.seed_everything(42)

# # Device configuration is handled by Lightning
# # So we don't need to manually set device

# # Initialize the data module
# data_dir = 'data/4_one_second_samples' 
# data_module = QUTDataModule(
#     data_dir, batch_size=batch_size, num_workers=num_workers
# )

# # Initialize the Lightning model
# model = Lightning_SNNQUT(
#     input_size=input_size,
#     hidden_size=hidden_size,
#     output_size=output_size,
#     beta_hidden_1=beta_hidden_1,
#     beta_hidden_2=beta_hidden_2,
#     beta_hidden_3=beta_hidden_3,
#     beta_output=beta_output,
#     hidden_reset_mechanism=hidden_reset_mechanism,
#     output_reset_mechanism=output_reset_mechanism,
#     learning_rate=learning_rate,
#     optimizer_betas=optimizer_betas,
#     scheduler_step_size=scheduler_step_size,
#     scheduler_gamma=scheduler_gamma,
#     output_threshold=output_threshold,
#     cuba_beta=cuba_beta,
#     hid_threshold = hid_threshold
# )

# # Initialize the Trainer
# trainer = pl.Trainer(
#     max_epochs=num_epochs,
#     # logger=logger,
#     # Uncomment the following line if GPU is available
#     # devices=1 if torch.cuda.is_available() else None,
#     # accelerator='gpu' if torch.cuda.is_available() else 'cpu',
# )

# # Start training
# trainer.fit(model, datamodule=data_module)

# # Test the model
# trainer.test(model, datamodule=data_module)

In [11]:
def main():
    # Get hyperparameters from NNI
    params = get_nni_params()

    # Set random seeds for reproducibility
    pl.seed_everything(42)

    # Initialize the data module
    #data_dir = 'data/4_one_second_samples'
    data_dir = 'data/TEST'
    data_module = QUTDataModule(
        data_dir, batch_size=batch_size, num_workers=num_workers
    )

    # Initialize the Lightning model with hyperparameters from NNI
    model = Lightning_SNNQUT(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        beta_hidden_1=beta_hidden_1,
        beta_hidden_2=beta_hidden_2,
        beta_hidden_3=beta_hidden_3,
        beta_output=beta_output,
        hidden_reset_mechanism=hidden_reset_mechanism,
        output_reset_mechanism=output_reset_mechanism,
        learning_rate=params['learning_rate'],
        optimizer_betas=tuple(params['optimizer_betas']),
        scheduler_step_size=scheduler_step_size,
        scheduler_gamma=scheduler_gamma,
        output_threshold=output_threshold,
        cuba_beta=cuba_beta,
        hid_threshold=hid_threshold,
        fast_sigmoid_slope=params['fast_sigmoid_slope'],
    )

    # Initialize the Trainer
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        # devices and accelerator settings
    )

    # Start training
    trainer.fit(model, datamodule=data_module)

    # Validate the model
    #val_result = trainer.validate(model, datamodule=data_module)
    # val_accuracy = val_result[0]['val_accuracy'] NOTE: Accessing val_accuracy as val_result[0]['val_accuracy'] only gives the accuracy for the first batch, not the entire validation dataset.

    # Validate the model
    trainer.validate(model, datamodule=data_module)
    val_accuracy = trainer.callback_metrics['val_accuracy'].item() # Accessing val_accuracy as trainer.callback_metrics['val_accuracy'] gives the accuracy for the entire validation dataset. .item() is used to convert the tensor to a Python number.



    # Report the result to NNI
    nni.report_final_result(val_accuracy)

if __name__ == '__main__':
    main()

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name          | Type    | Params | Mode 
--------------------------------------------------
0 | model         | SNNQUT  | 1.6 K  | train
1 | loss_function | MSELoss | 0      | train
--------------------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.007     Total estimated model params size (MB)


Loaded 2400 files from data/TEST


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/NeuroVecio/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
/opt/anaconda3/envs/NeuroVecio/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
/opt/anaconda3/envs/NeuroVecio/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (49) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


Loaded 2400 files from data/TEST


Validation: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy           26.94444465637207
     val_loss_epoch         43.750003814697266
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[2024-11-21 17:07:56] [32mFinal result: 26.94444465637207[0m
