## Imports

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import snntorch as snn
import snntorch.functional as SF
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import optuna

## Dataset

In [None]:
class TSVData(Dataset):
    """
    Custom Dataset for TSV files where:
      - The first column is the label (-1 for abnormal, 1 for normal)
      - The remaining columns are numeric features (each sample is a time series).
      
    The dataset:
      1. Reads the TSV file using pandas.
      2. Converts labels from {-1, 1} to {0, 1} (with 0=abnormal, 1=normal).
      3. Reshapes the features so that each sample is of shape [time_steps, 1].
    """
    def __init__(self, file_path):
        self.data = pd.read_csv(file_path, sep='\t', header=0)
        raw_labels = self.data.iloc[:, 0].values.astype(int)
        # Convert: -1 -> 0 (abnormal) and 1 -> 1 (normal)
        self.labels = ((raw_labels == 1).astype(int))
        self.features = self.data.iloc[:, 1:].values.astype(np.float32)
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Each sample: shape [time_steps] --> reshape to [time_steps, 1]
        x = torch.tensor(self.features[idx], dtype=torch.float32).unsqueeze(1)
        # For CrossEntropyLoss, labels should be long integers
        y = torch.tensor(self.labels[idx], dtype=torch.long)
        return x, y

## Liquid State Machine - Spiking Reservoir

In [None]:
class SpikingReservoir(nn.Module):
    def __init__(self,
                 input_size,         
                 reservoir_size,    
                 num_readout,      
                 num_connected,     
                 spectral_radius,    
                 beta_input,         
                 beta_reservoir,     
                 threshold,          
                 sparsity_percentage,# percent of recurrent connections to drop, for sparsity (e.g., 10 for 10%)
                 spike_grad=None,    # gradient surrogate (None = arctan, or I can specify something like surrogate gradient)
                 reset_mechanism='zero',  # or 'subtract'
                 reset_delay=0):     # delay before resetting membrane potential, so it can spike when reaching threshold or the next time step
        super().__init__()
        
        self.input_fc = nn.Linear(1, 1, bias=False)
        self.input_lif = snn.Leaky(beta=beta_input,
                                   threshold=threshold,
                                   spike_grad=spike_grad,
                                   reset_mechanism=reset_mechanism,
                                   reset_delay=reset_delay)
        
        self.reservoir_fc = nn.Linear(1, reservoir_size, bias=False)
        self.reservoir_lif = snn.RLeaky(beta=beta_reservoir,
                                        linear_features=reservoir_size,
                                        threshold=threshold,
                                        spike_grad=spike_grad,
                                        reset_mechanism=reset_mechanism,
                                        reset_delay=reset_delay,
                                        all_to_all=True)
        # sparsity mask: keep (100 - sparsity_percentage)% of connections to create sparsity
        sparsity_mask = (torch.rand(reservoir_size, reservoir_size) >
                         (sparsity_percentage/100)).float()
        self.register_buffer("sparsity_mask", sparsity_mask)
        with torch.no_grad():
            self.reservoir_lif.recurrent.weight.mul_(self.sparsity_mask) # mul_ is in-place multiplication, so it modifies the tensor in place without creating a new tensor
            W = self.reservoir_lif.recurrent.weight
            # torch.linalg.eigvals returns complex eigenvalues; taking the maximum absolute value for the spectral radius
            current_radius = torch.linalg.eigvals(W).abs().max()
            scaling_factor = spectral_radius / current_radius
            W.mul_(scaling_factor)
        
        # for readout layer I select a subset of reservoir neurons (randomly, dependent on seed, so seed can be a hyperparam)
        readout_indices = torch.randperm(reservoir_size)[:num_connected]
        self.register_buffer("readout_indices", readout_indices) # register_buffer is used to store tensors that are not trainable
        
        self.readout_fc = nn.Linear(num_connected, num_readout, bias=True)
        
        # freeze the parameters of all layers except the readout so they are not trained
        for param in self.input_fc.parameters():
            param.requires_grad = False
        for param in self.input_lif.parameters():
            param.requires_grad = False
        for param in self.reservoir_fc.parameters():
            param.requires_grad = False
        for param in self.reservoir_lif.parameters():
            param.requires_grad = False
        # only self.readout_fc remains trainable
    
    def forward(self, x):
        """
        x: tensor of shape [batch_size, time_steps, 1]
        Processes the time series sequentially.
        
        Returns:
          logits: [batch_size, num_readout]
          selected_features: the features fed to the readout (for logging)
          avg_spk_rate_input: average spiking rate of the input neuron over the sample and time steps
          avg_spk_rate_reservoir: average spiking rate of the reservoir neurons (all) over time
        """
        batch_size, time_steps, _ = x.shape
        device = x.device
        
        input_mem = torch.zeros(batch_size, 1, device=device)
        reservoir_mem = torch.zeros(batch_size, self.reservoir_fc.out_features, device=device)
        reservoir_spk = torch.zeros(batch_size, self.reservoir_fc.out_features, device=device)
        
        # For logging spiking activity we accumulate spike counts and record reservoir Vmem at each time step
        total_input_spikes = torch.zeros(batch_size, device=device)
        total_reservoir_spikes = torch.zeros(batch_size, device=device)
        reservoir_mem_record = []
        
        for t in range(time_steps):
            x_t = x[:, t, :]
            input_current = self.input_fc(x_t)
            input_spk, input_mem = self.input_lif(input_current, input_mem)
            total_input_spikes += input_spk.squeeze(1)
            
            # Map the input spike to reservoir dimension.
            reservoir_current = self.reservoir_fc(input_spk)
            reservoir_spk, reservoir_mem = self.reservoir_lif(reservoir_current, reservoir_spk, reservoir_mem)
            total_reservoir_spikes += reservoir_spk.sum(dim=1)
            reservoir_mem_record.append(reservoir_mem)
        
        # averaging the reservoir membrane potentials over time
        reservoir_mem_avg = torch.stack(reservoir_mem_record, dim=0).mean(dim=0)  # [batch_size, reservoir_size]
        # select a fixed subset of reservoir neurons.
        selected_features = reservoir_mem_avg[:, self.readout_indices]  # [batch_size, num_connected]
        # pass through the trainable readout linear layer.
        logits = self.readout_fc(selected_features)  # [batch_size, num_readout]
        
        # computing average spiking rates (per neuron, per time step).
        spk_rate_input = (total_input_spikes / time_steps).mean()
        spk_rate_reservoir = (total_reservoir_spikes / time_steps).mean()
        
        return logits, selected_features, spk_rate_input, spk_rate_reservoir

