# CODE

### Imports

In [1]:
import os
import json
import random
from datetime import datetime

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import snntorch as snn
import optuna
import tqdm.notebook as tqdm

from WOR_dataset import Wave_Order_Dataset, split_dataset, get_dataloaders
from WOR_plot import (plot_wave, plot_accuracies, plot_loss_curve, plot_metrics, 
                      plot_equal_prediction_values, plot_beta_values, plot_tau_values, 
                      plot_layer_weights, plot_spike_counts, plot_snn_spikes, 
                      plot_membrane_potentials, plot_threshold_potentials)
from WOR_train_val_test import train_model, validate_model, test_model

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

Device: cpu


### Hyperparameters

In [2]:
hyperparams = {
    "seed": 42,
    "std_dev": 0.0,
    "num_samples": 1000,
    "train_ratio": 0.7,
    "validation_ratio": 0.2,
    "test_ratio": 0.1,
    "freq_min": 2,
    "freq_max": 2,
    "amp_min": 1.1,
    "amp_max": 1.1,
    "offset": 1.1,
    "sample_rate": 40,       # NOTE: Change this along with freq_max to avoid aliasing.
    "duration": 3,
    "input_size": 1,
    "output_size": 2,
    "optimizer_betas": (0.99, 0.999),
    "scheduler_step_size": 30,
    "scheduler_gamma": 0.5,
    "L1_lambda": 0.001,
    "N_hidden_weights_gaussian": 30,
    "N_output_weights_gaussian": 1,
    "hidden_reset_mechanism": 'zero',
    "output_reset_mechanism": 'zero',
    "weights_hidden_min_clamped": 0.0,
    "weights_hidden_max_clamped": 2.0,
    "weights_output_min_clamped": 0.0,
    "weights_output_max_clamped": 2.0,
    "N_hidden_tau": 0.1,
    "N_output_tau": 0.1,
    "learn_threshold_hidden": True,
    "learn_threshold_output": True,
    "learn_beta_hidden": True,
    "learn_beta_output": True,
    "phase1": "random_uniform_0_to_2pi",
    "phase2": "random_uniform_0_to_2pi",
    # These parameters will be tuned by HPO:
    "learning_rate": 0.02,
    "hidden_size": 70,
    "threshold_hidden_min": 1.0,
    "threshold_hidden_max": 1.1,
    "penalty_weight": 0.0,
    "num_epochs": 50,
    "batch_size": 32
}

### WOR Model Definition

