# Comparative Analysis of Neural Network Architectures for Audio Classification on the Heidelberg Speech Dataset

## Overview
This notebook implements a comprehensive benchmark study comparing multiple neural network architectures for audio classification tasks using the Heidelberg Speech Dataset (SHD). The implementation evaluates seven different models across four architectural paradigms: Spiking Neural Networks (SNNs), Spiking Recurrent Neural Networks (SRNNs), Long Short-Term Memory networks (LSTMs), and Convolutional Neural Networks (CNNs).

In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import time
import os, sys
from tqdm.notebook import tqdm
import gc


torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from heidelberg_dataset import HeidelbergDatasetCached
from src.utils.save_model import save_model_checkpoint
from models.architectures.snn import SpikingNeuralNetwork
from models.architectures.lstm import LSTMClassifier
from models.architectures.cnn import CNNClassifier
from models.loss.max_over_time_loss import MaxOverTimeLoss
from models.loss.standard_cross_entropy import LSTMCNNLoss

Using device: cpu


In [2]:
CONFIG = {
    'train_file': '../data/shd_train.h5',
    'test_file': '../data/shd_test.h5',
    'batch_size': 256,
    'num_epochs': 300,
    'learning_rate': 0.002,
    'dt': 0.14e-3, 
    'T_max': 1.4, 
    'input_size': 700,
    'hidden_size': 200,
    'output_size': 20,
    'nb_steps': 200,
    'dropout': 0.1,
    'save_models': True
}

In [3]:
def custom_collate_fn(batch):
    """Custom collate function to handle variable-length sequences"""
    data, labels = zip(*batch)
    data = torch.stack(data, dim=0)
    labels = torch.stack(labels, dim=0)
    return data, labels

In [4]:
def single_model_dataset_setup(model_type, CONFIG, val_split=0.12):
    """Load datasets with proper train/val split - UPDATED for paper compatibility"""
    
    print(f"\n{'='*60}")
    print(f"Loading {model_type.upper()} datasets...")
    
    if model_type == 'snn':
        nb_steps = CONFIG['nb_steps']
    else:
        nb_steps = int(CONFIG['T_max'] / 0.01)
    
    print(f"Using {nb_steps} time steps for {model_type.upper()}")
    
    # Load full training dataset
    full_train_dataset = HeidelbergDatasetCached(
        CONFIG['train_file'], 'train', 
        dt=CONFIG['dt'], T_max=CONFIG['T_max'], 
        model_type=model_type, nb_steps=nb_steps
    )
    
    # Create train/val split
    total_size = len(full_train_dataset)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size
    
    # Random split
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)  # For reproducibility
    )
    
    print(f"Dataset sizes - Train: {train_size}, Val: {val_size}")
    
    # Create dataloaders with custom collate function
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=True, 
        num_workers=0,
        collate_fn=custom_collate_fn
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=False, 
        num_workers=0,
        collate_fn=custom_collate_fn
    )
    
    return train_loader, val_loader

In [5]:
def single_model_dataset_clean(train_loader, val_loader):
  del train_loader, val_loader
  gc.collect()


