# Acoustic Navigation - Model Training

## Training Pipeline Overview

**Task**: Predict optimal navigation action from acoustic sensor data

### Agent Setup:
- Agent occupies center of 3x3 grid cell
- 8 microphones arranged in circle around agent (1 cell radius)
- Each microphone records pressure time-series
- **CRITICAL**: No part of 3x3 agent footprint can be inside walls!

### Training Approach:
1. Sample agent positions from navigable cave space (ensuring 3x3 validity)
2. Extract 8-mic acoustic data at each position
3. Predict action: STOP, UP, DOWN, LEFT, RIGHT
4. Handle class imbalance (STOP is rarer than movement actions)

### This Notebook:
1. Load HDF5 dataset
2. Create proper PyTorch DataLoader with 3x3 validation
3. Visualize data distribution (action balance, agent positions)
4. Show sample agent with 8-mic array
5. Prepare for model training (next step)

In [None]:
import sys
sys.path.append('../')

import numpy as np
import h5py
import matplotlib.pyplot as plt
from pathlib import Path
from collections import Counter

import torch

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)


## 1. Configuration

In [None]:
# Dataset paths (folder with many cave_*.h5 files)
DATASET_DIR = Path('D:/audiomaze_dataset')
H5_FILES = sorted(DATASET_DIR.glob('cave_*.h5'))
assert len(H5_FILES) > 0, f"No cave_*.h5 files found in {DATASET_DIR}"

# Agent configuration
AGENT_RADIUS = 1

# Import dataset utilities
from src.cave_dataset import (
    MultiCaveDataset,
    ACTION_MAP,
    ACTION_NAMES,
    MIC_OFFSETS,
    compute_class_distribution,
    compute_class_weights,
)

print(f"Found {len(H5_FILES)} files in {DATASET_DIR}")


## 2. Load and Inspect Dataset

In [None]:
# Load first H5 file for visualization reference (file 0)
viz_file = H5_FILES[0]
with h5py.File(viz_file, 'r') as f:
    first_key = list(f.keys())[0]
    cave_grid = f[first_key]['cave_grid'][:]
    action_grid_raw = f[first_key]['action_grid'][:]
    if action_grid_raw.dtype.kind == 'S':
        action_grid_str = np.vectorize(lambda x: x.decode('utf-8'))(action_grid_raw)
    else:
        action_grid_str = action_grid_raw.astype(str)
    pressure_field = f[first_key]['pressure_timeseries'][:]
    end_pos = tuple(f[first_key].attrs['end_position'])
    start_pos = tuple(f[first_key].attrs.get('start_position', (-1, -1)))
    dt = f[first_key].attrs['dt']
    f0 = f[first_key].attrs['frequency_hz']

Nx, Ny, Nt = pressure_field.shape

print(f"Loaded reference file: {viz_file}")
print(f"Scene key: {first_key}")
print(f"Grid size: {Nx}x{Ny}, Nt={Nt}")
print(f"Start: {start_pos}, Goal: {end_pos}")


## 3. Create PyTorch Dataset

**Key Logic**:
- Agent at (y, x) occupies 3x3 square centered at (y, x)
- Extract 8-mic pressure time-series from surrounding cells
- Label = action at center position
- With 3x3 expansion system, actions only exist at valid centers (every 3rd cell)

In [None]:
# Build multi-file dataset
# Note: MultiCaveDataset handles valid-position filtering (3x3 footprint) per file

dataset = MultiCaveDataset(H5_FILES, agent_radius=AGENT_RADIUS, mic_offsets=MIC_OFFSETS, action_map=ACTION_MAP)
print(f"Total valid positions across all files: {len(dataset.valid_positions)}")


## 4. Analyze Action Distribution

**Important**: STOP will be rarer than movement actions (only at goal).  
We'll need to handle class imbalance during training.

In [None]:
# Get action distribution across all files
action_dist = dataset.get_action_distribution()

print("Action Distribution:")
print("="*50)
total_samples = len(dataset)
for action_name in ACTION_NAMES:
    action_str = action_name.lower()
    count = action_dist.get(action_str, 0)
    percentage = 100 * count / total_samples
    print(f"  {action_name:>5s}: {count:>6d} ({percentage:>5.2f}%)")
print("="*50)
print(f"Total samples (valid positions): {total_samples}")


## 5. Visualize Valid Agent Positions

Show where agents can be placed (3x3 expansion guarantees all are valid)

