# Acoustic Navigation Training Notebook

In [1]:
import torch
print("torch:", torch.__version__)
print("built with CUDA:", torch.version.cuda)        # None => CPU-only build
print("cuda available:", torch.cuda.is_available())  # should be True
print("cuda built:", torch.backends.cuda.is_built()) # True if GPU build


torch: 2.9.1+cu126
built with CUDA: 12.6
cuda available: True
cuda built: True


In [2]:
import sys
sys.path.append('../')

import numpy as np
from pathlib import Path
import torch
from torch.utils.data import DataLoader

from src.cave_dataset import (
    MultiCaveDataset,
    ACTION_MAP,
    ACTION_NAMES,
    MIC_OFFSETS,
    compute_class_distribution,
    compute_class_weights,
)
from src.models import CompactAcousticNet, SpatialTemporalAcousticNet, FocalLoss
from src.lmdb_dataset import LMDBAcousticDataset

print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
np.random.seed(42)
torch.manual_seed(42)

2.9.1+cu126
Using device: cuda


<torch._C.Generator at 0x1ab3b0c00f0>

In [3]:
# Cell 4+5: Perfectly Balanced Dataset (Dynamic Downsampling)
# ---------------------------------------------------------
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import random
from pathlib import Path
from collections import defaultdict
# Import your specific dataset class
from src.lmdb_dataset import LMDBAcousticDataset 

# 1. Load Raw Data
# UPDATE PATH IF NEEDED
DATASET_DIR = Path('D:/audiomaze_dataset_100')
raw_dataset = LMDBAcousticDataset('D:/audiomaze_lmdb_100')

# 2. Scan and Sort Indices by Class
print("Scanning dataset to group indices by class...")
indices_by_class = defaultdict(list)

# Iterate through dataset to sort indices (Fast with LMDB)
for idx in range(len(raw_dataset)):
    _, action, _, _ = raw_dataset[idx]
    act = int(action)
    # We only care about 1=UP, 2=DOWN, 3=LEFT, 4=RIGHT
    if act in [1, 2, 3, 4]:
        indices_by_class[act].append(idx)

# 3. Find the LIMIT (The count of the smallest class)
# This will automatically be ~2300 or ~2500 based on your data
min_count = min(len(indices_by_class[c]) for c in [1, 2, 3, 4])
print(f"\nClass Counts Found: { {k: len(v) for k,v in indices_by_class.items()} }")
print(f"--> Downsampling all classes to match the smallest: {min_count} samples each")

# 4. Create Perfectly Balanced Index List
balanced_indices = []
for cls in [1, 2, 3, 4]:
    # Randomly sample 'min_count' indices from this class
    # This ensures we use the MAXIMUM possible amount of data while staying balanced
    sampled = random.sample(indices_by_class[cls], min_count)
    balanced_indices.extend(sampled)

# Shuffle to mix classes
random.shuffle(balanced_indices)

# 5. Create Dataset & Split
balanced_dataset = Subset(raw_dataset, balanced_indices)

# Get targets for stratify split
balanced_targets = []
for i in balanced_indices:
    # Quick fetch of just the label
    _, action, _, _ = raw_dataset[i]
    balanced_targets.append(int(action))

train_idx, val_idx = train_test_split(
    np.arange(len(balanced_dataset)), 
    test_size=0.2, 
    random_state=42, 
    stratify=balanced_targets
)

train_dataset = Subset(balanced_dataset, train_idx)
val_dataset = Subset(balanced_dataset, val_idx)

# 6. Loaders
# We can use a standard loader now because the underlying data is already balanced
BATCH_SIZE = 256 
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"\n✅ Data Ready:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val:   {len(val_dataset)} samples")

  from scipy.sparse import csr_matrix, issparse


Loaded LMDB dataset: 96,868 samples
Action distribution: {'stop': 100, 'up': 48697, 'down': 43300, 'left': 2318, 'right': 2453}
Scanning dataset to group indices by class...

Class Counts Found: {2: 43300, 4: 2453, 3: 2318, 1: 48697}
--> Downsampling all classes to match the smallest: 2318 samples each

✅ Data Ready:
  Train: 7417 samples
  Val:   1855 samples


In [4]:
# Cell 6: Model & Focal Loss
# ---------------------------------------------------------
import torch.nn as nn
from src.models import CompactAcousticNet, SpatialTemporalAcousticNet, FocalLoss

MODEL_TYPE = 'spatial'
NUM_CLASSES = 5  # Back to 5 classes (0=STOP included)

if MODEL_TYPE == 'compact':
    model = CompactAcousticNet(num_classes=NUM_CLASSES, dropout=0.3).to(device)
    print('Using CompactAcousticNet')
else:
    model = SpatialTemporalAcousticNet(num_classes=NUM_CLASSES, dropout=0.3).to(device)
    print('Using SpatialTemporalAcousticNet')