In [3]:
class Wave_Order_Recognition_SNN(nn.Module):
    def __init__(self, hyperparams):
        super().__init__()
        hidden_size = hyperparams["hidden_size"]
        output_size = hyperparams["output_size"]
        input_size = hyperparams["input_size"]
        sample_rate = hyperparams["sample_rate"]
        duration = hyperparams["duration"]
        deltaT = 1 / sample_rate
        self.num_steps = int(duration / deltaT)
        
        # tau and beta values for hidden and output layers
        tau_hidden = torch.Tensor(hidden_size).uniform_(
            hyperparams["N_hidden_tau"] * hyperparams["freq_min"],
            hyperparams["N_hidden_tau"] * hyperparams["freq_max"]
        )
        tau_output = torch.Tensor(output_size).uniform_(
            hyperparams["N_output_tau"] * hyperparams["freq_min"],
            hyperparams["N_output_tau"] * hyperparams["freq_max"]
        )
        beta_hidden = torch.exp(-deltaT / tau_hidden)
        beta_output = torch.exp(-deltaT / tau_output)
        
        # thresholds: tuned for the hidden layer, fixed range for output
        threshold_hidden = np.random.uniform(
            hyperparams["threshold_hidden_min"],
            hyperparams["threshold_hidden_max"],
            hidden_size
        )
        threshold_output = np.random.uniform(1.0, 1.1, output_size)
        
        # Gaussian initialization parameters for weights.
        N_hidden_weights_gaussian = hyperparams["N_hidden_weights_gaussian"]
        N_output_weights_gaussian = hyperparams["N_output_weights_gaussian"]
        N_hidden_weights_std = np.sqrt(N_hidden_weights_gaussian)
        N_output_weights_std = np.sqrt(N_output_weights_gaussian)
        gaussian_mean_hidden_weights = N_hidden_weights_gaussian / sample_rate
        gaussian_std_hidden_weights = N_hidden_weights_std / sample_rate
        gaussian_mean_output_weights = N_output_weights_gaussian / (sample_rate * hidden_size)
        gaussian_std_output_weights = N_output_weights_std / (sample_rate * hidden_size)
        
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(
            beta=beta_hidden,
            threshold=threshold_hidden,
            learn_beta=hyperparams["learn_beta_hidden"],
            learn_threshold=hyperparams["learn_threshold_hidden"],
            reset_mechanism=hyperparams["hidden_reset_mechanism"],
            reset_delay=False
        )
        self.fc2 = nn.Linear(hidden_size, output_size, bias=False)
        self.lif2 = snn.Leaky(
            beta=beta_output,
            threshold=threshold_output,
            learn_beta=hyperparams["learn_beta_output"],
            learn_threshold=hyperparams["learn_threshold_output"],
            reset_mechanism=hyperparams["output_reset_mechanism"],
            reset_delay=False
        )
        self._initialize_weights(
            gaussian_mean_hidden_weights, gaussian_std_hidden_weights,
            gaussian_mean_output_weights, gaussian_std_output_weights
        )
    
    def _initialize_weights(self, mean_hidden, std_hidden, mean_output, std_output):
        nn.init.normal_(self.fc1.weight, mean=mean_hidden, std=std_hidden)
        nn.init.normal_(self.fc2.weight, mean=mean_output, std=std_output)
    
    def forward(self, x, mem1=None, mem2=None):
        batch_size = x.size(0)
        if mem1 is None:
            mem1 = torch.zeros(batch_size, self.fc1.out_features, device=x.device)
        if mem2 is None:
            mem2 = torch.zeros(batch_size, self.fc2.out_features, device=x.device)
        
        spk1_rec, mem1_rec, spk2_rec, mem2_rec = [], [], [], []
        hidden_spike_count = 0
        output_spike_count = 0
        
        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step].unsqueeze(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            
            spk1_rec.append(spk1)
            mem1_rec.append(mem1)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
            
            hidden_spike_count += spk1.sum().item()
            output_spike_count += spk2.sum().item()
        
        return (torch.stack(spk1_rec, dim=0),
                torch.stack(mem1_rec, dim=0),
                torch.stack(spk2_rec, dim=0),
                torch.stack(mem2_rec, dim=0),
                hidden_spike_count,
                output_spike_count)

### Optuna Objective for HPO

In [4]:
def objective(trial):
    # Override a few hyperparameters using trial suggestions.
    hyperparams["learning_rate"] = trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True)
    hyperparams["threshold_hidden_min"] = trial.suggest_float("threshold_hidden_min", 0.8, 1.2)
    hyperparams["threshold_hidden_max"] = trial.suggest_float("threshold_hidden_max", hyperparams["threshold_hidden_min"] + 0.1, 2.0)
    
    # Create dataset and dataloaders.
    dataset = Wave_Order_Dataset(
        hyperparams["num_samples"],
        hyperparams["sample_rate"],
        hyperparams["duration"],
        hyperparams["freq_min"],
        hyperparams["freq_max"],
        hyperparams["amp_min"],
        hyperparams["amp_max"],
        hyperparams["std_dev"],
        hyperparams["offset"]
    )
    train_dataset, validation_dataset, _ = split_dataset(
        dataset,
        hyperparams["train_ratio"],
        hyperparams["validation_ratio"],
        hyperparams["test_ratio"]
    )
    train_loader, validation_loader, _ = get_dataloaders(
        train_dataset, validation_dataset, [], hyperparams["batch_size"]
    )
    
    # Instantiate model, loss, optimizer, and scheduler.
    model = Wave_Order_Recognition_SNN(hyperparams).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adamax(model.parameters(), lr=hyperparams["learning_rate"],
                                   betas=hyperparams["optimizer_betas"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=hyperparams["scheduler_step_size"],
                                                gamma=hyperparams["scheduler_gamma"])
    
    # Run a short training loop for HPO (using fewer epochs for speed).
    num_epochs = hyperparams["num_epochs"] // 10
    for epoch in range(num_epochs):
        train_loss, _, _, _, _ = train_model(model, train_loader, criterion, optimizer,
                                             epoch, num_epochs, hyperparams["batch_size"],
                                             hyperparams["hidden_size"], hyperparams["output_size"],
                                             hyperparams["weights_hidden_min_clamped"],
                                             hyperparams["weights_hidden_max_clamped"],
                                             hyperparams["weights_output_min_clamped"],
                                             hyperparams["weights_output_max_clamped"],
                                             hyperparams["penalty_weight"],
                                             hyperparams["L1_lambda"], device)
        val_accuracy, val_loss = validate_model(model, validation_loader, criterion, device)
        scheduler.step()
        if epoch % 5 == 0:
            print(f"Trial {trial.number} Epoch {epoch}: Val Acc {val_accuracy:.2f}% Loss {val_loss:.4f}")
    
    # validation accuracy (metric to maximize)
    return val_accuracy

### Full Training Run (after HPO)

In [5]:
def run_training(hyperparams):
    set_seed(hyperparams["seed"])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}")
    
    # folder for saving outputs
    now = datetime.now()
    output_folder = os.path.join(os.getcwd(), "thesis_simulations", now.strftime("run_%Y%m%d_%H%M%S"))
    os.makedirs(output_folder, exist_ok=True)
    
    # Create dataset and dataloaders
    dataset = Wave_Order_Dataset(
        hyperparams["num_samples"],
        hyperparams["sample_rate"],
        hyperparams["duration"],
        hyperparams["freq_min"],
        hyperparams["freq_max"],
        hyperparams["amp_min"],
        hyperparams["amp_max"],
        hyperparams["std_dev"],
        hyperparams["offset"]
    )
    train_dataset, validation_dataset, test_dataset = split_dataset(
        dataset,
        hyperparams["train_ratio"],
        hyperparams["validation_ratio"],
        hyperparams["test_ratio"]
    )
    train_loader, validation_loader, test_loader = get_dataloaders(
        train_dataset, validation_dataset, test_dataset, hyperparams["batch_size"]
    )
    
    # model, loss, optimizer, and scheduler.
    model = Wave_Order_Recognition_SNN(hyperparams).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adamax(model.parameters(), lr=hyperparams["learning_rate"],
                                   betas=hyperparams["optimizer_betas"])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=hyperparams["scheduler_step_size"],
                                                gamma=hyperparams["scheduler_gamma"])
    
    # Training loop
    num_epochs = hyperparams["num_epochs"]
    for epoch in range(num_epochs):
        train_loss, _, _, _, _ = train_model(model, train_loader, criterion, optimizer,
                                             epoch, num_epochs, hyperparams["batch_size"],
                                             hyperparams["hidden_size"], hyperparams["output_size"],
                                             hyperparams["weights_hidden_min_clamped"],
                                             hyperparams["weights_hidden_max_clamped"],
                                             hyperparams["weights_output_min_clamped"],
                                             hyperparams["weights_output_max_clamped"],
                                             hyperparams["penalty_weight"],
                                             hyperparams["L1_lambda"], device)
        val_accuracy, val_loss = validate_model(model, validation_loader, criterion, device)
        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs}: Val Acc = {val_accuracy:.2f}% | Val Loss = {val_loss:.4f}")
    
    # Evaluate on test data
    test_metrics = test_model(model, test_loader, criterion, device)
    print("Test Metrics:")
    print(test_metrics)
    
    # Save the final model
    model_save_path = os.path.join(output_folder, "final_model.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")
    
    # Return a dictionary with useful outputs.
    return {
        "output_folder": output_folder,
        "train_loader": train_loader,
        "validation_loader": validation_loader,
        "test_loader": test_loader,
        "model": model,
        "test_metrics": test_metrics,
        "num_epochs": num_epochs,
        #"deltaT": 1 / hyperparams["sample_rate"]
    }

