# Speaker Verification Training & Inference

Complete pipeline for speaker verification using:
- **PTM Encoder**: WavLM/HuBERT/Wav2Vec2 (multi-layer weighted sum)
- **Handcrafted Encoder**: MFBE + F0 features with CNN
- **Fusion Methods**: Concatenation or Cross-Attention
- **Backbone**: ECAPA-TDNN
- **Loss**: AAM-Softmax
- **Features**: Early stopping, LR scheduling, heatmap visualization

## 1. Setup & Parse Arguments

In [None]:
# Auto-reload modules
%load_ext autoreload
%autoreload 2

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import argparse
import sys
from datetime import datetime

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
from config import (
    BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE, 
    MODE, FUSION_METHOD, FEATURE_MODE,
    CHECKPOINT_DIR, BEST_MODEL_NAME,
)

def parse_args():
    """Parse arguments for notebook - can be called from script or notebook"""
    parser = argparse.ArgumentParser(description="Train Speaker Verification Model")
    
    # Data arguments
    parser.add_argument("--embedding-path", type=str, default="./embedding.pt",
                        help="Path to embedding file")
    parser.add_argument("--feature-path", type=str, default="./feature.pt",
                        help="Path to feature file")
    
    # Model arguments
    parser.add_argument("--mode", type=int, choices=[1, 2, 3], default=MODE,
                        help="Training mode: 1=PTM only, 2=Handcrafted only, 3=Fusion")
    parser.add_argument("--fusion-method", type=str, 
                        choices=["concat", "cross_attention", "gating"], 
                        default=FUSION_METHOD,
                        help="Fusion method for mode 3")
    parser.add_argument("--feature-mode", type=str, 
                        choices=["mfbe_pitch", "mfcc_pitch", "mfbe_only", "mfcc_only", "pitch_only"],
                        default=FEATURE_MODE,
                        help="Feature mode")
    
    # Training arguments
    parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
                        help="Batch size")
    parser.add_argument("--learning-rate", "--lr", type=float, default=LEARNING_RATE,
                        help="Learning rate")
    parser.add_argument("--epochs", type=int, default=NUM_EPOCHS,
                        help="Number of epochs")
    parser.add_argument("--exp-name", type=str, default=None,
                        help="Experiment name")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    
    # Parse with empty args for notebook mode
    args = parser.parse_args([])
    return args

# Parse arguments
args = parse_args()

print("="*60)
print("PARSED ARGUMENTS")
print("="*60)
print(f"  Embedding path: {args.embedding_path}")
print(f"  Feature path: {args.feature_path}")
print(f"  Mode: {args.mode} (1=PTM only, 2=Handcrafted only, 3=Fusion)")
print(f"  Fusion method: {args.fusion_method if args.mode == 3 else 'N/A'}")
print(f"  Feature mode: {args.feature_mode if args.mode in [2, 3] else 'N/A'}")
print(f"  Batch size: {args.batch_size}")
print(f"  Learning rate: {args.learning_rate}")
print(f"  Epochs: {args.epochs}")
print(f"  Exp name: {args.exp_name or 'Auto-generated'}")
print(f"  Seed: {args.seed}")
print("="*60)

## 2. Training

In [None]:
from train import train

# Generate experiment name if not provided
if args.exp_name is None:
    args.exp_name = f"mode{args.mode}_fusion_{args.fusion_method}_feat_{args.feature_mode}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80)
print(f"Experiment: {args.exp_name}\n")

model, history, exp_dir = train(args)

print(f"\n✓ Training completed!")
print(f"  Results saved to: {exp_dir}")
print("="*80)

## 3. Training Curves

In [None]:
import json

# Load history
history_path = "./outputs/training_history.json"
with open(history_path, "r") as f:
    history = json.load(f)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Loss
axes[0].plot(history["train_loss"], label="Train", marker='o', markersize=3)
axes[0].plot(history["val_loss"], label="Val", marker='s', markersize=3)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history["train_accuracy"], label="Train", marker='o', markersize=3)
axes[1].plot(history["val_accuracy"], label="Val", marker='s', markersize=3)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("Training & Validation Accuracy")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("./outputs/training_curves.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"Best validation loss: {min(history['val_loss']):.4f}")
print(f"Best validation accuracy: {max(history['val_accuracy']):.4f}")

## 4. Test Results

In [None]:
from train import load_checkpoint
from model import AAMSoftmaxLoss

# Load best model
best_model_path = os.path.join(exp_dir, BEST_MODEL_NAME)
model, _, _, _ = load_checkpoint(best_model_path, model)
model = model.to(device)
model.eval()

print(f"✓ Best model loaded from: {best_model_path}")

In [None]:
# Load data for testing (test_loader needs to be recreated)
from dataset import create_data_loaders

# Recreate test loader
_, _, test_loader, _, num_speakers = create_data_loaders(
    args.embedding_path, args.feature_path, args.mode, args.batch_size, num_workers=0
)

# Test on test set
criterion = AAMSoftmaxLoss(num_speakers=num_speakers)
from train import validate

test_loss, test_acc = validate(model, test_loader, criterion, device)

print(f"Test Results:")
print(f"  Loss: {test_loss:.4f}")
print(f"  Accuracy: {test_acc:.4f}")

## 5. Experiment Comparison

In [None]:
import pandas as pd

# List all experiments
exp_base_dir = "./outputs/experiments"
if os.path.exists(exp_base_dir):
    experiments = []
    for exp_name_dir in sorted(os.listdir(exp_base_dir)):
        exp_path = os.path.join(exp_base_dir, exp_name_dir)
        results_file = os.path.join(exp_path, "results.json")
        if os.path.exists(results_file):
            with open(results_file, "r") as f:
                data = json.load(f)
                config = data.get("config", {})
                experiments.append({
                    "Experiment": exp_name_dir,
                    "Mode": config.get("mode", ""),
                    "Fusion": config.get("fusion_method", "N/A"),
                    "Feature": config.get("feature_mode", "N/A"),
                    "Best Val Loss": f"{data.get('best_val_loss', 0):.4f}",
                    "Epochs": data.get("epochs_trained", 0),
                })
    
    if experiments:
        df = pd.DataFrame(experiments)
        print("\n" + "="*120)
        print("EXPERIMENT COMPARISON")
        print("="*120)
        print(df.to_string(index=False))
        print("="*120)
    else:
        print("No experiments found.")
else:
    print(f"Directory {exp_base_dir} does not exist yet.")