# OPTIONAL: Load weights if you want to continue training
# model.load_state_dict(torch.load("checkpoints/best_model.pt")['model_state'])

total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params:,}')

# Use Focal Loss to further penalize ignoring rare classes
# alpha can be set to class weights, but the Sampler already handles balance.
# We use gamma=2.0 to focus on "hard" examples.
criterion = FocalLoss(gamma=2.0, reduction='mean') 
print('Using FocalLoss (gamma=2.0) to handle class hardness')

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)

Using SpatialTemporalAcousticNet
Total parameters: 1,268,901
Using FocalLoss (gamma=2.0) to handle class hardness


In [None]:
# Cell 6: Final Model Architecture (WideFieldNet)
# ---------------------------------------------------------
import torch.nn as nn
import torch

class WideFieldNet(nn.Module):
    """
    Wide-Field Acoustic Network.
    Designed for Reverberant Environments:
    1. Large Kernel (64) in Layer 1 to capture long-range Impulse Response (Echoes).
    2. Stride 2 to preserve phase cues better than MaxPool.
    3. LeakyReLU to prevent dead gradients during training.
    """
    def __init__(self, num_classes=4, in_channels=8, dropout=0.5):
        super().__init__()
        
        self.encoder = nn.Sequential(
            # Layer 1: The "Echo Catcher" (64 samples ~= 0.4ms)
            nn.Conv1d(in_channels, 64, kernel_size=64, stride=2, padding=31),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),

            # Layer 2: Texture & Decay
            nn.Conv1d(64, 128, kernel_size=32, stride=2, padding=15),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),

            # Layer 3: Deep Features
            nn.Conv1d(128, 256, kernel_size=16, stride=2, padding=7),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            
            # Layer 4: Global Context
            nn.Conv1d(256, 512, kernel_size=8, stride=2, padding=3),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.1),
            nn.AdaptiveAvgPool1d(1), # Squeeze time to 1 vector
        )

        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 128),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )
        
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')

    def forward(self, x):
        return self.head(self.encoder(x))

# Instantiate
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = WideFieldNet(num_classes=4, dropout=0.5).to(device)

# Standard AdamW is stable for this architecture
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

print("✅ WideFieldNet Loaded (Final Architecture).")

✅ WavefrontNet (2D CNN) Loaded.
   Treats audio as an 8-row image to detect phase slopes.


In [None]:
# Cell 7: WavefrontNet Training
# ---------------------------------------------------------
from tqdm.auto import tqdm
import torch
import numpy as np
from pathlib import Path
from sklearn.metrics import classification_report
from contextlib import nullcontext