In [None]:
# Create valid position mask for the first file (viz only)
viz_valid = np.zeros_like(cave_grid)
for y, x in dataset.file_infos[0]['valid']:
    viz_valid[y, x] = 1

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Cave layout
axes[0].imshow(cave_grid, origin='upper', cmap='binary')
axes[0].scatter([end_pos[1]], [end_pos[0]], s=200, c='red', marker='*',
               edgecolors='black', linewidths=2, label='Goal', zorder=10)
axes[0].set_title(f'Cave Layout ({Nx}x{Ny})', fontsize=14, fontweight='bold')
axes[0].set_xlabel('X')
axes[0].set_ylabel('Y')
axes[0].legend()
axes[0].axis('image')

# Valid agent positions (first file)
axes[1].imshow(viz_valid, origin='upper', cmap='Greens', alpha=0.6)
axes[1].contour(cave_grid, levels=[0.5], colors='black', linewidths=1)
axes[1].scatter([end_pos[1]], [end_pos[0]], s=200, c='red', marker='*',
               edgecolors='black', linewidths=2, label='Goal', zorder=10)
axes[1].set_title(f'Valid Agent Positions ({len(dataset.file_infos[0]["valid"])} total)', 
                 fontsize=14, fontweight='bold')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
axes[1].legend()
axes[1].axis('image')

plt.tight_layout()
plt.show()

print(f"Valid agent positions (file 0): {len(dataset.file_infos[0]['valid'])} / {(Nx*Ny)} total cells")


## 6. Visualize Sample Agent with 8-Mic Array

Pick a random agent position and show:
- 3x3 footprint
- 8 microphone positions
- Pressure time-series at each mic
- Action label

In [None]:
# Pick a random sample from the dataset
sample_idx = np.random.randint(0, len(dataset))
mic_data, action, (agent_y, agent_x), file_idx = dataset.get_sample_with_position(sample_idx)
action_name = ACTION_NAMES[action.item()]

print(f"Sample {sample_idx} from file #{file_idx} ({dataset.file_infos[file_idx]['path'].name}):")
print(f"  Agent position: ({agent_y}, {agent_x})")
print(f"  Action: {action_name}")
print(f"  Mic data shape: {mic_data.shape}")
print(f"  Cave grid at agent: {dataset.file_infos[file_idx]['cave_grid'][agent_y, agent_x]}")


In [None]:
# FULL MAP VISUALIZATION - Agent, Mics, and Goal (per sampled file)
info = dataset.file_infos[file_idx]
cg = info['cave_grid']
ag = info['action_grid']
end_local = info['end_pos']
start_local = info['start_pos']
Nx_local, Ny_local = cg.shape

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

ax = axes[0]
ax.imshow(cg, origin='upper', cmap='binary', alpha=0.8)
ax.contour(cg, levels=[0.5], colors='gray', linewidths=1.5, alpha=0.5)

from matplotlib.patches import Rectangle
r = AGENT_RADIUS
rect = Rectangle((agent_x - r - 0.5, agent_y - r - 0.5), 2*r + 1, 2*r + 1,
                linewidth=4, edgecolor='blue', facecolor='blue', alpha=0.3, label='3x3 Footprint')
ax.add_patch(rect)

ax.scatter([agent_x], [agent_y], s=500, c='blue', marker='o',
           edgecolors='white', linewidths=3, label='Agent', zorder=10)

# Mic positions; only draw mics that are in air
mic_ys = [agent_y + dy for dy, dx in MIC_OFFSETS]
mic_xs = [agent_x + dx for dy, dx in MIC_OFFSETS]
mic_in_air = [(mx, my) for mx, my in zip(mic_xs, mic_ys) if 0 <= my < Nx_local and 0 <= mx < Ny_local and cg[int(my), int(mx)] == 0]
if mic_in_air:
    ax.scatter([mx for mx,_ in mic_in_air], [my for _,my in mic_in_air], s=150, c='lime', marker='^',
               edgecolors='black', linewidths=2, label='8 Microphones (air only)', zorder=9)

ax.scatter([end_local[1]], [end_local[0]], s=400, c='red', marker='*',
           edgecolors='black', linewidths=2, label='Goal (Sound Source)', zorder=8)

ax.set_title(f'Full Cave Map - Agent at ({agent_y},{agent_x})
Action: {action_name} (file {file_idx})',
             fontsize=14, fontweight='bold')
