# Acoustic Navigation Training Notebook

In [13]:
import torch
print("torch:", torch.__version__)
print("built with CUDA:", torch.version.cuda)        # None => CPU-only build
print("cuda available:", torch.cuda.is_available())  # should be True
print("cuda built:", torch.backends.cuda.is_built()) # True if GPU build


torch: 2.9.1+cu126
built with CUDA: 12.6
cuda available: True
cuda built: True


In [14]:
import sys
sys.path.append('../')

import numpy as np
from pathlib import Path
import torch
from torch.utils.data import DataLoader

from src.cave_dataset import (
    MultiCaveDataset,
    ACTION_MAP,
    ACTION_NAMES,
    MIC_OFFSETS,
    compute_class_distribution,
    compute_class_weights,
)
from src.models import CompactAcousticNet, SpatialTemporalAcousticNet

print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
np.random.seed(42)
torch.manual_seed(42)

2.9.1+cu126
Using device: cuda


<torch._C.Generator at 0x1f529ce00b0>

In [15]:
# Dataset config
DATASET_DIR = Path('D:/audiomaze_dataset_100')
H5_FILES = sorted(DATASET_DIR.glob('cave_*.h5'))
assert len(H5_FILES) > 0, f"No cave_*.h5 files found in {DATASET_DIR}"
AGENT_RADIUS = 1

# Build dataset

#Optional: use LMDB dataset
#from src.lmdb_dataset import LMDBAcousticDataset
#dataset = LMDBAcousticDataset('D:/audiomaze_lmdb_100')

dataset = MultiCaveDataset(H5_FILES, agent_radius=AGENT_RADIUS, mic_offsets=MIC_OFFSETS, action_map=ACTION_MAP)
print(f"Total valid positions: {len(dataset)} across {len(H5_FILES)} files")

# Class distribution & weights
class_counts = compute_class_distribution(dataset)
print('Class counts:', class_counts)
class_weights = compute_class_weights(class_counts)
print('Class weights:', class_weights)


Total valid positions: 96868 across 100 files
Class counts: {'stop': 100, 'up': 48697, 'down': 43300, 'left': 2318, 'right': 2453}
Class weights: tensor([193.7360,   0.3978,   0.4474,   8.3579,   7.8979])


In [None]:
# Dataloaders (NO weighted sampling - weights handled in loss)
BATCH_SIZE = 256  # Reduced for larger model
TRAIN_SPLIT = 0.8
NUM_WORKERS = 10  # set >0 on Linux/Mac

train_size = int(TRAIN_SPLIT * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

# Simple dataloaders - no weighted sampling (avoiding double weighting)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

batch = next(iter(train_loader))
print('Batch mic shape:', batch[0].shape)
print(f'Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}')

In [None]:
# Model Architecture Selection
# Choose one:
# - CompactAcousticNet: Efficient baseline (faster training)
# - SpatialTemporalAcousticNet: Advanced with attention (better accuracy, slower)

MODEL_TYPE = 'compact'  # 'compact' or 'spatial_temporal'

if MODEL_TYPE == 'compact':
    model = CompactAcousticNet(num_classes=len(ACTION_NAMES), dropout=0.3).to(device)
    print('Using CompactAcousticNet')
else:
    model = SpatialTemporalAcousticNet(num_classes=len(ACTION_NAMES), dropout=0.3).to(device)
    print('Using SpatialTemporalAcousticNet')

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

# Loss with class weights (handles imbalance)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))
print(f'Using class weights: {class_weights}')

# Optimizer with gradient clipping
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)

Using CompactAcousticNet
Total parameters: 610,981
Trainable parameters: 610,981
Using class weights: tensor([193.7360,   0.3978,   0.4474,   8.3579,   7.8979])


In [None]:
from collections import defaultdict
from tqdm.auto import tqdm
import copy
from contextlib import nullcontext

EPOCHS = 20
MAX_GRAD_NORM = 1.0  # Gradient clipping
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_type)

# AMP setup
if device_type == 'cuda':
    scaler = torch.amp.GradScaler('cuda')
    autocast_ctx = lambda: torch.amp.autocast('cuda')
else:
    scaler = torch.amp.GradScaler(enabled=False)
    autocast_ctx = lambda: nullcontext()

save_dir = Path("checkpoints")
save_dir.mkdir(parents=True, exist_ok=True)
best_val_loss = float('inf')
best_val_acc = 0.0