In [6]:
def evaluate_model(model, data_loader, is_snn=True, loss_type='max_over_time'):
    """Generic evaluation function"""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for data, labels in data_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            
            if is_snn and isinstance(outputs, tuple):
                outputs, _ = outputs  
                max_outputs, _ = torch.max(outputs, dim=1)
                predictions = torch.argmax(max_outputs, dim=1)
            elif isinstance(outputs, dict):  # LSTM
                if loss_type == 'last_time_step':
                    predictions = torch.argmax(outputs['last_time_step'], dim=1)
                else:
                    max_outputs, _ = torch.max(outputs['all_time_steps'], dim=1)
                    predictions = torch.argmax(max_outputs, dim=1)
            else:  # CNN
                predictions = torch.argmax(outputs, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return accuracy_score(all_labels, all_predictions)

In [7]:
def train_model(model, train_loader, val_loader, model_name, num_epochs=150, lr=1e-3, 
                is_snn=True, loss_type='max_over_time'):
    """Generic training function for all models"""
    
    print(f"\n🚀 Training {model_name}...")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    model.to(device)
    
    optimizer = torch.optim.Adamax(model.parameters(), lr=lr, betas=(0.9,0.999))
    
    if is_snn:
        criterion = MaxOverTimeLoss()
    else:
        criterion = LSTMCNNLoss(loss_type=loss_type)
    
    # Training history
    train_losses = []
    val_accuracies = []
    best_val_acc = 0.0
    
    for epoch in tqdm(range(num_epochs), desc=f"Training {model_name}"):
        model.train()
        epoch_loss = 0.0
        
        for batch_idx, (data, labels) in enumerate(train_loader):

            data, labels = data.to(device), labels.to(device)
            
            outputs = model(data)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        if (epoch % 5 == 0 and epoch != 0) or epoch == 1:
            val_acc = evaluate_model(model, val_loader, is_snn=is_snn, loss_type=loss_type)
            val_accuracies.append(val_acc)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
            
            print(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Val Acc = {val_acc:.4f}")

    print(f"✅ {model_name} training completed! Best Val Acc: {best_val_acc:.4f}")
    
    final_val_acc = val_accuracies[-1]
    
    save_model_checkpoint(
        model=model,
        optimizer=optimizer,
        epoch=num_epochs-1,
        loss=train_losses[-1],
        accuracy=final_val_acc,
        model_name=f"{model_name}_final",
        checkpoint_dir="../models/checkpoints")
    
    return {
        'train_losses': train_losses,
        'val_accuracies': val_accuracies,
        'best_val_acc': best_val_acc,
        'optimizer': optimizer
    }

In [8]:
print("Initializing models...")

models = {}

"""
# Spiking Neural Networks
models['SNN_1Layer'] = SpikingNeuralNetwork(
    input_size=CONFIG['input_size'],
    hidden_size=CONFIG['hidden_size'],
    output_size=CONFIG['output_size'],
    num_layers=1,
    recurrent=False,
    dt=CONFIG['dt'],
    reg_strength=2e-6
)

models['SNN_2Layer'] = SpikingNeuralNetwork(
    input_size=CONFIG['input_size'],
    hidden_size=CONFIG['hidden_size'],
    output_size=CONFIG['output_size'],
    num_layers=2,
    recurrent=False,
    dt=CONFIG['dt']
)

models['SNN_3Layer'] = SpikingNeuralNetwork(
    input_size=CONFIG['input_size'],
    hidden_size=CONFIG['hidden_size'],
    output_size=CONFIG['output_size'],
    num_layers=3,
    recurrent=False,
    dt=CONFIG['dt']
)

models['SRNN'] = SpikingNeuralNetwork(
    input_size=CONFIG['input_size'],
    hidden_size=CONFIG['hidden_size'],
    output_size=CONFIG['output_size'],
    dt=CONFIG['dt'],
    recurrent=True
)
"""

# LSTM
models['LSTM'] = LSTMClassifier(
    input_size=CONFIG['input_size'],
    hidden_size=CONFIG['hidden_size'],
    output_size=CONFIG['output_size'],
    dropout=CONFIG['dropout']
)

# CNN
models['CNN'] = CNNClassifier(
    input_channels=64,  # Spatially binned channels for CNN
    output_size=CONFIG['output_size'],
    dropout=CONFIG['dropout']
)

print("✅ All models initialized!")
# Print model summaries
for name, model in models.items():
    print(f"\n{name}: {sum(p.numel() for p in model.parameters()):,} parameters")

Initializing models...
✅ All models initialized!

LSTM: 725,620 parameters

CNN: 455,700 parameters




In [9]:
# Training configuration for each model
training_configs = {
    'SNN_1Layer': {'model_type': 'snn', 'is_snn': True, 'loss_type': 'max_over_time'},
    'SNN_2Layer': {'model_type': 'snn', 'is_snn': True, 'loss_type': 'max_over_time'},
    'SNN_3Layer': {'model_type': 'snn', 'is_snn': True, 'loss_type': 'max_over_time'},
    'SRNN': {'model_type': 'snn', 'is_snn': True, 'loss_type': 'max_over_time'},
    'LSTM': {'model_type': 'lstm', 'is_snn': False, 'loss_type': 'max_over_time'},
    'CNN': {'model_type': 'cnn', 'is_snn': False, 'loss_type': 'standard'}
}


In [None]:
training_results = {}

print("🎯 Starting training for all models...")
print("=" * 60)

# Train each model
for model_name in models.keys():
    config = training_configs[model_name]
    train_loader, val_loader = single_model_dataset_setup(
        model_type=config['model_type'],
        CONFIG=CONFIG
    )

    start_time = time.time()

    print(f"Training {models[model_name]}...")

    result = train_model(
        model=models[model_name],
        train_loader=train_loader,
        val_loader=val_loader,
        model_name=model_name,
        num_epochs=CONFIG['num_epochs'],
        lr=CONFIG['learning_rate'],
        is_snn=config['is_snn'],
        loss_type=config['loss_type']
    )

    training_time = time.time() - start_time
    result['training_time'] = training_time
    training_results[model_name] = result

    single_model_dataset_clean(train_loader, val_loader)

    print(f"⏱️  {models[model_name]} Training time: {training_time/60:.2f} minutes")
    print("=" * 60)

print("🎉 All models trained successfully!")


🎯 Starting training for all models...

Loading LSTM datasets...
Using 140 time steps for LSTM
Loading train dataset into memory...
✓ Loaded 8156 samples
✓ Time bins: 140 steps over 1.4s
Dataset sizes - Train: 7178, Val: 978
Training LSTMClassifier(
  (lstm): LSTM(700, 200, batch_first=True, dropout=0.1)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc): Linear(in_features=200, out_features=20, bias=True)
)...

🚀 Training LSTM...
Parameters: 725,620


Training LSTM:   0%|          | 0/300 [00:00<?, ?it/s]

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for i, (model_name, results) in enumerate(training_results.items()):

    ax = axes[i]

    # Plot training loss
    ax.plot(results['train_losses'], label='Training Loss', alpha=0.7)

    # Plot validation accuracy (scaled for visualization)
    val_epochs = list(range(0, len(results['train_losses']), 10))[
        :len(results['val_accuracies'])]
    ax2 = ax.twinx()
    ax2.plot(val_epochs, results['val_accuracies'],
             'r-', label='Validation Accuracy')

    ax.set_xlabel('Epoch')
    ax.set_ylabel('Training Loss', color='b')
    ax2.set_ylabel('Validation Accuracy', color='r')
    ax.set_title(f'{model_name}\nBest Val Acc: {results["best_val_acc"]:.4f}')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Training Progress for All Models', fontsize=16, y=1.02)
plt.show()