### 'mode' Selection and Execution

In [6]:
# Set mode = "hpo" for hyperparameter optimization, or "train" for a full training run.
mode = "hpo"  # change to "hpo" to run hyperparameter optimization

if mode == "hpo":
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=10)  # adjust n_trials as needed
    best_trial = study.best_trial
    print("Best trial:")
    print(f"  Value: {best_trial.value}")
    print("  Params:")
    for key, value in best_trial.params.items():
        print(f"    {key}: {value}")
    # Save best hyperparameters to a JSON file.
    with open("best_hyperparameters.json", "w") as f:
        json.dump(best_trial.params, f, indent=4)
    print("Best hyperparameters saved to best_hyperparameters.json")
    
elif mode == "train":
    # Load best hyperparameters (ensure you have run an HPO run first so that the JSON exists).
    with open("best_hyperparameters.json", "r") as f:
        best_tuned = json.load(f)
    # Merge tuned hyperparameters with the fixed ones.
    hyperparams = {**hyperparams, **best_tuned}
    results = run_training(hyperparams)

[I 2025-02-17 01:14:11,627] A new study created in memory with name: no-name-14a92d30-b320-40ab-82d9-8bac428edefd


Trial 0 Epoch 0: Val Acc 47.62% Loss 0.6931


[I 2025-02-17 01:14:15,665] Trial 0 finished with value: 50.476190476190474 and parameters: {'learning_rate': 0.0006063220466440857, 'threshold_hidden_min': 1.1549939282394002, 'threshold_hidden_max': 1.668972803525267}. Best is trial 0 with value: 50.476190476190474.


