[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width="300">](https://github.com/jeshraghian/snntorch/)

# Discover SNN Hyperparameters with Optuna
### Tutorial written by Reto Stamm
<a href="https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_optuna.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width="28">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width="80">](https://github.com/jeshraghian/snntorch/)

*This tutorial demonstrates optimizing Spiking Neural Network hyperparameters with Optuna, blending advanced neural modeling and hyperparameter tuning. In this example, we minimize power consumption by adjusting hyperparameters.*

**[Optuna](https://optuna.org)** is an efficient, open-source hyperparameter optimization framework. It helps automatically figure out the best settings for machine learning models.

In this tutorial, we will make the reasonable assumption that the more spikes in a network, the more power is consumed. We want to **minimize power consumption**, so we adjust the model's **hyperparameters**, including the shape of the network, the number of time-steps, and the number of epochs it is trained for.

For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)
The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:

> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. "Training Spiking Neural Networks Using Lessons From Deep Learning". Proceedings of the IEEE, 111(9) September 2023.](https://ieeexplore.ieee.org/abstract/document/10242251) </cite>

In [None]:
!pip install optuna snntorch optunacy --quiet

In [None]:
# Import all the libraries
import copy
import logging
import random
import numbers
import sys
import time # To see how long each iteration takes
import multiprocessing # To check how many cores we have

import optuna # the optimizer
from optuna.exceptions import TrialPruned # To abort, or prune inefficient parameter sets
from concurrent.futures import ThreadPoolExecutor
from optuna.trial import TrialState

# Basic torch tools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

# Image processing tools
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Extra plotting tools
from optunacy.oplot import OPlot
import scipy as scipy

# Spiking Neural Networks!
import snntorch as snn
import snntorch.functional as SF
# from snntorch import utils

## 1. The MNIST Dataset
### 1.1 Dataloading
Define variables for dataloading.

In [None]:
batch_size = 128
data_path='/tmp/data/mnist'

Load the dataset. This is mostly just boilerplate. Note that we are taking a subset of the dataset.

We also do not use the test set for tuning hyperparameters, or we would be leaking test time data.

In [None]:
# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

# The MNIST dataset contains black and white images (28x28) of digits from 0-9
# 60,000 training images
mnist_train = datasets.MNIST(data_path, train=True, download=True,
                             transform=transform)

# Split out a validation subset
total_size = len(mnist_train)
val_size = int(total_size * 0.08)  # 8% for validation
train_size = total_size - val_size  # Remaining for training

# Split the dataset, the same way every time
mnist_val = Subset(mnist_train, range(train_size, total_size))
mnist_train = Subset(mnist_train, range(0, train_size))

# 10,000 test images
mnist_test = datasets.MNIST(data_path, train=False, download=True,
                            transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(mnist_val, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

## 2. A parameterizeable network with snnTorch

With this MNIST dataset, some things always remain the same:

* The input image size, and the fact that we want to detect one of 10 pixels. Those are hardwired.

* The output layer is fixed to classify digits 0-9. There is one neuron for each digit.

* The first layer's decay rate *beta1* can be learnable or adjustable as a hyperparameter.

* The number of timesteps before we get the result is configurable.

We also track the average spike activity across the network, so that we can calculate how much spiking activity per digit was generated.

The model definition below constructs the layers in a loop to ensure they are parameterizable.

In [None]:
class Net(nn.Module):

    def __init__(self, num_steps, num_hidden_neurons=299, num_hidden_layers=1, beta1=0.9):
        super().__init__()
        assert 0 <= beta1 <= 1, "Beta1 must be between 0 and 1"
        assert num_hidden_layers >= 0, "Number of hidden layers must be non-negative"

        num_inputs = 28 * 28 # image is 28x28 pixels
        num_outputs = 10 # we want to get digits 0-9
        self.num_steps = num_steps
        self.num_hidden_neurons = num_hidden_neurons
        self.num_hidden_layers = num_hidden_layers

        # Initialize layers
        self.layers = []
        for n in range(num_hidden_layers + 1):
            layer = {}
            if n == 0:
                # First layer
                layer['fc'] = nn.Linear(num_inputs, num_hidden_neurons)
                layer['lif'] = snn.Leaky(beta=beta1)
            elif n < num_hidden_layers:
                # Inner layers
                layer['fc'] = nn.Linear(num_hidden_neurons, num_hidden_neurons)
                beta2 = torch.rand((num_hidden_neurons), dtype=torch.float)
                layer['lif'] = snn.Leaky(beta=beta2, learn_beta=True)
            else:
                # Output layer
                layer['fc'] = nn.Linear(num_hidden_neurons, num_outputs)
                beta2 = torch.rand((num_outputs), dtype=torch.float)
                layer['lif'] = snn.Leaky(beta=beta2, learn_beta=True)

            # Add the layers to the internal representation
            self.add_module(f'fc{n}', layer['fc'])
            self.add_module(f'lif{n}', layer['lif'])

            # Add the layers to our layer list.
            self.layers.append(layer)

        # Reset spike counter
        self.reset_spikes()

    def forward(self, x):
        # The forward pass.

        # Initialize all the neurons in all layers
        for layer in self.layers:
            layer['mem'] = layer['lif'].init_leaky()

        spk_rec, mem_rec = [], []

        # process each timestep
        for step in range(self.num_steps):
            cur = x.flatten(1)

            # process each layer
            for index, layer in enumerate(self.layers):
                # process the data
                cur, layer['mem'] = layer['lif'](layer['fc'](cur), layer['mem'])

                # update the total spike count
                self.total_spike_count += cur.sum().item()
            # update the spike records
            spk_rec.append(cur)
            mem_rec.append(self.layers[-1]['mem'])

        self.forward_count += 1 # so we can normalize the spike_count later
        return torch.stack(spk_rec), torch.stack(mem_rec)

    def get_spikes_per_digit(self):
        # Returns average number of spikes per forward pass
        return self.total_spike_count/self.forward_count

    def reset_spikes(self):
        # Reset all the spike counting information
        self.total_spike_count = 0 # How many spikes have been generated, in all layers
        self.forward_count = 0 # How many forward passes have been made, altogether

## 3. The hyperparameter trainer

The trainer class is here to define how the training takes place, given a network and a few training hyperparameters. It makes the objective below a bit more readable.

This class includes an automatic early stopping feature. Early stopping completes the training loop when there the loss has not significantly improved in the last *patience=300* batches.

What does 'significant improvement' mean? Whatever you want. In our case, we make it a function of the number of layers. Deep networks take longer to train and the improvements are, on average, smaller per batch, the formula below is a rough way to account for this:

$$significance_{actual} = \frac{significance_{base}}{layers^3}$$

By default, e.g., for a single layer network, the training run will terminate if there isn't anything more than a 5% improvement in the last 300 training steps.

For a ten layer network, pretty much any tiny improvement is an improvement of significance. This method has been heuristically determined, and seems to work quite well.

In [None]:
# Logs must be in sync with Optunas output
logger = logging.getLogger('optuna')

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
class SNNTrainer:

    def __init__(self, net, trial, num_epochs=30, num_steps=25,
                 learning_rate=2e-3, patience = 300, sig_improvement = 0.05):

        self.net = net.to(device)
        self.num_epochs = num_epochs
        self.num_steps = num_steps
        self.learning_rate = learning_rate
        self.trial = trial
        self.patience = patience

        # Calculate what we mean by significant improvement
        self.sig_improvement = sig_improvement/(net.num_hidden_layers**3)

        self.optimizer = torch.optim.Adam(self.net.parameters(),
                                          lr=learning_rate,
                                          betas=(0.9, 0.999))

        self.loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)
        self.loss_hist = []
        self.acc_hist = []
        self.epochs = 0
        self.batches = 0

    def train(self, train_loader):
        acc = 0
        best_loss = float("inf")
        loss_counter = 0

        for epoch in range(self.num_epochs):
            for i, (data, targets) in enumerate(train_loader):
                data = data.to(device)
                targets = targets.to(device)

                self.net.train()
                spk_rec, _ = self.net(data)
                loss_val = self.loss_fn(spk_rec, targets)
                self.optimizer.zero_grad()
                loss_val.backward()
                self.optimizer.step()

                current_loss = loss_val.item()
                self.loss_hist.append(current_loss)

                # Update display every few iterations.
                if i % 100 == 0 and i != 0:
                    acc = SF.accuracy_rate(spk_rec, targets)
                    self.acc_hist.append(acc)
                    logger.info(f"Trial {self.trial.number}: Training: Epoch {epoch}, Batch {i} "+
                                 f"Loss: {loss_val.item():.4f} (best:{best_loss:.4f} t-{loss_counter}) "+
                                 f"Accuracy: {acc * 100:.2f}%")

                # rudimentary early stop:
                # After the first epoch, if there is no improvement, call it a day
                if current_loss < (1-self.sig_improvement)*best_loss: # an improvement!
                    best_loss = current_loss
                    loss_counter = 0
                else: # No improvement
                    loss_counter += 1

                self.batches += 1

                if epoch > 0 and loss_counter > self.patience:
                    logger.info("Early stopping.")
                    return

            self.epochs += 1

    def get_accuracy(self, test_loader):
        # Get the normal test accuracy for the dataset provided.
        total_acc = 0
        total = 0
        with torch.no_grad():
            self.net.eval()
            for data, targets in test_loader:
                data = data.to(device)
                targets = targets.to(device)
                spk_rec, _ = self.net(data)

                acc = SF.accuracy_rate(spk_rec, targets)
                total_acc += acc * data.size(0)
                total += data.size(0)
        return total_acc / total

## 4. The Objective

We have two targets:

* maximize accuracy
* minimize spikes

We direct Optuna to optimize specific parameters within carefully chosen ranges, ensuring they are neither too broad nor too narrow. This balance is crucial, especially for parameters like the number of steps and epochs, as overly high values can significantly increase training time. Even these hyper-hyperparameters have to be chosen in some way.

We have included a termination policy below that kills training runs that are too expensive.

Following the training run, accuracy is measuring using a separate validation dataset. This contains data that this network has never seen - so we can see how well the network can generalize from the training data.

In [None]:
def optuna_objective(trial, train_loader, test_loader):
    # Suggest hyperparameters, set the approximate range where we want to optimze
    num_steps = trial.suggest_int('Timesteps', 10, 50)
    num_hidden_layers = trial.suggest_int('Hidden Layers', 1, 10)
    num_hidden_neurons = trial.suggest_int('Neurons per Hidden Layer', 5, 300)

    learning_rate = trial.suggest_float('Learning Rate', 1e-4, 1e-2, log=True)
    first_layer_beta = trial.suggest_float('First Layer β', 0.5, 1)

    logger.info(f"Trial {trial.number}: Training: Layers={num_hidden_layers} "+
          f"Neurons={num_hidden_neurons} Steps={num_steps} l1Beta={first_layer_beta:2f}")

    # Skip large networks with many steps, they take too long to train
    # This cuts off a large corner of the parameter space - and the runtime
    if num_hidden_layers*num_hidden_neurons*num_steps > 300*15:
        raise TrialPruned("Too computationally intensive.")

    logger.info(f"Trial {trial.number}: Running")

    net = Net(num_steps, num_hidden_neurons, num_hidden_layers, first_layer_beta)

    # Run the training!
    trainer = SNNTrainer(net, trial, num_steps=num_steps, learning_rate=learning_rate)
    training_start_time = time.time()
    trainer.train(train_loader)

    # Training info - so we can plot it later
    trial.set_user_attr("Training Time [s]", time.time() - training_start_time)
    trial.set_user_attr("Epochs", trainer.epochs)
    trial.set_user_attr("Batches", trainer.batches)

    logger.info(f"Trial {trial.number}: Run on validation set")
    net.reset_spikes() # Only consider spikes/digit after training is complete
    validation_accuracy = trainer.get_accuracy(validation_loader)

    # The thing we really want to optimize for!
    spikes_per_digit = net.get_spikes_per_digit()

    # Define the objective to maximize test accuracy and minimize spike count
    return               validation_accuracy,   spikes_per_digit

# The objectives have a printable name and direction
# Optuna keeps track of the objectives returned as an ordered array,
# so we do, too, all here in one place.
objective_names      = ["Validation Accuracy", "Spikes per Digit"]
objective_directions = ["maximize",            "minimize"]

## 5. The study

Now we can run the things we just defined and see the results! This will take a considerable amount of time. Reduce `additional_trials` if you'd like to speed things up.

In [None]:
# Define the Optuna study
# maximize accuracy
# minimize spikes
study = optuna.create_study(study_name="Minimize spikes, maximize accuracy",
                            directions=objective_directions)

completed_trials = 0 # Nothing has been done yet.

In [None]:
# Helper to figure out how many trials have successfully completed
def completed_trials(study):
    # Counts the completed, successful trials
    return sum(1 for trial in study.trials if trial.state == TrialState.COMPLETE)

In [None]:
# Need at least 3 for the plots below
additional_trials = 50

# Bookkeeping
start_time = time.time()
start_trials = completed_trials(study)
target_trials = start_trials + additional_trials
logger.info(f"Running on device={device}.")
logger.info(f"{start_trials} completed. Running {additional_trials} more to have {target_trials} in total.")

while completed_trials(study) < target_trials:
    # Run trials one at a time so we can stop the code block and keep whatever has been learned
    study.optimize( lambda trial:
                    optuna_objective(trial, train_loader, test_loader),
                    n_trials=additional_trials)

    # Bookkeeping and message generation
    elapsed = time.time() - start_time
    total_completed = completed_trials(study)
    completed = total_completed - start_trials
    remaining_trials = target_trials - completed - start_trials
    logger.info(f"#### Remaining trials {remaining_trials} ####")
    if completed > 0:
        rate = elapsed/(completed)
        remaining_time = (target_trials - completed)*rate
        logger.info(f"Completed {total_completed}/{target_trials} studies at {rate/60:.1f}min/trial")
        if total_completed < target_trials:
            logger.info(f"Remaining time: {remaining_time/60:.1f} minutes to do {remaining_trials} trials.")

logger.info(f"DONE")

## 6. Ponder the Results

Now it's time to actually look at the parameters and think about them!

In [None]:
# Initialize the optunacy plotter
see = OPlot(study, objective_names)

## 6.1 Cause and Effect

We can look at the importance of hyperparameters on outcome metrics, and see what impact a change in hyperparameter input has on an output.

In [None]:
see.parameters()

I am curious about deep networks with many hidden layers and if they are effective here. Let's see:

In [None]:
see.plot("Spikes per Digit", "Validation Accuracy", "Hidden Layers")

In this plot, each dot is a Network, and the color indicates the hidden networks in a given area.
Spikes per Digit is roughly proportional to power consumption, and Validation Accuracy is a measure of how well the network works. So we want to be in the top left corner.
But we can already see: The top left corner is dominated by one-layered networks. So my hypothesis was not right, deep networks make lots of spikes.

It's a bit chaotic, and we absolutely don't care about accuracies below 60%. So let's zoom in a bit:

In [None]:
see.plot("Spikes per Digit", "Validation Accuracy", "Hidden Layers", y_range=(0.60, 1), z_clip=(1,5))

Deeper networks are definitely to the right.

I'd guess that network size and spike rate are correlated.

In [None]:
see.plot("Neurons per Hidden Layer", "Hidden Layers", "Validation Accuracy")

First, note that there are no datapoints in the top right part of the graph. That's because we prune these - lots of deep layers are very computationally expensive.

In any case, the graph is not very informative, we mostly care about accuracies that are at the very least 80%.

In [None]:
see.plot("Neurons per Hidden Layer", "Hidden Layers", "Validation Accuracy", z_clip=(.8,1))

That's nice, it seems we need about 100-200 Neurons (if you mouse over a point you can see the data) on one or two layers, or more on 3 layers. Also, large networks seem not to be very accurate. Also, networks with very few neurons (in the bottom left corner) are not accurate.

Let's look at the spike rate on that same picture. I'll clip it to see the interesting parts.

In [None]:
see.plot("Neurons per Hidden Layer", "Hidden Layers", "Spikes per Digit", z_clip=(20000,80000))

From this, it's clear that large networks are not power efficient.

What about the other parameters?

In [None]:
see.plot("First Layer β", "Validation Accuracy", "Spikes per Digit", z_clip=(30000, 80000), y_range=(0.8,1))

That does not look particularly helpful. It seems like all values for β can provide high accuracy results, some even with low spike counts. It appears that there are more low spike count nets with high accuracy where β is close to 1, so maybe β should be greater than 0.95.

What about timesteps?

In [None]:
see.plot("Timesteps", "Validation Accuracy", "Spikes per Digit", z_clip=(30000, 80000), y_range=(0.8,1))

As we can expect, the longer it runs, the more timesteps we get. It seems that the optimum numer of timesteps is around 15-20.

What about Learning Rate?

In [None]:
see.plot("Learning Rate", "Validation Accuracy", "Spikes per Digit", z_clip=(30000, 80000), y_range=(0.8,1))

Here, it seems like most of the results are in the top left corner. There's an area in the right top corner that is maybe underexplored. That's becuase the learning rate was run with a log distribution:

```    learning_rate = trial.suggest_float('Learning Rate', 1e-4, 1e-2, log=True)```

Maybe in the next run, take that off, and explore the top right corner also!

Optuna has some more built in [plotting features](https://optuna.readthedocs.io/en/stable/reference/visualization/index.html), for example, a way to plot the importance of a parameter for a particular optimization target.

In [None]:
optuna.visualization.plot_param_importances(study,
                                  target=lambda t: t.values[0],
                                  target_name = "Validation accuracy").show()

This plot means that the hyperparameter with the longest bar has the highest impact on accuracy. It does not say wether that number needs to be large or small.

### 6.1 Summary

From looking at the data, we've found that the optimal network is likely around

- 1-2 layer deep
- 100-200 total neurons
- 15-20 timesteps long
- at least 0.9 for the first layer's β

This drastically reduces our searchspace, and we can re-run the optimizer with a focus in that area.