In [None]:
import os
import random

from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import seaborn as sns

import nni
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from snntorch import surrogate
import snntorch as snn
from torch.utils.data import DataLoader, Dataset, random_split
from torchmetrics.classification import ConfusionMatrix

from QUT_DataLoader import QUTDataset

In [None]:
# Hyperparameters and Constants
input_size = 16
hidden_size = 24
output_size = 4

num_epochs = 2
batch_size = 32
scheduler_step_size = 3
scheduler_gamma = 0.5 # 0.5 means halving the learning rate every step_size epochs

output_threshold = 1e7
hidden_threshold = 1

num_workers = max(1, os.cpu_count() - 1)
hidden_reset_mechanism = 'subtract'
output_reset_mechanism = 'none'

# Vmem shift to be used for MSELoss
Vmem_shift_for_MSELoss = 0.5

In [None]:
from typing import Dict, Any

def get_nni_params() -> Dict[str, Any]:
    """
    Retrieve hyperparameters for model training, integrating with NNI (Neural Network Intelligence) if available.

    This function serves as a centralized way to define and update hyperparameters for your model.
    It initializes a set of default hyperparameters and, if NNI is being used for hyperparameter tuning,
    updates these defaults with the parameters suggested by NNI's tuner.

    **Functionality Overview:**

    1. **Initialize Default Hyperparameters:**
       - A dictionary `params` is created with default hyperparameter values that are used when NNI is not available or not being used.
       - Default hyperparameters include:
         - `'learning_rate'`: Controls how much to adjust the weights with respect to the loss gradient. Default is `0.01`.
         - `'optimizer_betas'`: Beta coefficients for the optimizer (e.g., Adam or Adamax). Default is `(0.9, 0.99)`.
         - `'fast_sigmoid_slope'`: Parameter for the slope of the surrogate gradient in spiking neural networks. Default is `10`.

    2. **Integrate with NNI for Hyperparameter Tuning:**
       - Attempts to retrieve hyperparameters from NNI using `nni.get_next_parameter()`.
       - If NNI is being used, it provides new values for hyperparameters that are being tuned.
       - The default parameters are updated with these new values, ensuring that the model uses the hyperparameters suggested by NNI.

    3. **Handle the Case When NNI Is Not Used:**
       - If NNI is not being used, `nni.get_next_parameter()` may return an empty dictionary or raise an exception.
       - To handle this, exception handling is included to catch any errors when NNI is not available.
       - In such cases, the function retains the default hyperparameters defined at the start.

    4. **Return the Final Hyperparameters:**
       - The function returns the `params` dictionary, which contains the hyperparameters to be used in model training.

    **Line-by-Line Explanation:**

    ```python
    params = {
        'learning_rate': 0.01,
        'optimizer_betas': (0.9, 0.99),
        'fast_sigmoid_slope': 10,
    }
    ```
    - **Purpose:** Initialize a dictionary named `params` with default hyperparameters.
    - **Details:**
      - `'learning_rate'`: Sets the learning rate for the optimizer. A smaller value means the weights are updated more slowly.
      - `'optimizer_betas'`: Tuple containing beta coefficients for optimizers like Adam/Adamax, controlling the decay rates of moving averages.
      - `'fast_sigmoid_slope'`: Defines the steepness of the surrogate gradient's activation function in spiking neural networks.

    ```python
    try:
        tuner_params = nni.get_next_parameter()
        params.update(tuner_params)
    except Exception as e:
        print("NNI is not being used or failed to retrieve parameters. Using default hyperparameters.")
    ```
    - **Purpose:** Update the default hyperparameters with those from NNI, if available.
    - **Details:**
      - `try` block:
        - `tuner_params = nni.get_next_parameter()`: Attempts to get the next set of parameters from NNI's tuner.
          - If NNI is running, this function returns a dictionary with hyperparameters that NNI wants to test.
          - If NNI is not running, this function may raise an exception or return an empty dictionary.
        - `params.update(tuner_params)`: Updates the `params` dictionary with the values from `tuner_params`.
          - This means any hyperparameters suggested by NNI will overwrite the defaults.
      - `except` block:
        - Catches any exceptions that occur if NNI is not available.
        - Prints a message indicating that default hyperparameters will be used.

    ```python
    return params
    ```
    - **Purpose:** Returns the final hyperparameters to be used in model training.
    - **Details:**
      - The returned `params` dictionary contains either the default hyperparameters or the ones updated with NNI's suggestions.

    **Usage Examples:**

    - **With NNI:**
      - When running hyperparameter tuning experiments with NNI, this function will integrate seamlessly, allowing NNI to provide new hyperparameters for each trial.
      - Example in training script:
        ```python
        params = get_nni_params()
        model = MyModel(**params)
        ```
      - NNI's tuner will modify the hyperparameters during each trial to find the optimal configuration.

    - **Without NNI:**
      - If NNI is not being used, the function will return the default hyperparameters.
      - The rest of your code can proceed without any modifications.
      - Example in training script:
        ```python
        params = get_nni_params()
        model = MyModel(**params)
        ```
      - The default hyperparameters defined in the function will be used.

    **Notes on NNI Integration:**

    - **NNI (Neural Network Intelligence):**
      - An open-source toolkit for hyperparameter optimization and neural architecture search.
      - Provides functionalities to automatically search for the best hyperparameters for machine learning models.
      - Official documentation: [NNI Documentation](https://nni.readthedocs.io/en/latest/)

    - **Function `nni.get_next_parameter()`:**
      - Used to retrieve the hyperparameters for the next trial from NNI's tuner.
      - Returns a dictionary with hyperparameters that can overwrite the defaults.

    - **Handling Absence of NNI:**
      - The `try-except` block ensures that the function doesn't fail when NNI is not being used.
      - This makes the function flexible and usable in both development and production environments.

    **Best Practices:**

    - **Exception Handling:**
      - Always include exception handling when calling external libraries that may not be available in all environments.
      - This prevents the code from crashing and provides a fallback mechanism.

    - **Documentation:**
      - Provide clear docstrings explaining the purpose and functionality of the function.
      - Include explanations for each line or block of code for better readability and maintainability.

    - **Type Annotations:**
      - Use type hints (e.g., `-> Dict[str, Any]`) to specify the expected return type of the function.
      - This helps with static analysis and improves code clarity.

    - **Modularity:**
      - Encapsulate hyperparameter retrieval in a separate function to keep the code organized.
      - Makes it easier to manage hyperparameters and integrate with tools like NNI.

    **Dependencies:**

    - **Required Libraries:**
      - `nni`: Ensure that the NNI library is installed (`pip install nni`) if you intend to use it.
      - `typing`: Used for type annotations.

    - **Import Statements:**
      ```python
      import nni
      from typing import Dict, Any
      ```

    **Possible Modifications:**

    - **Adding More Hyperparameters:**
      - You can expand the `params` dictionary with additional hyperparameters as needed.
        ```python
        params = {
            'learning_rate': 0.01,
            'optimizer_betas': (0.9, 0.99),
            'fast_sigmoid_slope': 10,
            'batch_size': 32,
            'num_epochs': 10,
        }
        ```

    - **Customizing Exception Handling:**
      - Instead of printing a message, you could log the exception or handle it differently depending on your application's requirements.

    **Conclusion:**

    - The `get_nni_params` function is a robust and flexible way to manage hyperparameters for your machine learning models.
    - By integrating with NNI, it allows for automated hyperparameter tuning.
    - The inclusion of exception handling ensures that the function remains functional even when NNI is not being used, making it suitable for various stages of development and deployment.

    **References:**

    - [NNI Documentation](https://nni.readthedocs.io/en/latest/)
    - [Python Type Hints (PEP 484)](https://www.python.org/dev/peps/pep-0484/)
    - [PEP 8 -- Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/)
    - [PEP 257 -- Docstring Conventions](https://www.python.org/dev/peps/pep-0257/)

    """
    params = {
        'learning_rate': 0.01,
        'optimizer_betas': (0.9, 0.99),
        'fast_sigmoid_slope': 10,
    }
    try:
        # Update with parameters from NNI
        tuner_params = nni.get_next_parameter()
        params.update(tuner_params)
    except Exception as e:
        print("NNI is not being used or failed to retrieve parameters. Using default hyperparameters.")
    return params

