# 02: Model Training for Acoustic Navigation

This notebook trains a deep CNN to predict navigation actions from acoustic spectrograms.

## Training Pipeline

1. Load pre-computed dataset from HDF5
2. Split into train/validation sets
3. Define CNN architecture
4. Train with cross-entropy loss
5. Evaluate and save model

## Model Architecture

- **Input:** Multi-channel spectrograms (8 mics × freq × time)
- **Architecture:** 3-layer CNN + Global Average Pooling + FC layers
- **Output:** Action probabilities (5 classes: STOP, UP, DOWN, LEFT, RIGHT)

In [None]:
# Add src to path
import sys
sys.path.append('../')

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from src.model import AudioNavCNN
from src.dataset import AcousticGridDataset, AcousticDataModule
from src.utils import plot_training_curves, plot_confusion_matrix

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Configuration

In [None]:
# Data
DATA_FILE = '../data/acoustic_navigation_data.h5'
TRAIN_VAL_SPLIT = 0.8  # 80% train, 20% validation

# Model
NUM_MICROPHONES = 8
NUM_ACTIONS = 5
DROPOUT_RATE = 0.3

# Training
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-5

# Output
MODEL_SAVE_PATH = '../data/audio_nav_model.pth'

print("Training Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Device: {device}")

## 2. Load Dataset

In [None]:
# Load full dataset
full_dataset = AcousticGridDataset(
    hdf5_path=DATA_FILE,
    normalize=True,
    cache_in_memory=False  # Set to True if dataset fits in RAM
)

print(f"Total samples: {len(full_dataset)}")
print(f"Spectrogram shape: {full_dataset.spectrogram_shape}")

# Get normalization stats
mean, std = full_dataset.get_normalization_stats()
print(f"\nNormalization:")
print(f"  Mean: {mean:.6f}")
print(f"  Std: {std:.6f}")

# Split into train and validation
train_size = int(TRAIN_VAL_SPLIT * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"\nSplit:")
print(f"  Training: {len(train_dataset)} samples")
print(f"  Validation: {len(val_dataset)} samples")

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Use 0 for Windows, 4+ for Linux/Mac
    pin_memory=torch.cuda.is_available()
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)

print(f"\nDataLoaders:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

## 3. Examine Sample Batch

In [None]:
# Get a sample batch
sample_spectrograms, sample_actions = next(iter(train_loader))

print(f"Batch shapes:")
print(f"  Spectrograms: {sample_spectrograms.shape}")
print(f"  Actions: {sample_actions.shape}")

# Action distribution in batch
action_names = ['STOP', 'UP', 'DOWN', 'LEFT', 'RIGHT']
print(f"\nActions in batch:")
for action_idx, action_name in enumerate(action_names):
    count = (sample_actions == action_idx).sum().item()
    print(f"  {action_name}: {count}")

## 4. Initialize Model

In [None]:
# Create model
model = AudioNavCNN(
    num_microphones=NUM_MICROPHONES,
    num_actions=NUM_ACTIONS,
    dropout_rate=DROPOUT_RATE
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: AudioNavCNN")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

# Test forward pass
with torch.no_grad():
    test_output = model(sample_spectrograms.to(device))
    print(f"\nTest forward pass:")
    print(f"  Input shape: {sample_spectrograms.shape}")
    print(f"  Output shape: {test_output.shape}")
    print(f"  Output range: [{test_output.min():.3f}, {test_output.max():.3f}]")

## 5. Define Loss and Optimizer

In [None]:
# Loss function (Cross Entropy for classification)
criterion = nn.CrossEntropyLoss()

# Optimizer (Adam with weight decay)
optimizer = optim.Adam(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# Learning rate scheduler (reduce on plateau)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    verbose=True
)

print("Optimizer: Adam")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print("\nScheduler: ReduceLROnPlateau")
print(f"  Factor: 0.5")
print(f"  Patience: 5 epochs")

## 6. Training Loop

In [None]:
# Training history
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

best_val_loss = float('inf')
best_epoch = 0

print("Starting training...\n")

for epoch in range(NUM_EPOCHS):
    # ============ Training Phase ============
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]")
    for spectrograms, actions in train_pbar:
        # Move to device
        spectrograms = spectrograms.to(device)
        actions = actions.to(device)
        
        # Forward pass
        outputs = model(spectrograms)
        loss = criterion(outputs, actions)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        train_loss += loss.item() * spectrograms.size(0)
        _, predicted = torch.max(outputs, 1)
        train_correct += (predicted == actions).sum().item()
        train_total += actions.size(0)
        
        # Update progress bar
        train_pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * train_correct / train_total:.2f}%'
        })
    
    # Epoch statistics
    train_loss = train_loss / train_total
    train_accuracy = 100 * train_correct / train_total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    # ============ Validation Phase ============
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Val]")
        for spectrograms, actions in val_pbar:
            spectrograms = spectrograms.to(device)
            actions = actions.to(device)
            
            outputs = model(spectrograms)
            loss = criterion(outputs, actions)
            
            val_loss += loss.item() * spectrograms.size(0)
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == actions).sum().item()
            val_total += actions.size(0)
            
            val_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * val_correct / val_total:.2f}%'
            })
    
    val_loss = val_loss / val_total
    val_accuracy = 100 * val_correct / val_total
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'train_accuracy': train_accuracy,
            'val_accuracy': val_accuracy,
        }, MODEL_SAVE_PATH)
        print(f"  ✓ Best model saved! (Val Loss: {val_loss:.4f})")
    
    print("-" * 60)