# Config
EPOCHS = 20
LR = 0.001
save_dir = Path("checkpoints")
save_dir.mkdir(parents=True, exist_ok=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
criterion = torch.nn.CrossEntropyLoss()

scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler(enabled=False)
autocast_ctx = lambda: torch.amp.autocast('cuda') if torch.cuda.is_available() else nullcontext()

best_val_acc = 0.0

print("=" * 60)
print(f"STARTING 2D WAVEFRONT TRAINING")
print("=" * 60)

for epoch in range(EPOCHS):
    model.train()
    train_correct = 0
    train_total = 0
    pred_counts = {0:0, 1:0, 2:0, 3:0}
    
    pbar = tqdm(train_loader, desc=f"Ep {epoch+1}/{EPOCHS}", dynamic_ncols=True, colour="#4CAF50")
    for mic, action, _, _ in pbar:
        mic = mic.to(device, non_blocking=True)
        action = action.to(device, non_blocking=True)
        targets = action - 1 
        
        optimizer.zero_grad(set_to_none=True)
        with autocast_ctx():
            logits = model(mic)
            loss = criterion(logits, targets)
            
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        preds = logits.argmax(dim=1)
        train_correct += (preds == targets).sum().item()
        train_total += targets.numel()
        
        for p in preds.tolist(): pred_counts[p] += 1
            
        pbar.set_postfix(acc=train_correct/train_total)

    # --- VALIDATION ---
    model.eval()
    val_correct = 0
    val_total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for mic, action, _, _ in val_loader:
            mic = mic.to(device, non_blocking=True)
            action = action.to(device, non_blocking=True)
            targets = action - 1

            with autocast_ctx():
                logits = model(mic)
            
            preds = logits.argmax(dim=1)
            val_correct += (preds == targets).sum().item()
            val_total += targets.numel()
            
            all_preds.extend(preds.cpu().tolist())
            all_targets.extend(targets.cpu().tolist())

    avg_val_acc = val_correct / val_total
    scheduler.step(avg_val_acc)
    
    total_preds = sum(pred_counts.values())
    dist_str = " | ".join([f"{k}:{v/total_preds:.0%}" for k,v in pred_counts.items()])
    print(f"\n  Val Acc: {avg_val_acc:.3f} | Preds: {dist_str}")
    
    if avg_val_acc > best_val_acc:
        best_val_acc = avg_val_acc
        torch.save(model.state_dict(), "best_model_2d.pt")
        print("  --> Saved Best Model")
    
    print("-" * 60)

STARTING 2D WAVEFRONT TRAINING


Ep 1/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.250 | Preds: 0:32% | 1:27% | 2:23% | 3:19%
  --> Saved Best Model
------------------------------------------------------------


Ep 2/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.258 | Preds: 0:30% | 1:22% | 2:23% | 3:25%
  --> Saved Best Model
------------------------------------------------------------


Ep 3/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.250 | Preds: 0:28% | 1:23% | 2:23% | 3:26%
------------------------------------------------------------


Ep 4/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.250 | Preds: 0:25% | 1:24% | 2:25% | 3:25%
------------------------------------------------------------


Ep 5/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.252 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 6/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.250 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 7/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.250 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 8/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.343 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
  --> Saved Best Model
------------------------------------------------------------


Ep 9/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.253 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 10/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.291 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 11/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.351 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
  --> Saved Best Model
------------------------------------------------------------


Ep 12/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.292 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 13/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.356 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
  --> Saved Best Model
------------------------------------------------------------


Ep 14/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.319 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 15/20:   0%|          | 0/29 [00:00<?, ?it/s]


  Val Acc: 0.340 | Preds: 0:25% | 1:25% | 2:25% | 3:25%
------------------------------------------------------------


Ep 16/20:   0%|          | 0/29 [00:00<?, ?it/s]

In [None]:
# Cell: Deep Diagnostics (Inputs, Outputs, Gradients)
# ---------------------------------------------------------
import torch
import numpy as np

print("--- DIAGNOSTICS REPORT ---")

# 1. Check Input Statistics (Is data normalized?)
# ---------------------------------------------------------
batch = next(iter(train_loader))
mics, acts, _, _ = batch
mics = mics.to(device)
print(f"Input Shape: {mics.shape}")
print(f"Input Stats: Mean={mics.mean():.4f} | Std={mics.std():.4f}")
print(f"             Min={mics.min():.4f}  | Max={mics.max():.4f}")

if mics.std() < 0.1:
    print("⚠️ WARNING: Inputs are very quiet. Model might struggle to find signal.")

# 2. Check Model Confidence (Is it guessing blindly or confidently wrong?)
# ---------------------------------------------------------
model.eval()
with torch.no_grad():
    logits = model(mics)
    probs = torch.softmax(logits, dim=1)

print("\nModel Predictions (First 5 samples):")
print("   UP      DOWN    LEFT    RIGHT")
for p in probs[:5]:
    print(f"  [{p[0]:.4f}, {p[1]:.4f}, {p[2]:.4f}, {p[3]:.4f}]")

# 3. Check Gradients (Is the model learning or dead?)
# ---------------------------------------------------------
model.train()
model.zero_grad()
# Forward pass again for grads
logits = model(mics)
targets = (acts - 1).to(device)
loss = torch.nn.CrossEntropyLoss()(logits, targets)
loss.backward()

print("\nGradient Health (Layer-wise norms):")
dead_layers = 0
for name, param in model.named_parameters():
    if param.grad is not None and param.dim() > 1: # Only weights, ignore biases
        grad_norm = param.grad.norm().item()
        if grad_norm == 0.0:
            print(f"  ❌ DEAD (0.0): {name}")
            dead_layers += 1
        elif grad_norm > 10.0:
            print(f"  ⚠️ EXPLODING ({grad_norm:.2f}): {name}")
        else:
            # Print first layer only to save space
            if "front_end.0" in name:
                 print(f"  ✅ OK ({grad_norm:.4f}): {name}")

if dead_layers > 0:
    print(f"\nCRITICAL: {dead_layers} layers have zero gradient. Use LeakyReLU or check initialization.")
else:
    print("\nGradients look healthy. The architecture just needs to be smarter.")

--- DIAGNOSTICS REPORT ---
Input Shape: torch.Size([256, 8, 11434])
Input Stats: Mean=-0.0000 | Std=1.0000
             Min=-15.4346  | Max=16.2569

Model Predictions (First 5 samples):
   UP      DOWN    LEFT    RIGHT
  [0.1222, 0.5740, 0.0076, 0.2962]
  [0.0634, 0.6904, 0.0100, 0.2362]
  [0.0793, 0.6417, 0.0235, 0.2555]
  [0.1160, 0.5552, 0.0177, 0.3110]
  [0.1123, 0.5922, 0.0073, 0.2882]

Gradient Health (Layer-wise norms):

Gradients look healthy. The architecture just needs to be smarter.