In [None]:
# Spiking Neural Network Model
class SNNQUT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output, hidden_reset_mechanism, output_reset_mechanism, hidden_threshold, output_threshold, fast_sigmoid_slope,):
        
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(beta=beta_hidden_1, reset_mechanism=hidden_reset_mechanism, threshold=hidden_threshold, spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc2 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif2 = snn.Leaky(beta=beta_hidden_2, reset_mechanism=hidden_reset_mechanism,  threshold=hidden_threshold, spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc3 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif3 = snn.Leaky(beta=beta_hidden_3, reset_mechanism=hidden_reset_mechanism,  threshold=hidden_threshold, spike_grad=snn.surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc4 = nn.Linear(hidden_size, output_size, bias=False)
        self.lif4 = snn.Leaky(beta=beta_output, reset_mechanism=output_reset_mechanism, threshold=output_threshold)

        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.xavier_normal_(self.fc3.weight)
        nn.init.xavier_normal_(self.fc4.weight)
        #self.enforce_weight_constraints() # enforcing weight constraints to have some % of negative weights

    def enforce_weight_constraints(self):
        # Collect all weight tensors and their shapes
        params = []
        shapes = []
        for param in self.parameters():
            if param.requires_grad:
                params.append(param)
                shapes.append(param.shape)
        # Flatten all parameters and concatenate
        flat_params = [p.data.view(-1) for p in params]
        all_weights = torch.cat(flat_params)
        total_weights = all_weights.numel()
        num_neg_weights = int(0.5 * total_weights)  # number of negative weights

        # Ensure the random indices are on the same device
        device = all_weights.device
        # Set all weights to absolute values (positive)
        all_weights = all_weights.abs()
        # Randomly select num_neg_weights indices to be negative
        permuted_indices = torch.randperm(total_weights, device=device)
        neg_indices = permuted_indices[:num_neg_weights]
        # Set selected weights to negative
        all_weights[neg_indices] *= -1

        # Now, split all_weights back to the parameter shapes
        pointer = 0
        for i, param in enumerate(params):
            numel = param.numel()
            param_data = all_weights[pointer:pointer+numel].view(shapes[i])
            param.data.copy_(param_data)
            pointer += numel

    def forward(self, x):
        x = x.to(torch.float32)  # Convert input to float32
        batch_size, time_steps, _ = x.shape

        # Initialization of membrane potentials
        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=x.device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=x.device)
        mem4 = torch.zeros(batch_size, self.fc4.out_features, device=x.device)

        spk1_rec = []
        spk2_rec = []
        spk3_rec = []
        mem4_rec = []

        for step in range(time_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            cur4 = self.fc4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)

            # Record at every time step
            spk1_rec.append(spk1)
            spk2_rec.append(spk2)
            spk3_rec.append(spk3)
            mem4_rec.append(mem4)

        # Stack along the time axis (first dimension)
        return torch.stack(spk1_rec, dim=0), torch.stack(spk2_rec, dim=0), torch.stack(spk3_rec, dim=0), torch.stack(mem4_rec, dim=0)

In [None]:
# Lightning Module
class Lightning_SNNQUT(pl.LightningModule):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        beta_hidden_1,
        beta_hidden_2,
        beta_hidden_3,
        beta_output,
        hidden_reset_mechanism,
        output_reset_mechanism,
        hidden_threshold,
        output_threshold,
        learning_rate,
        scheduler_step_size,
        scheduler_gamma,
        optimizer_betas,
        fast_sigmoid_slope,
    ):
        super().__init__()
        self.save_hyperparameters(
            'input_size',
            'hidden_size',
            'output_size',
            'hidden_reset_mechanism',
            'output_reset_mechanism',
            'hidden_threshold',
            'output_threshold',
            'learning_rate',
            'scheduler_step_size',
            'scheduler_gamma',
            'optimizer_betas',
            'fast_sigmoid_slope',
        )
        # Assign tensors directly since using self.save_hyperparameters() gives error because beta tensors are not JSON serializable:
        # - they are torch tensors, can't convert to JSON
        self.beta_hidden_1 = beta_hidden_1
        self.beta_hidden_2 = beta_hidden_2
        self.beta_hidden_3 = beta_hidden_3
        self.beta_output = beta_output

        # Initialize confusion matrices for train, val, and test
        self.train_confmat = ConfusionMatrix(task='multiclass',num_classes=self.hparams.output_size)
        self.val_confmat = ConfusionMatrix(task='multiclass',num_classes=self.hparams.output_size)
        self.test_confmat = ConfusionMatrix(task='multiclass',num_classes=self.hparams.output_size)



        # Initialize the SNN model
        self.model = SNNQUT(
            input_size=self.hparams.input_size,
            hidden_size=self.hparams.hidden_size,
            output_size=self.hparams.output_size,
            beta_hidden_1=self.beta_hidden_1,
            beta_hidden_2=self.beta_hidden_2,
            beta_hidden_3=self.beta_hidden_3,
            beta_output=self.beta_output,
            hidden_reset_mechanism=self.hparams.hidden_reset_mechanism,
            output_reset_mechanism=self.hparams.output_reset_mechanism,
            output_threshold=self.hparams.output_threshold,
            hidden_threshold=self.hparams.hidden_threshold,
            fast_sigmoid_slope=self.hparams.fast_sigmoid_slope,
        )

        # Initialize the loss function
        self.loss_function = nn.MSELoss()

    def on_train_epoch_end(self):
        self.log_weight_statistics('train')
        if self.train_confmat._update_called:  # Check if `update` has been called
            confmat = self.train_confmat.compute()
            fig = plt.figure(figsize=(8, 8))
            sns.heatmap(confmat.cpu().numpy(), annot=True, fmt='d', cmap='Blues')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.title('Train Confusion Matrix')
            self.logger.experiment.add_figure('train_confusion_matrix', fig, self.current_epoch)
            plt.close(fig)
            self.train_confmat.reset()
        else:
            print("Warning: No data to compute train confusion matrix")

    def on_validation_epoch_end(self):
        self.log_weight_statistics('val')
        if self.val_confmat._update_called:  # Check if `update` has been called
            confmat = self.val_confmat.compute()
            fig = plt.figure(figsize=(8, 8))
            sns.heatmap(confmat.cpu().numpy(), annot=True, fmt='d', cmap='Blues')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.title('Validation Confusion Matrix')
            self.logger.experiment.add_figure('val_confusion_matrix', fig, self.current_epoch)
            plt.close(fig)
            self.val_confmat.reset()
        else:
            print("Warning: No data to compute validation confusion matrix")


    def log_weight_statistics(self, mode):
        total_neg_weights = 0
        total_weights = 0
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                weight_mean = param.data.mean()
                weight_std = param.data.std()
                num_neg = (param.data < 0).sum().item()
                num_total = param.data.numel()
                total_neg_weights += num_neg
                total_weights += num_total
                self.log(f'{mode}_weight_mean_{name}', weight_mean, on_epoch=True, prog_bar=False)
                self.log(f'{mode}_weight_std_{name}', weight_std, on_epoch=True, prog_bar=False)
                self.log(f'{mode}_num_neg_{name}', num_neg, on_epoch=True, prog_bar=False)
        percent_neg_weights = (total_neg_weights / total_weights) * 100
        self.log(f'{mode}_percent_neg_weights', percent_neg_weights, on_epoch=True, prog_bar=False)

                
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1_rec, spk2_rec, spk3_rec, mem4_rec = self(inputs)

        # Logging membrane potentials
        self.logger.experiment.add_scalar('spk1_sum', spk1_rec.sum(), self.global_step)
        self.logger.experiment.add_scalar('spk2_sum', spk2_rec.sum(), self.global_step)
        self.logger.experiment.add_scalar('spk3_sum', spk3_rec.sum(), self.global_step)


        # Expanding labels to match mem4_rec's shape
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)

        # Calculate loss
        loss = self.loss_function(mem4_rec, (labels_expanded * Vmem_shift_for_MSELoss) + Vmem_shift_for_MSELoss)

        # Use the final membrane potential for prediction
        final_mem4 = mem4_rec.sum(0)

        # Predicted class is the one with the highest membrane potential
        _, predicted = final_mem4.max(-1)
        _, targets = labels.max(-1)

        # Calculate accuracy
        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total

        # Update confusion matrix
        self.train_confmat.update(predicted, targets)

        # Log training loss and accuracy
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_accuracy', accuracy * 100, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def on_after_backward(self):
        # Log the gradient norm (average of the absolute value of the gradients) and gradient mean, for each parameter
        for name, param in self.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.abs().mean()
                grad_mean = param.grad.mean()
                self.log(f'grad_norm_{name}', grad_norm, on_step=True, on_epoch=True, prog_bar=False)
                self.log(f'grad_mean/{name}', grad_mean, on_step=True, on_epoch=True, prog_bar=False)
            else:
                print(f'No gradient for parameter: {name}')


    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1_rec, spk2_rec, spk3_rec, mem4_rec = self(inputs)

        # Expanding labels to match mem4_rec's shape
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)

        # Calculate loss
        loss = self.loss_function(mem4_rec, (labels_expanded * Vmem_shift_for_MSELoss) + Vmem_shift_for_MSELoss)

        # Use the final membrane potential for prediction
        final_mem4 = mem4_rec.sum(0)

        # Predicted class is the one with the highest membrane potential
        _, predicted = final_mem4.max(-1)
        _, targets = labels.max(-1)

        # Calculate accuracy
        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total

        # Update confusion matrix
        self.val_confmat.update(predicted, targets)

        # Log validation loss and accuracy
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_accuracy', accuracy * 100, on_step=False, on_epoch=True, prog_bar=True)

        return {'val_loss': loss, 'val_accuracy': accuracy}

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1_rec, spk2_rec, spk3_rec, mem4_rec = self(inputs)

        # Expanding labels to match mem4_rec's shape
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)

        # Calculate loss
        loss = self.loss_function(mem4_rec, (labels_expanded * Vmem_shift_for_MSELoss) + Vmem_shift_for_MSELoss)

        # Use the final membrane potential for prediction
        final_mem4 = mem4_rec.sum(0)

        # Predicted class is the one with the highest membrane potential
        _, predicted = final_mem4.max(-1)
        _, targets = labels.max(-1)

        # Calculate accuracy
        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total

        # Update confusion matrix
        self.test_confmat.update(predicted, targets)

        # Log test loss and accuracy
        self.log('test_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test_accuracy', accuracy * 100, on_step=True, on_epoch=True, prog_bar=True)

        return {'test_loss': loss, 'test_accuracy': accuracy}

    # def on_before_zero_grad(self, optimizer):
    #         for param in self.model.parameters():
    #             if param.requires_grad:
    #                 param.data.clamp_(min=0.001)