print(f"\n{'='*60}")
print("Training Complete!")
print(f"Best epoch: {best_epoch}")
print(f"Best val loss: {best_val_loss:.4f}")
print(f"{'='*60}")

## 7. Plot Training Curves

In [None]:
# Plot training history
fig = plot_training_curves(
    train_losses=train_losses,
    val_losses=val_losses,
    train_accuracies=train_accuracies,
    val_accuracies=val_accuracies
)
plt.show()

# Print final statistics
print(f"Final Training Metrics:")
print(f"  Train Loss: {train_losses[-1]:.4f}")
print(f"  Train Accuracy: {train_accuracies[-1]:.2f}%")
print(f"  Val Loss: {val_losses[-1]:.4f}")
print(f"  Val Accuracy: {val_accuracies[-1]:.2f}%")

## 8. Evaluate on Validation Set

In [None]:
# Load best model
checkpoint = torch.load(MODEL_SAVE_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")

# Compute confusion matrix
all_predictions = []
all_targets = []

with torch.no_grad():
    for spectrograms, actions in tqdm(val_loader, desc="Evaluating"):
        spectrograms = spectrograms.to(device)
        actions = actions.to(device)
        
        outputs = model(spectrograms)
        _, predicted = torch.max(outputs, 1)
        
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(actions.cpu().numpy())

# Compute confusion matrix
from sklearn.metrics import confusion_matrix, classification_report

cm = confusion_matrix(all_targets, all_predictions)

# Plot confusion matrix
fig = plot_confusion_matrix(
    cm,
    class_names=action_names,
    title="Confusion Matrix (Validation Set)",
    normalize=False
)
plt.show()

# Plot normalized confusion matrix
fig = plot_confusion_matrix(
    cm,
    class_names=action_names,
    title="Normalized Confusion Matrix (Validation Set)",
    normalize=True
)
plt.show()

# Print classification report
print("\nClassification Report:")
print(classification_report(
    all_targets,
    all_predictions,
    target_names=action_names,
    digits=4
))

## 9. Per-Class Analysis

In [None]:
# Compute per-class accuracy
print("Per-Class Accuracy:")
for i, action_name in enumerate(action_names):
    class_mask = np.array(all_targets) == i
    if class_mask.sum() > 0:
        class_predictions = np.array(all_predictions)[class_mask]
        class_targets = np.array(all_targets)[class_mask]
        class_accuracy = 100 * (class_predictions == class_targets).sum() / len(class_targets)
        print(f"  {action_name}: {class_accuracy:.2f}% ({class_mask.sum()} samples)")

# Overall accuracy
overall_accuracy = 100 * (np.array(all_predictions) == np.array(all_targets)).sum() / len(all_targets)
print(f"\nOverall Accuracy: {overall_accuracy:.2f}%")

## 10. Save Final Model and Metadata

In [None]:
# Save complete model information
final_checkpoint = {
    'model_state_dict': model.state_dict(),
    'model_config': {
        'num_microphones': NUM_MICROPHONES,
        'num_actions': NUM_ACTIONS,
        'dropout_rate': DROPOUT_RATE,
    },
    'training_config': {
        'batch_size': BATCH_SIZE,
        'num_epochs': NUM_EPOCHS,
        'learning_rate': LEARNING_RATE,
        'weight_decay': WEIGHT_DECAY,
    },
    'normalization': {
        'mean': mean,
        'std': std,
    },
    'performance': {
        'best_epoch': best_epoch,
        'best_val_loss': best_val_loss,
        'final_val_accuracy': val_accuracies[-1],
        'overall_accuracy': overall_accuracy,
    },
    'history': {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
    }
}

torch.save(final_checkpoint, MODEL_SAVE_PATH)
print(f"✓ Final model saved to: {MODEL_SAVE_PATH}")

# Display model info
import os
model_size_mb = os.path.getsize(MODEL_SAVE_PATH) / (1024 * 1024)
print(f"  Model size: {model_size_mb:.2f} MB")
print(f"  Best validation accuracy: {val_accuracies[best_epoch-1]:.2f}%")
print(f"  Final validation accuracy: {val_accuracies[-1]:.2f}%")

print("\n✓ Training complete!")
print(f"Next step: Open 03_Simulation_Demo.ipynb")