ax.set_xlabel('Column Index', fontsize=12)
ax.set_ylabel('Row Index', fontsize=12)
ax.legend(loc='upper left', fontsize=10, framealpha=0.9)
ax.axis('image')
ax.grid(True, alpha=0.2, linestyle='--', linewidth=0.5)

# Right: Zoomed view around agent (same origin)
ax = axes[1]
zoom_radius = 10
y_min = max(0, agent_y - zoom_radius)
y_max = min(Nx_local, agent_y + zoom_radius + 1)
x_min = max(0, agent_x - zoom_radius)
x_max = min(Ny_local, agent_x + zoom_radius + 1)

cave_zoom = cg[y_min:y_max, x_min:x_max]
extent_zoom = [x_min - 0.5, x_max - 0.5, y_min - 0.5, y_max - 0.5]
ax.imshow(cave_zoom, origin='upper', cmap='binary', alpha=0.8, extent=extent_zoom)

rect_zoom = Rectangle((agent_x - r - 0.5, agent_y - r - 0.5), 2*r + 1, 2*r + 1,
                      linewidth=4, edgecolor='blue', facecolor='blue', alpha=0.3)
ax.add_patch(rect_zoom)

if mic_in_air:
    ax.scatter([mx for mx,_ in mic_in_air], [my for _,my in mic_in_air], s=150, c='lime', marker='^',
               edgecolors='black', linewidths=2, label='Mics (air only)', zorder=9)

mic_labels = ['R', 'DR', 'D', 'DL', 'L', 'UL', 'U', 'UR']
for label, (mx, my) in zip(mic_labels, zip(mic_xs, mic_ys)):
    if 0 <= my < Nx_local and 0 <= mx < Ny_local and cg[int(my), int(mx)] == 0:
        ax.text(mx, my - 0.6, label, fontsize=9, fontweight='bold',
                ha='center', va='top', color='darkgreen',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8, edgecolor='none'))

ax.set_title(f'Zoomed View (±{zoom_radius} cells)', fontsize=14, fontweight='bold')
ax.set_xlabel('Column Index', fontsize=12)
ax.set_ylabel('Row Index', fontsize=12)
ax.legend(loc='upper left', fontsize=10, framealpha=0.9)
ax.axis('image')
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)

plt.tight_layout()

# Print summary
print(f"{'='*60}")
print(f"AGENT VERIFICATION")
print(f"{'='*60}")
print(f"Agent Position: ({agent_y}, {agent_x})")
print(f"Cave grid at agent: {cg[agent_y, agent_x]} (0=air, 1=wall)")
print(f"3x3 footprint all air: {np.all(cg[agent_y - r:agent_y + r + 1, agent_x - r:agent_x + r + 1] == 0)}")
print(f"Action label: {action_name}")
print(f"Goal Position: {end_local}")
print(f"Distance to goal: {np.sqrt((agent_y-end_local[0])**2 + (agent_x-end_local[1])**2):.1f} cells")
print(f"File: {info['path'].name}")
print(f"{'='*60}")


In [None]:
# Follow action labels from current agent position to goal (correct file)
info = dataset.file_infos[file_idx]
cg = info['cave_grid']
ag = info['action_grid']
end_local = info['end_pos']

moves = {'up': (-1,0), 'down': (1,0), 'left': (0,-1), 'right': (0,1), 'stop': (0,0)}

def follow_path(start_y, start_x, max_steps=2000):
    path = [(int(start_y), int(start_x))]
    r, c = start_y, start_x
    for step in range(max_steps):
        a = ag[r, c]
        if isinstance(a, bytes):
            a = a.decode('utf-8')
        if a == 'stop':
            return True, path
        if a not in moves:
            return False, path
        dr, dc = moves[a]
        nr, nc = r + dr, c + dc
        if not (0 <= nr < cg.shape[0] and 0 <= nc < cg.shape[1]):
            return False, path
        if cg[nr, nc] == 1:
            return False, path
        r, c = nr, nc
        path.append((int(r), int(c)))
    return False, path

ok, path = follow_path(agent_y, agent_x)
print(f"Path follow result: {'SUCCESS' if ok else 'FAIL'}, steps={len(path)-1}, end={path[-1]}, file={info['path'].name}")

fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(cg, origin='upper', cmap='binary', alpha=0.8)
ys = [p[0] for p in path]; xs = [p[1] for p in path]
ax.plot(xs, ys, '-o', color='lime' if ok else 'red', linewidth=2, markersize=4)
ax.scatter([end_local[1]], [end_local[0]], s=200, c='red', marker='*', edgecolors='black', linewidths=1.5)
ax.set_title('Action-Label Path to Goal')
ax.set_xlabel('Column Index'); ax.set_ylabel('Row Index'); ax.axis('image')
plt.show()