# enforcing weight constraints to have some % of negative weights, use EITHER clamp or enforce_weight_constraints
    # def on_before_zero_grad(self, optimizer):
    #         self.model.enforce_weight_constraints() 
    
    def configure_optimizers(self):

        optimizer = optim.Adamax(
            self.parameters(),
            lr=self.hparams.learning_rate,
            betas=self.hparams.optimizer_betas,
        )

        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.hparams.scheduler_step_size,
            gamma=self.hparams.scheduler_gamma,
        )
        
        return [optimizer]  , [scheduler]

In [None]:
# Data Module
class QUTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        dataset = QUTDataset(self.data_dir)
        train_size = int(0.65 * len(dataset))
        val_size = int(0.15 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            dataset, [train_size, val_size, test_size]
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
        )

In [None]:
# Function to Generate Tau and Beta Values
def generate_tau_beta_values(hidden_size, output_size):

    def create_power_vector(n, size):
        # Generate the powers of 2 up to 2^n
        powers = [2 ** i for i in range(1, n + 1)]
        # Calculate how many times each power should be repeated
        repeat_count = size // n
        # Create the final vector by repeating each power equally
        power_vector = np.repeat(powers, repeat_count)
        return power_vector

    # Generate Tau Values
    size = hidden_size
    tau_hidden_1 = create_power_vector(n=2, size=size)
    tau_hidden_2 = create_power_vector(n=4, size=size)
    tau_hidden_3 = create_power_vector(n=8, size=size)

    # Generate Beta Values from Tau
    delta_t = 1  # 1ms time step

    beta_hidden_1 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_1, dtype=torch.float32))
    beta_hidden_2 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_2, dtype=torch.float32))
    beta_hidden_3 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_3, dtype=torch.float32))

    tau_output = np.repeat(10, output_size)
    beta_output = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_output, dtype=torch.float32))

    # Return all beta values
    return beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output

