In [1]:
import os
import warnings
import torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
warnings.filterwarnings("ignore", message="expandable_segments not supported on this platform")
os.environ["MLFLOW_LOCK_MODEL_DEPENDENCIES"] = "true"
torch.cuda.empty_cache()
print(torch.__version__)
print(os.cpu_count())

2.5.1+cu118
16


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Dataset, Subset
import optuna
import mlflow

from spikingjelly.clock_driven import functional

import json
import random
import numpy as np
import os

from typing import Optional
from functools import partial
import traceback

import importlib
import dtypeconvert
import snn_model
import ravdess_dataset
importlib.reload(ravdess_dataset)
importlib.reload(snn_model)
importlib.reload(dtypeconvert)
from snn_model import EmotionSNN
from ravdess_dataset import RAVDESSDataset
from dtypeconvert import convert_dataset_dtype

In [3]:
mlflow.set_tracking_uri("sqlite:///snn.db")
mlflow.set_experiment("SNN_audio-experiment-study")

2025/09/24 23:51:41 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/09/24 23:51:41 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.
INFO  [alembic.runtime.migration] Context impl SQLiteImpl.
INFO  [alembic.runtime.migration] Will assume non-transactional DDL.


<Experiment: artifact_location='file:///c:/Users/Marc/Desktop/Programming/Main Project/snn/audio-snn/mlruns/1', creation_time=1758737404883, experiment_id='1', last_update_time=1758737404883, lifecycle_stage='active', name='SNN_audio-experiment-study', tags={}>

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

Using device: cuda


In [5]:
# Add this cell early in the notebook
import platform
import sys
import subprocess

def log_environment():
    """Log complete environment for reproducibility"""
    env_info = {
        "python_version": sys.version,
        "platform": platform.platform(),
        "pytorch_version": torch.__version__,
        "cuda_version": torch.version.cuda,
        "cuda_available": torch.cuda.is_available(),
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None",
        "spikingjelly_version": getattr(functional, '__version__', 'unknown'),
        "mlflow_version": mlflow.__version__,
        "optuna_version": optuna.__version__,
    }
    
    # Log complete environment as JSON artifact
    mlflow.log_text(json.dumps(env_info, indent=2), "environment.json")
    
    # Log key info as MLflow tags (for filtering/searching)
    mlflow.set_tag("env_pytorch_version", env_info["pytorch_version"])
    mlflow.set_tag("env_cuda_version", str(env_info["cuda_version"]))
    mlflow.set_tag("env_cuda_available", str(env_info["cuda_available"]))
    mlflow.set_tag("env_gpu_name", env_info["gpu_name"])
    mlflow.set_tag("env_platform", env_info["platform"])

    return env_info



In [6]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# PyTorch deterministic flags for reproducibility (may slow training)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [7]:
dataset_root = "Audio_Speech_Actors_01-24"

# Load dataset
full_dataset = RAVDESSDataset(root_dir=dataset_root, T=50, augment_prob=0.7)

train_size = int(0.6 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

generator = torch.Generator().manual_seed(SEED)

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, 
    [train_size, val_size, test_size],
    generator=generator
)

# Log split indices for exact reproduction
train_indices = train_dataset.indices
val_indices = val_dataset.indices
test_indices = test_dataset.indices

train_dataset = convert_dataset_dtype(train_dataset, torch.float16)
val_dataset = convert_dataset_dtype(val_dataset, torch.float16)
test_dataset = convert_dataset_dtype(test_dataset, torch.float16)

print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}, Test size: {len(test_dataset)}")

Train size: 864, Validation size: 288, Test size: 288


In [8]:
x, y = train_dataset[0]
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([50, 1, 128, 400]) torch.float16
torch.Size([]) torch.int64


In [9]:
# quick checks for what train_dataset returns
print(f"min: {x.min().item()} \nmax: {x.max().item()} \nmean: {x.mean().item()}")
print("y:", y)