In [None]:
# Validate mic footprint across all valid positions in all files
mic_hits_wall = 0
for file_idx, info in enumerate(dataset.file_infos):
    cg = info['cave_grid']
    for (y, x) in info['valid']:
        for dy, dx in MIC_OFFSETS:
            my, mx = y + dy, x + dx
            if not (0 <= my < cg.shape[0] and 0 <= mx < cg.shape[1]) or cg[my, mx] == 1:
                mic_hits_wall += 1
                break

print(f"Valid positions (all files): {len(dataset.valid_positions)}")
print(f"Positions where any mic would hit a wall: {mic_hits_wall}")
assert mic_hits_wall == 0, "Some mic locations fall on walls; check MIC_OFFSETS or valid_positions logic"

import random
rng = random.Random(0)
samples = rng.sample(dataset.valid_positions, min(5, len(dataset.valid_positions)))
print("Sampled positions (file, y, x): action")
for file_idx, y, x in samples:
    a = dataset.file_infos[file_idx]['action_grid'][y, x]
    print(f"  (file {file_idx}, {y}, {x}): {a.upper()}")


## 7. Create DataLoader

Prepare PyTorch DataLoader for training

In [None]:
# Training is handled in 03_Training.ipynb.
# This notebook focuses on dataset inspection and visualization.


## 8. Compute Class Weights for Training

To handle class imbalance (STOP is rare), compute class weights

In [None]:
# Class weight computation is performed in 03_Training.ipynb using compute_class_distribution/weights.


## Summary

### Dataset Ready for Training!

**What we have:**
- ✓ PyTorch Dataset with 3x3 footprint validation
- ✓ DataLoaders for train/val splits
- ✓ Action distribution analysis
- ✓ Class weights to handle imbalance
- ✓ Verified 8-mic data extraction

**Next Steps:**
1. Design neural network architecture (CNN/RNN)
2. Implement training loop with class weights
3. Monitor per-class accuracy (especially STOP)
4. Evaluate on validation set

In [None]:
# Diagnose STOP coverage and visualize the first problematic file (if any)
bad = []
r = AGENT_RADIUS

for i, info in enumerate(dataset.file_infos):
    cg = info['cave_grid']
    ag = info['action_grid']
    y, x = info['end_pos']
    label = ag[y, x]
    if isinstance(label, bytes):
        label = label.decode('utf-8')
    is_stop = (label == 'stop')
    footprint_ok = (
        y - r >= 0 and y + r < cg.shape[0] and
        x - r >= 0 and x + r < cg.shape[1] and
        np.all(cg[y - r:y + r + 1, x - r:x + r + 1] == 0)
    )
    stop_in_valid = any((y == vy and x == vx) for vy, vx in info['valid'])
    if not (is_stop and footprint_ok and stop_in_valid):
        bad.append((i, is_stop, footprint_ok, stop_in_valid, (y, x), info['path'].name))

print(f"Files with STOP missing from valid_positions: {len(bad)}")
if bad:
    print("First few:", bad[:5])

# Visualize the first bad file (if any)
if bad:
    idx, is_stop, footprint_ok, stop_in_valid, (y, x), fname = bad[0]
    info = dataset.file_infos[idx]
    cg = info['cave_grid']
    ag = info['action_grid']
    if isinstance(ag[y, x], bytes):
        ag = np.vectorize(lambda t: t.decode('utf-8') if isinstance(t, bytes) else t)(ag)

    print(f"\nInspecting file #{idx}: {fname}")
    print(f"end_pos={info['end_pos']}, label_at_end={ag[y, x]}, footprint_ok={footprint_ok}, stop_in_valid={stop_in_valid}")

    from matplotlib.patches import Rectangle
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.imshow(cg, origin='upper', cmap='binary', alpha=0.8)
    ax.scatter([info['end_pos'][1]], [info['end_pos'][0]], c='red', marker='*', s=200, edgecolors='black')
    rect = Rectangle((x - r - 0.5, y - r - 0.5), 2 * r + 1, 2 * r + 1,
                     linewidth=3, edgecolor='blue', facecolor='none')
    ax.add_patch(rect)
    ax.set_title(f"{fname} | STOP label: {ag[y, x]} | footprint_ok={footprint_ok}")
    ax.axis('image')
    plt.show()