In [None]:
# Main Function, so it's easy to convert to script
def main():
    # Get hyperparameters from NNI
    params = get_nni_params()
    beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output = generate_tau_beta_values(hidden_size, output_size)


    # Set random seeds for reproducibility
    pl.seed_everything(42)

    # Initialize the data module
    # data_dir is either TEST or 4_one_second_samples
    data_dir = 'data/TEST'
    data_module = QUTDataModule(
        data_dir, batch_size=batch_size, num_workers=num_workers
    )

    # Initialize the Lightning model with hyperparameters from NNI
    model = Lightning_SNNQUT(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        beta_hidden_1=beta_hidden_1,
        beta_hidden_2=beta_hidden_2,
        beta_hidden_3=beta_hidden_3,
        beta_output=beta_output,
        hidden_reset_mechanism=hidden_reset_mechanism,
        output_reset_mechanism=output_reset_mechanism,
        learning_rate=params['learning_rate'],
        optimizer_betas=params['optimizer_betas'],
        scheduler_step_size=scheduler_step_size,
        scheduler_gamma=scheduler_gamma,
        output_threshold=output_threshold,
        hidden_threshold=hidden_threshold,
        fast_sigmoid_slope=params['fast_sigmoid_slope'],
    )

    logger = TensorBoardLogger(save_dir='logs',name='Mikel_LIF')

    # Initialize the Trainer
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        log_every_n_steps=10,
        logger=logger,
        #accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices='auto',
    )

    # Start training
    trainer.fit(model, datamodule=data_module)

    # Validate the model
    trainer.validate(model, datamodule=data_module)
    val_accuracy = trainer.callback_metrics['val_accuracy'].item()

    # Test the model
    trainer.test(model, datamodule=data_module)
    test_accuracy = trainer.callback_metrics['test_accuracy'].item()
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    # Report the result to NNI
    nni.report_final_result(test_accuracy)

In [None]:
if __name__ == '__main__':
    main()

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=logs/Mikel_LIF

## To add:

- checkpoints and extra training
- log spikes, Vmem, ...
- learning rate finder (optional)
- thorough documentation 
- loss dependent on num spikes