## Train, Validate, Test & Metrics

In [None]:
def train_model(model, train_loader, criterion, optimizer, device, writer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits, _, spk_rate_input, spk_rate_reservoir = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    writer.add_scalar('Train/Loss', avg_loss, epoch)
    writer.add_scalar('Train/SpikingRate_Input', spk_rate_input.item(), epoch)
    writer.add_scalar('Train/SpikingRate_Reservoir', spk_rate_reservoir.item(), epoch)

    for name, param in model.readout_fc.named_parameters():
        writer.add_histogram(f'Weights/Readout_{name}', param, epoch)
    return avg_loss

def test_model(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_logits = []
    all_labels = []
    spk_rate_inputs = []
    spk_rate_reservoirs = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            logits, _, spk_rate_input, spk_rate_reservoir = model(inputs)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())
            spk_rate_inputs.append(spk_rate_input.item())
            spk_rate_reservoirs.append(spk_rate_reservoir.item())
    avg_loss = running_loss / len(test_loader)
    all_logits = torch.cat(all_logits, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    # for classification we take the argmax since we have 2 outputs
    preds = torch.argmax(all_logits, dim=1)
    avg_spk_rate_input = np.mean(spk_rate_inputs)
    avg_spk_rate_reservoir = np.mean(spk_rate_reservoirs)
    return all_labels.numpy(), preds.numpy(), avg_loss, {'input': avg_spk_rate_input, 'reservoir': avg_spk_rate_reservoir}

def evaluate_metrics(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='binary')
    rec = recall_score(y_true, y_pred, average='binary')
    f1 = 2 * prec * rec / (prec + rec + 1e-8)
    # REMEMBER: for ROC-AUC, if only the predictions are available (not the probabilities) this is approximate.
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    cm = confusion_matrix(y_true, y_pred)
    return acc, prec, rec, f1, roc_auc, cm

def plot_confusion_matrix(cm, epoch):
    plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(f'Confusion Matrix (Epoch {epoch})')
    plt.colorbar()
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ['Abnormal', 'Normal'])
    plt.yticks(tick_marks, ['Abnormal', 'Normal'])
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    fig = plt.gcf()
    return fig

## Optuna

In [None]:
def objective(trial):
    reservoir_size = trial.suggest_int("reservoir_size", 20, 100)
    num_connected = trial.suggest_int("num_connected", 10, reservoir_size)
    spectral_radius = trial.suggest_float("spectral_radius", 0.5, 1.5)
    beta_input = trial.suggest_float("beta_input", 0.7, 0.99)
    beta_reservoir = trial.suggest_float("beta_reservoir", 0.7, 0.99)
    threshold = trial.suggest_float("threshold", 0.5, 1.5)
    sparsity_percentage = trial.suggest_float("sparsity_percentage", 0, 50)
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-4, 1e-2)
    
    # For file paths and other parameters we assume they come from a global dict (gs).
    global gs
    train_dataset = TSVData(gs['train_tsv'])
    test_dataset = TSVData(gs['test_tsv'])
    
    # WeightedRandomSampler for training data, because the classes are imbalanced so we need to upsample the minority class
    labels = train_dataset.labels
    class_counts = np.bincount(labels)
    sample_weights = [1.0 / class_counts[label] for label in labels]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    
    train_loader = DataLoader(train_dataset, batch_size=gs['batch_size'], sampler=sampler)
    test_loader = DataLoader(test_dataset, batch_size=gs['batch_size'], shuffle=False)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SpikingReservoir(
        input_size=1,
        reservoir_size=reservoir_size,
        num_readout=2,
        num_connected=num_connected,
        spectral_radius=spectral_radius,
        beta_input=beta_input,
        beta_reservoir=beta_reservoir,
        threshold=threshold,
        sparsity_percentage=sparsity_percentage,
        spike_grad=None,
        reset_mechanism='zero',
        reset_delay=0
    )
    model.to(device)
    
    optimizer = torch.optim.Adam(model.readout_fc.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    # Run for a small number of epochs. NOTE: I have to remember to change this 5 or put a hyperparameter. or use gs['epochs_search']
    for epoch in range(5):
        train_model(model, train_loader, criterion, optimizer, device, SummaryWriter(log_dir='./temp_runs'), epoch)
    
    # Evaluate on test set.
    y_true, y_pred, _, _ = test_model(model, test_loader, criterion, device)
    acc, _, _, _, _, _ = evaluate_metrics(y_true, y_pred)
    return acc

## main() definition

In [None]:
def main(search_space):
    # global search space
    global gs
    gs = search_space 
    
    seed = search_space.get('seed', 42)
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    train_dataset = TSVData(search_space['train_tsv'])
    test_dataset = TSVData(search_space['test_tsv'])
    
    # WeightedRandomSampler for oversampling the minority class.
    labels = train_dataset.labels
    class_counts = np.bincount(labels)
    sample_weights = [1.0 / class_counts[label] for label in labels]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    
    train_loader = DataLoader(train_dataset, batch_size=search_space['batch_size'], sampler=sampler)
    test_loader = DataLoader(test_dataset, batch_size=search_space['batch_size'], shuffle=False)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # with Optuna search, running a short hyperparameter search first
    if search_space.get('use_optuna', False):
        study = optuna.create_study(direction="maximize")
        study.optimize(objective, n_trials=search_space.get('n_trials', 20))
        best_params = study.best_params
        print("Best trial:", study.best_trial)
        print("Best params:", best_params)
        # Update the search_space with the best found parameters.
        search_space.update(best_params)
    
    # then build the final model with hyperparameters from search_space
    model = SpikingReservoir(
        input_size=1,
        reservoir_size=search_space['reservoir_size'],
        num_readout=2,
        num_connected=search_space['num_connected'],
        spectral_radius=search_space['spectral_radius'],
        beta_input=search_space['beta_input'],
        beta_reservoir=search_space['beta_reservoir'],
        threshold=search_space['threshold'],
        sparsity_percentage=search_space['sparsity_percentage'],
        spike_grad=None,
        reset_mechanism='zero',
        reset_delay=0
    )
    model.to(device)
    
    # Only train the readout layer.
    optimizer = torch.optim.Adam(model.readout_fc.parameters(), lr=search_space['learning_rate'])
    criterion = nn.CrossEntropyLoss()
    
    # Set up TensorBoard logging.
    writer = SummaryWriter(log_dir=search_space.get('tensorboard_log_dir', './runs'))
    
    # Final training (using many more epochs than the search)
    num_epochs = search_space['epochs_final']
    for epoch in range(num_epochs):
        train_loss = train_model(model, train_loader, criterion, optimizer, device, writer, epoch)
        y_true, y_pred, test_loss, spk_rates = test_model(model, test_loader, criterion, device)
        acc, prec, rec, f1, roc_auc, cm = evaluate_metrics(y_true, y_pred)
        
        writer.add_scalar('Test/Accuracy', acc, epoch)
        writer.add_scalar('Test/Precision', prec, epoch)
        writer.add_scalar('Test/Recall', rec, epoch)
        writer.add_scalar('Test/F1', f1, epoch)
        writer.add_scalar('Test/ROC_AUC', roc_auc, epoch)
        writer.add_scalar('Test/Loss', test_loss, epoch)
        writer.add_scalar('Spiking/InputRate', spk_rates['input'], epoch)
        writer.add_scalar('Spiking/ReservoirRate', spk_rates['reservoir'], epoch)
        
        # Also log the confusion matrix as an image.
        fig = plot_confusion_matrix(cm, epoch)
        writer.add_figure('Confusion_Matrix', fig, epoch)
        plt.close(fig)
        
        print(f"Epoch {epoch:03d}: Train Loss={train_loss:.4f} | Test Loss={test_loss:.4f} | Acc={acc:.4f} | Prec={prec:.4f} | Rec={rec:.4f} | F1={f1:.4f} | ROC_AUC={roc_auc:.4f}")
    
    # Save the best model.
    save_path = os.path.join(search_space.get('save_dir', '.'), 'best_model.pth')
    torch.save(model.state_dict(), save_path)
    writer.close()

## main() usage

In [None]:
if __name__ == '__main__':
    search_space = {
        # Data paths (change these to your actual TSV file locations)
        'train_tsv': '/Users/mikel/Documents/GitHub/polimikel/data/UCR_dataset/Wafer/Wafer_TRAIN.tsv',
        'test_tsv': '/Users/mikel/Documents/GitHub/polimikel/data/UCR_dataset/Wafer/Wafer_TEST.tsv',
        
        'reservoir_size': 100,      
        'num_connected': 50,        
        'spectral_radius': 0.9,    
        'beta_input': 0.9,
        'beta_reservoir': 0.9,
        'threshold': 1.0,
        'sparsity_percentage': 10, 
        
        'learning_rate': 1e-3,
        'batch_size': 32,
        'epochs_search': 5,        
        'epochs_final': 20,         
        
        'seed': 42,
        
        'tensorboard_log_dir': './runs/experiment1',
        
        'save_dir': './saved_models',
        
        # For Optuna:
        'use_optuna': True,         # set to False if we want to skip hyperparameter search
        'n_trials': 20
    }
    
    # Create save directory if it does not exist.
    os.makedirs(search_space['save_dir'], exist_ok=True)
    
    main(search_space)