Trial 1 Epoch 0: Val Acc 50.52% Loss 0.6931


[I 2025-02-17 01:14:18,810] Trial 1 finished with value: 51.54639175257732 and parameters: {'learning_rate': 0.00016240583045087637, 'threshold_hidden_min': 0.829079692120461, 'threshold_hidden_max': 1.4698093570537463}. Best is trial 1 with value: 51.54639175257732.


Trial 2 Epoch 0: Val Acc 42.11% Loss 0.6931


[I 2025-02-17 01:14:21,902] Trial 2 finished with value: 56.8421052631579 and parameters: {'learning_rate': 0.008762316412516962, 'threshold_hidden_min': 0.9280475465880023, 'threshold_hidden_max': 1.8567201498310255}. Best is trial 2 with value: 56.8421052631579.


Trial 3 Epoch 0: Val Acc 47.47% Loss 0.6931


[I 2025-02-17 01:14:25,017] Trial 3 finished with value: 47.474747474747474 and parameters: {'learning_rate': 0.023637447009193763, 'threshold_hidden_min': 0.8597470196894832, 'threshold_hidden_max': 0.9603486480830844}. Best is trial 2 with value: 56.8421052631579.


Trial 4 Epoch 0: Val Acc 50.50% Loss 0.6931


[I 2025-02-17 01:14:28,550] Trial 4 finished with value: 42.57425742574257 and parameters: {'learning_rate': 0.009174011771649733, 'threshold_hidden_min': 1.0983704914118737, 'threshold_hidden_max': 1.6584158234076558}. Best is trial 2 with value: 56.8421052631579.


Trial 5 Epoch 0: Val Acc 55.10% Loss 0.6931


[I 2025-02-17 01:14:31,715] Trial 5 finished with value: 48.97959183673469 and parameters: {'learning_rate': 0.06302987116092815, 'threshold_hidden_min': 1.08741923637226, 'threshold_hidden_max': 1.8512458816691852}. Best is trial 2 with value: 56.8421052631579.


Trial 6 Epoch 0: Val Acc 47.42% Loss 0.6931


[I 2025-02-17 01:14:35,014] Trial 6 finished with value: 44.329896907216494 and parameters: {'learning_rate': 0.018371435314968373, 'threshold_hidden_min': 0.9356750170854905, 'threshold_hidden_max': 1.9991098258323243}. Best is trial 2 with value: 56.8421052631579.


Trial 7 Epoch 0: Val Acc 49.00% Loss 0.6931


[I 2025-02-17 01:14:38,258] Trial 7 finished with value: 51.0 and parameters: {'learning_rate': 0.0003791190668625142, 'threshold_hidden_min': 0.8888478344496423, 'threshold_hidden_max': 1.6290291333726536}. Best is trial 2 with value: 56.8421052631579.


Trial 8 Epoch 0: Val Acc 52.00% Loss 0.6931


[I 2025-02-17 01:14:41,488] Trial 8 finished with value: 54.0 and parameters: {'learning_rate': 0.007896062006436597, 'threshold_hidden_min': 0.9013967165843327, 'threshold_hidden_max': 1.9426632892972697}. Best is trial 2 with value: 56.8421052631579.