min: 0.0 
max: 1.0 
mean: 0.0919189453125
y: tensor(5)


In [10]:
@torch.inference_mode()
def validate(model, val_loader, criterion, device):
    model.eval()

    total_loss = 0.0
    total_samples = 0
    total_correct = 0


    for data, target in val_loader:
        data = data.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # Forward
        outputs = model(data)
        loss = criterion(outputs, target)

        # Update aggregates
        batch_size = target.size(0)
        total_loss += loss.item()
        total_samples += batch_size

        preds = outputs.argmax(dim=1)
        total_correct += (preds == target).sum().item()

        # Reset SNN states to avoid carryover between batches
        functional.reset_net(model)

    # Safeguards
    num_batches = max(1, len(val_loader))
    avg_loss = total_loss / num_batches
    accuracy = 100.0 * (total_correct / max(1, total_samples))
    return avg_loss, accuracy


In [11]:
def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs,
    learning_rate,
    device,
    patience=10,
    trial=None,
    min_delta=0.001,
    ckpt_dir="checkpoints",
    grad_clip_max_norm: Optional[float] = None,
    report_intermediate = True
):
    
    os.makedirs(ckpt_dir, exist_ok=True)
    
    # advanced optimizer setup
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr=learning_rate,
                                 weight_decay=0.01,
                                 betas=(0.9, 0.999)
                    )


    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min',
        factor=0.5,
        patience=5,
        verbose=True,
        min_lr=1e-7,
        threshold=min_delta
    )


    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    best_val_loss = float('inf')
    best_val_acc = 0.0
    best_epoch = 0
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    best_model_state = None
    no_improve_epochs = 0

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss, train_acc = 0.0, 0.0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            
            # Combined loss with L2 regularization
            ce_loss = criterion(output, target)
            l2_loss = model.get_l2_loss()
            loss = ce_loss + l2_loss
            
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            if grad_clip_max_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip_max_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            functional.reset_net(model)
            
            train_loss += loss.item()
            pred = output.argmax(dim=1)
            train_acc += pred.eq(target).sum().item()
            
            # printing batch summary
            batch_acc = 100.0 * pred.eq(target).sum().item() / len(target)
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_loader)}: "
                f"Loss: {loss.item():.4f}, Batch Acc: {batch_acc:.2f}%")
            
        # ------- Validate -------
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # ------- Scheduler step -------
        scheduler.step(val_loss)
        
        # ------- Calculate metrics -------
        epoch_train_loss = train_loss / len(train_loader)
        epoch_train_acc = 100.0 * train_acc / len(train_loader.dataset)
        
        train_accs.append(epoch_train_acc)
        val_accs.append(val_acc)
        train_losses.append(epoch_train_loss)
        val_losses.append(val_loss)

        # ------- Optuna Reporting -------
        if trial is not None and report_intermediate:
            trial.report(val_loss, epoch)
            trial.set_user_attr(f"val_acc_epoch_{epoch}", val_acc)

        # ------- Track best -------
        val_improved = val_acc > best_val_acc + min_delta
        loss_improved = val_loss < best_val_loss - min_delta
        
        # Update best metrics
        if val_improved:
            best_val_acc = val_acc
            best_epoch = epoch
            no_improvement_epochs = 0
        else:
            no_improvement_epochs += 1
            
        if loss_improved:
            best_val_loss = val_loss
            best_model_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

        # ------- MLFlow Logging -------
        mlflow.log_metric("train_loss", epoch_train_loss, step=epoch)
        mlflow.log_metric("val_loss", val_loss, step=epoch)
        mlflow.log_metric("train_acc", epoch_train_acc, step=epoch)
        mlflow.log_metric("val_acc", val_acc, step=epoch)
        mlflow.log_metric("best_val_acc", best_val_acc, step=epoch)
        mlflow.log_metric("learning_rate", optimizer.param_groups[0]['lr'], step=epoch)
        
        # ------ optuna pruning hook ------
        if trial is not None:
            if trial.should_prune():
                print(f"Trial {trial.number} pruned at epoch {epoch}")
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                raise optuna.TrialPruned()
          
        # -------- Print epoch summary -------
        if trial is None:
            print(f'Epoch {epoch+1}/{num_epochs}: '
                f'Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.2f}%, '
                f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, '
                f'Best Val Acc: {best_val_acc:.2f}%, '
                f'LR: {optimizer.param_groups[0]["lr"]:.2e}')

    # Load the best weights seen during this call
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        
    return best_val_acc, best_val_loss, train_accs, val_accs, train_losses, val_losses