def save_ckpt(state, name):
    torch.save(state, save_dir / name)

def step_metrics(logits, targets):
    preds = logits.argmax(dim=1)
    correct = (preds == targets).sum().item()
    total = targets.numel()
    return correct, total

# Per-class accuracy tracking
def compute_class_accuracies(all_preds, all_targets, num_classes=5):
    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    for pred, target in zip(all_preds, all_targets):
        class_total[target] += 1
        if pred == target:
            class_correct[target] += 1
    return {ACTION_NAMES[i]: (class_correct[i] / class_total[i] if class_total[i] > 0 else 0.0) 
            for i in range(num_classes)}

print("=" * 70)
print(f"TRAINING CONFIGURATION")
print("=" * 70)
print(f"Model: {MODEL_TYPE}")
print(f"Epochs: {EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
print(f"Weight decay: {optimizer.param_groups[0]['weight_decay']}")
print(f"Gradient clipping: {MAX_GRAD_NORM}")
print(f"Class weights: {class_weights}")
print("=" * 70)

for epoch in range(EPOCHS):
    # Training
    model.train()
    running_loss = running_correct = running_total = 0

    pbar = tqdm(train_loader, desc=f"Train {epoch+1}/{EPOCHS}", dynamic_ncols=True, colour="blue")
    for mic, action, _, _ in pbar:
        mic = mic.to(device, non_blocking=True)
        action = action.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast_ctx():
            logits = model(mic)
            loss = criterion(logits, action)
        
        scaler.scale(loss).backward()
        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * mic.size(0)
        correct, total = step_metrics(logits, action)
        running_correct += correct
        running_total += total
        pbar.set_postfix(loss=running_loss / running_total, acc=running_correct / running_total)

    train_loss = running_loss / running_total
    train_acc = running_correct / running_total

    # Validation
    model.eval()
    val_loss = val_correct = val_total = 0
    all_preds = []
    all_targets = []
    
    pbar_val = tqdm(val_loader, desc=f"Val {epoch+1}/{EPOCHS}", dynamic_ncols=True, colour="green")
    with torch.no_grad():
        for mic, action, _, _ in pbar_val:
            mic = mic.to(device, non_blocking=True)
            action = action.to(device, non_blocking=True)
            with autocast_ctx():
                logits = model(mic)
                loss = criterion(logits, action)
            
            val_loss += loss.item() * mic.size(0)
            correct, total = step_metrics(logits, action)
            val_correct += correct
            val_total += total
            
            # Track per-class accuracy
            preds = logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(action.cpu().numpy())
            
            pbar_val.set_postfix(loss=val_loss / val_total, acc=val_correct / val_total)

    val_loss /= val_total
    val_acc = val_correct / val_total
    
    # Per-class accuracies
    class_accs = compute_class_accuracies(all_preds, all_targets)
    
    print(f"\nEpoch {epoch+1}/{EPOCHS}:")
    print(f"  Train: loss={train_loss:.4f}, acc={train_acc:.4f}")
    print(f"  Val:   loss={val_loss:.4f}, acc={val_acc:.4f}")
    print(f"  Per-class val acc: ", end="")
    for name, acc in class_accs.items():
        print(f"{name}={acc:.3f} ", end="")
    print()
    
    scheduler.step()

    # Save best model (by validation accuracy)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_loss = val_loss
        best_state = {
            'epoch': epoch + 1,
            'model_state': copy.deepcopy(model.state_dict()),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
            'class_accs': class_accs,
        }
        save_ckpt(best_state, "best_model.pt")
        print(f"  ✓ Saved new best model (val_acc={val_acc:.4f})")

    # Save periodic checkpoint
    if (epoch + 1) % 5 == 0:
        ckpt_state = {
            'epoch': epoch + 1,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }
        save_ckpt(ckpt_state, f"epoch_{epoch+1}.pt")
        print(f"  → Saved checkpoint at epoch {epoch+1}")

print("\n" + "=" * 70)
print(f"TRAINING COMPLETE")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Best validation loss: {best_val_loss:.4f}")
print("=" * 70)

TRAINING CONFIGURATION
Model: compact
Epochs: 20
Batch size: 512
Learning rate: 0.0003
Weight decay: 0.0001
Gradient clipping: 1.0
Class weights: tensor([193.7360,   0.3978,   0.4474,   8.3579,   7.8979])


Train 1/20:   0%|          | 0/152 [00:26<?, ?it/s]

KeyboardInterrupt: 