Trial 9 Epoch 0: Val Acc 52.13% Loss 0.6931


[I 2025-02-17 01:14:44,508] Trial 9 finished with value: 43.61702127659574 and parameters: {'learning_rate': 0.0036577622496122616, 'threshold_hidden_min': 0.8690648074588009, 'threshold_hidden_max': 1.0919515328299072}. Best is trial 2 with value: 56.8421052631579.


Best trial:
  Value: 56.8421052631579
  Params:
    learning_rate: 0.008762316412516962
    threshold_hidden_min: 0.9280475465880023
    threshold_hidden_max: 1.8567201498310255
Best hyperparameters saved to best_hyperparameters.json


# PLOTS

### Sample Dataset

In [7]:
plot_wave(train_loader, save_path=os.path.join(unique_folder_name, 'wave_samples.png'))

NameError: name 'train_loader' is not defined

### Accuracy

In [None]:
plot_accuracies(num_epochs, test_accuracies, validation_accuracies, os.path.join(unique_folder_name, 'accuracy_plot'))

### Loss

In [None]:
plot_loss_curve(loss_hist, test_loss_hist, num_epochs, os.path.join(unique_folder_name, 'loss_curve'))

### Confusion Matrix

In [None]:
plot_metrics(test_metrics,os.path.join(unique_folder_name, 'evaluation_plots'))

### Equal Prediction Count

In [None]:
plot_equal_prediction_values(equal_prediction_values, num_epochs, os.path.join(unique_folder_name, 'equal_prediction_values'))

### Beta Hidden Layer

In [None]:
plot_beta_values(beta1_values, num_epochs, os.path.join(unique_folder_name, 'beta_hidden_layer'), layer_name='Hidden')

### Beta Output Layer

In [None]:
plot_beta_values(beta2_values, num_epochs, os.path.join(unique_folder_name, 'beta_output_layer'), layer_name='Output')

### Tau Hidden Layer

In [None]:
plot_tau_values(beta1_values, num_epochs, deltaT, os.path.join(unique_folder_name, 'tau_hidden_layer'), layer_name='Hidden')

### Tau Output Layer

In [None]:
plot_tau_values(beta2_values, num_epochs, deltaT, os.path.join(unique_folder_name, 'tau_output_layer'), layer_name='Output')

### Weights Hidden Layer

In [None]:
plot_layer_weights(weights_hidden_layer, num_epochs, os.path.join(unique_folder_name, 'weights_hidden_layer'), layer_name='Hidden')

### Weights Output Layer

In [None]:
plot_layer_weights(weights_output_layer, num_epochs, os.path.join(unique_folder_name, 'weights_output_layer'), layer_name='Output')

### Spike Count vs Epochs

In [None]:
plot_spike_counts(hidden_spike_count, output_spike_count, output_spike_counts_neuron0, output_spike_counts_neuron1, num_epochs, os.path.join(unique_folder_name, 'spike_counts'))

### Spikes Hidden Layer

In [None]:
plot_snn_spikes(model, test_loader, device, os.path.join(unique_folder_name, 'hidden_layer_spikes'), layer_name='Hidden', layer_size=hidden_size, num_steps=num_steps)

### Spikes Output Layer

In [None]:
plot_snn_spikes(model, test_loader, device, os.path.join(unique_folder_name, 'output_layer_spikes'), layer_name='Output', layer_size=output_size, num_steps=num_steps)

### Vmem Hidden Layer

In [None]:
plot_membrane_potentials(model, test_loader, device, 'Hidden', hidden_size, num_steps, os.path.join(unique_folder_name, 'hidden_membrane_potentials'))

### Vmem Output Layer

In [None]:
plot_membrane_potentials(model, test_loader, device, 'Output', output_size, num_steps, os.path.join(unique_folder_name, 'output_membrane_potentials'))

### Threshold Hidden Layer

In [None]:
plot_threshold_potentials(threshold_hidden_layer, num_epochs, os.path.join(unique_folder_name, 'threshold_hidden_layer'), 'Hidden')

### Threshold Output Layer

In [None]:
plot_threshold_potentials(threshold_output_layer, num_epochs, os.path.join(unique_folder_name, 'threshold_output_layer'), 'Output')