In [12]:
# Optuna objective function
def objective(trial, parent_run_id):
    """Optuna objective function for hyperparameter optimization"""
       
    # Suggest hyperparameters
    params = {
            'conv1_channels': trial.suggest_int('conv1_channels', 16, 64, step=8),
            'conv2_channels': trial.suggest_int('conv2_channels', 32, 128, step=16),
            'fc1_units': trial.suggest_int('fc1_units', 64, 256, step=32),
            'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True),
            'batch_size': trial.suggest_categorical('batch_size', [2, 4]),
            'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.5),
            
            # model specific
            "v_threshold" : trial.suggest_float("v_threshold", 0.2, 1.0),
            "tau" : trial.suggest_float("tau", 1.0, 4.0),
            'surrogate_func': trial.suggest_categorical('surrogate_func', ['Sigmoid', 'ATan']),
            'T_steps': trial.suggest_int('T_steps', 50, 200, step=25),
            }
        

    train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], 
                              shuffle=True, drop_last=True,
                              num_workers=0,pin_memory=True, 
                              )
    
    val_loader = DataLoader(val_dataset, batch_size=params['batch_size'], 
                            shuffle=False, drop_last=False, 
                            num_workers=0,pin_memory=True, 
                            )


    with mlflow.start_run(run_name=f"trial_{trial.number}",nested=True, parent_run_id=parent_run_id):
        
        # Set different random seed for each trial
        trial_seed = SEED + trial.number
        torch.manual_seed(trial_seed)
        torch.cuda.manual_seed_all(trial_seed)
        
        mlflow.set_tag("phase", "optuna")
        # Log hyperparameters
        mlflow.log_params(params)
        
        # Log data splits for reproducibility
        mlflow.log_text(json.dumps({
            "train_indices": train_indices,
            "val_indices": val_indices,
            "test_indices": test_indices
        }, indent=2), "data_splits.json")
    
        # Create model with suggested parameters
        model = EmotionSNN(
            num_classes=8,
            conv1_channels=params['conv1_channels'],
            conv2_channels=params['conv2_channels'], 
            fc1_units=params['fc1_units'],
            surrogate_func=params['surrogate_func'],
            dropout_rate=params['dropout_rate']
        ).to(device).half()
            
        
        
        # Train model (shorter for hyperparameter search)
        try:
            best_val_acc, best_val_loss, train_accs, val_accs, train_losses, val_losses = train_model(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                num_epochs=20,  # Shorter for optimization
                learning_rate=params['learning_rate'],
                device=device,
                patience=15,
                trial=trial,
                grad_clip_max_norm=1.0,
            )

            return best_val_loss, best_val_acc
        
        except optuna.exceptions.TrialPruned:
            raise
        except Exception as e:
            traceback.print_exc()
            print("Trial failed, returning high loss")
            return float('inf'), 0  # Return poor score for failed trials
        
        finally:
            torch.cuda.empty_cache()

In [13]:
def run_hyperparameter_optimization(n_trials=50):
    # Run Optuna hyperparameter optimization

    # Create study
    sampler = optuna.samplers.TPESampler(n_startup_trials=10, n_ei_candidates=24, multivariate=True, group=True)
    pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10, interval_steps=5)
    
    study = optuna.create_study(
        directions=["minimize", "maximize"],
        study_name="snn_emotion_recognition",
        storage="sqlite:///snn_optuna.db",  # Persist results
        load_if_exists=True,
        pruner=pruner,
        sampler=sampler,
    )

    print("Starting hyperparameter optimization...")
    print(f"Will run {n_trials} trials")

    # callback: print + optuna pruning only
    def stop_callback(study, trial):
        try:
            best_val = study.best_value
            print(f"Trial {trial.number} completed. Best validation loss so far: {best_val:.2f}")
        except Exception:
            print(f"Trial {trial.number} completed.")

    with mlflow.start_run(run_name="snn_optuna_sweep") as parent_run:
        parent_run_id = parent_run.info.run_id
        
        mlflow.set_tag("phase", "optuna_sweep")
        mlflow.set_tag("model", "EmotionSNN")
        mlflow.log_artifact("requirements.txt")

        study.optimize(partial(objective, parent_run_id=parent_run_id), n_trials=n_trials, timeout=7200, callbacks=[stop_callback])

        df = study.trials_dataframe()
        df.to_csv("study_trials.csv", index=False)
        mlflow.log_artifact("study_trials.csv", artifact_path="optuna")
        
        # Save best parameters to a JSON file
        with open("best_params.json", "w") as f:
            json.dump(study.best_params, f, indent=2)

        # Log the file as an artifact in MLflow
        mlflow.log_artifact("best_params.json", artifact_path="optuna")
        
        mlflow.log_text(json.dumps({"best_value": study.best_value}, indent=2), "optuna/best_value.json")
        return study



In [14]:
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

Allocated: 0.00 MB
Reserved: 0.00 MB


In [15]:
print("=== Warmup step: Training with baseline hyperparameters ===")
baseline_params = {
    'conv1_channels': 32,
    'conv2_channels': 64,
    'fc1_units': 128,
    'learning_rate': 1e-3,
    'batch_size': 4,
    'dropout_rate': 0.3,
    'surrogate_func': 'Sigmoid',
    'T_steps': 50,
    'v_threshold': 0.5,
    'tau': 2.0,
}
train_loader = DataLoader(train_dataset, batch_size=baseline_params['batch_size'], 
                              shuffle=True, drop_last=True,
                              num_workers=0,pin_memory=True, 
                              )

val_loader = DataLoader(val_dataset, batch_size=baseline_params['batch_size'], 
                            shuffle=False, drop_last=False, 
                            num_workers=0,pin_memory=True, 
                            )

model = EmotionSNN(
    num_classes=8,
    conv1_channels=baseline_params['conv1_channels'],
    conv2_channels=baseline_params['conv2_channels'],
    fc1_units=baseline_params['fc1_units'],
    surrogate_func=baseline_params['surrogate_func'],
    dropout_rate=baseline_params['dropout_rate']
).to(device).half()

torch.cuda.empty_cache()
train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=2,  # Just a few epochs for warmup
    learning_rate=baseline_params['learning_rate'],
    device=device,
    patience=2,
    trial=None,
    grad_clip_max_norm=1.0,
)
print("=== Warmup complete ===")

=== Warmup step: Training with baseline hyperparameters ===




Epoch 1/2, Batch 1/216: Loss: 2.4630, Batch Acc: 0.00%
Epoch 1/2, Batch 2/216: Loss: nan, Batch Acc: 25.00%
Epoch 1/2, Batch 3/216: Loss: nan, Batch Acc: 50.00%
Epoch 1/2, Batch 4/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 5/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 6/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 7/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 8/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 9/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 10/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 11/216: Loss: nan, Batch Acc: 0.00%
Epoch 1/2, Batch 12/216: Loss: nan, Batch Acc: 0.00%


KeyboardInterrupt: 

In [None]:
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

In [None]:
# # optuna hyperparameter optimization
print("=== Option 2: Hyperparameter Optimization ===")
study = run_hyperparameter_optimization(n_trials=100)  # Start with 100 trials