# Acoustic Navigation - Model Training

## Training Pipeline Overview

**Task**: Predict optimal navigation action from acoustic sensor data

### Agent Setup:
- Agent occupies center of 3x3 grid cell
- 8 microphones arranged in circle around agent (1 cell radius)
- Each microphone records pressure time-series
- **CRITICAL**: No part of 3x3 agent footprint can be inside walls!

### Training Approach:
1. Sample agent positions from navigable cave space (ensuring 3x3 validity)
2. Extract 8-mic acoustic data at each position
3. Predict action: STOP, UP, DOWN, LEFT, RIGHT
4. Handle class imbalance (STOP is rarer than movement actions)

### This Notebook:
1. Load HDF5 dataset
2. Create proper PyTorch DataLoader with 3x3 validation
3. Visualize data distribution (action balance, agent positions)
4. Show sample agent with 8-mic array
5. Prepare for model training (next step)

In [1]:
import sys
sys.path.append('../')

import numpy as np
import h5py
import matplotlib.pyplot as plt
from pathlib import Path
from collections import Counter
from scipy.fft import fft, fftfreq

import torch
from torch.utils.data import Dataset, DataLoader

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

print("Libraries loaded successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

Libraries loaded successfully!
PyTorch version: 2.9.1+cpu
CUDA available: False


## 1. Configuration

In [None]:
# Dataset path
DATASET_PATH = Path('../dataset/acoustic_cave_dataset.h5')

# Agent configuration
AGENT_RADIUS = 1  # 3x3 grid means radius of 1 cell from center
NUM_MICS = 8      # 8 microphones in circular array

# Action mapping: string labels -> numeric labels
ACTION_MAP = {
    'stop': 0,
    'up': 1,
    'down': 2,
    'left': 3,
    'right': 4,
    '': -1  # Invalid/wall
}

# Model output classes
ACTION_NAMES = ['STOP', 'UP', 'DOWN', 'LEFT', 'RIGHT']
NUM_CLASSES = len(ACTION_NAMES)

# Microphone positions (relative to agent center)
MIC_OFFSETS = [
    (0, 1),   # Right
    (1, 1),   # Down-right
    (1, 0),   # Down
    (1, -1),  # Down-left
    (0, -1),  # Left
    (-1, -1), # Up-left
    (-1, 0),  # Up
    (-1, 1)   # Up-right
]

print(f"Dataset: {DATASET_PATH}")
print(f"Agent footprint: {2*AGENT_RADIUS+1}x{2*AGENT_RADIUS+1} = {(2*AGENT_RADIUS+1)**2} cells")
print(f"Microphones: {NUM_MICS}")
print(f"Actions: {NUM_CLASSES} classes - {ACTION_NAMES}")

## 2. Load and Inspect Dataset

In [None]:
# Load HDF5 dataset
print("Loading dataset...")
with h5py.File(DATASET_PATH, 'r') as f:
    print(f"\nAvailable scenes: {list(f.keys())}")
    
    # Load first scene (cave_0001)
    scene = f['cave_0001']
    
    print(f"\nScene datasets:")
    for key in scene.keys():
        print(f"  {key}: {scene[key].shape} ({scene[key].dtype})")
    
    print(f"\nMetadata:")
    for key, value in scene.attrs.items():
        print(f"  {key}: {value}")
    
    # Load data into memory
    cave_grid = scene['cave_grid'][:]
    action_grid = scene['action_grid'][:]  # Will be bytes (b'up', b'down', etc.)
    pressure_field = scene['pressure_timeseries'][:]  # (Nx, Ny, Nt)
    
    # Metadata
    start_pos = tuple(scene.attrs['start_position'])
    end_pos = tuple(scene.attrs['end_position'])  # Goal = sound source
    dt = scene.attrs['dt']
    f0 = scene.attrs['frequency_hz']

# Convert action_grid from bytes to strings
action_grid_str = np.vectorize(lambda x: x.decode('utf-8') if isinstance(x, bytes) else x)(action_grid)

Nx, Ny, Nt = pressure_field.shape

print(f"\n{'='*60}")
print(f"Loaded Cave Scene:")
print(f"  Grid size: {Nx}x{Ny}")
print(f"  Time steps: {Nt}")
print(f"  Sampling rate: {1/dt:.0f} Hz")
print(f"  Source frequency: {f0/1000:.1f} kHz")
print(f"  Start: {start_pos}")
print(f"  Goal: {end_pos}")
print(f"  Walls: {100*cave_grid.mean():.1f}%")
print(f"  Air: {100*(1-cave_grid.mean()):.1f}%")
print(f"{'='*60}")

## 3. Create PyTorch Dataset

**Key Logic**:
- Agent at (y, x) occupies 3x3 square centered at (y, x)
- Extract 8-mic pressure time-series from surrounding cells
- Label = action at center position
- With 3x3 expansion system, actions only exist at valid centers (every 3rd cell)

In [None]:
class AcousticCaveDataset(Dataset):
    """
    PyTorch Dataset for acoustic navigation in caves.
    
    Agent occupies 3x3 footprint (center + 8 surrounding cells).
    Dataset validates that entire footprint is in navigable space (air).
    """
    
    def __init__(self, cave_grid, action_grid, pressure_field,
                 agent_radius=1, mic_offsets=None, action_map=None):
        """
        Args:
            cave_grid: (Nx, Ny) binary grid (0=air, 1=wall)
            action_grid: (Nx, Ny) string array ('up', 'down', 'left', 'right', 'stop')
            pressure_field: (Nx, Ny, Nt) pressure time-series
            agent_radius: Radius of agent footprint (1 = 3x3 grid)
            mic_offsets: List of (dy, dx) tuples for microphone positions
            action_map: Dict mapping action strings to integer labels
        """
        self.cave_grid = cave_grid
        self.action_grid = action_grid
        self.pressure_field = pressure_field
        self.agent_radius = agent_radius
        self.mic_offsets = mic_offsets if mic_offsets else MIC_OFFSETS
        self.action_map = action_map if action_map else ACTION_MAP
        
        self.Nx, self.Ny, self.Nt = pressure_field.shape
        
        # Find all valid agent positions (with 3x3 footprint validation)
        self.valid_positions = self._find_valid_positions()
        
        print(f"AcousticCaveDataset initialized:")
        print(f"  Cave size: {self.Nx}x{self.Ny}")
        print(f"  Agent footprint: {2*agent_radius+1}x{2*agent_radius+1}")
        print(f"  Valid positions: {len(self.valid_positions)}")
        print(f"  Microphones: {len(self.mic_offsets)}")
        print(f"  Time steps per mic: {self.Nt}")
    
    def _is_valid_footprint(self, y, x):
        """Check if 3x3 footprint centered at (y,x) is entirely in air."""
        r = self.agent_radius
        
        # Check bounds
        if y - r < 0 or y + r >= self.Nx or x - r < 0 or x + r >= self.Ny:
            return False
        
        # Check if all cells in 3x3 footprint are air (0)
        footprint = self.cave_grid[y-r:y+r+1, x-r:x+r+1]
        return np.all(footprint == 0)
    
    def _find_valid_positions(self):
        """Find all positions where action labels exist AND 3x3 footprint is valid."""
        valid = []
        
        for y in range(self.Nx):
            for x in range(self.Ny):
                # Check if action exists at this position
                action_str = self.action_grid[y, x]
                if action_str in self.action_map and action_str != '':
                    # Validate 3x3 footprint is in air
                    if self._is_valid_footprint(y, x):
                        valid.append((y, x))
        
        return valid
    
    def __len__(self):
        return len(self.valid_positions)
    
    def __getitem__(self, idx):
        """
        Returns:
            mic_data: (8, Nt) tensor - pressure time-series for 8 mics
            action: int - action label (0-4)
        """
        y, x = self.valid_positions[idx]
        
        # Extract 8-mic time-series
        mic_data = []
        for dy, dx in self.mic_offsets:
            mic_y, mic_x = y + dy, x + dx
            # Get pressure time-series at this mic location
            pressure = self.pressure_field[mic_y, mic_x, :]
            mic_data.append(pressure)
        
        mic_data = np.array(mic_data, dtype=np.float32)  # (8, Nt)
        
        # Get action label
        action_str = self.action_grid[y, x]
        action_label = self.action_map[action_str]
        
        return torch.from_numpy(mic_data), torch.tensor(action_label, dtype=torch.long)
    
    def get_action_distribution(self):
        """Compute action distribution across all valid positions."""
        action_counts = {i: 0 for i in range(NUM_CLASSES)}
        
        for y, x in self.valid_positions:
            action_str = self.action_grid[y, x]
            action_num = self.action_map[action_str]
            action_counts[action_num] += 1
        
        # Convert to named dict
        named_counts = {ACTION_NAMES[i].lower(): action_counts[i] for i in range(NUM_CLASSES)}
        return Counter(named_counts)
    
    def get_sample_with_position(self, idx):
        """Get sample with position info (for visualization)."""
        y, x = self.valid_positions[idx]
        mic_data, action = self.__getitem__(idx)
        return mic_data, action, (y, x)

# Create dataset
dataset = AcousticCaveDataset(
    cave_grid=cave_grid,
    action_grid=action_grid_str,
    pressure_field=pressure_field,
    agent_radius=AGENT_RADIUS,
    mic_offsets=MIC_OFFSETS,
    action_map=ACTION_MAP
)

## 4. Analyze Action Distribution

**Important**: STOP will be rarer than movement actions (only at goal).  
We'll need to handle class imbalance during training.

In [None]:
# Get action distribution
action_dist = dataset.get_action_distribution()

print("Action Distribution:")
print("="*50)
total_samples = len(dataset)

for action_name in ACTION_NAMES:
    action_str = action_name.lower()
    count = action_dist.get(action_str, 0)
    percentage = 100 * count / total_samples
    print(f"  {action_name:>5s}: {count:>5d} ({percentage:>5.2f}%)")

print("="*50)
print(f"  TOTAL: {total_samples:>5d} (100.00%)")

# Visualize distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
counts = [action_dist.get(name.lower(), 0) for name in ACTION_NAMES]
colors = ['red', 'blue', 'cyan', 'orange', 'green']
bars = axes[0].bar(ACTION_NAMES, counts, color=colors, alpha=0.7, edgecolor='black')
axes[0].set_xlabel('Action', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Action Distribution', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, count in zip(bars, counts):
    height = bar.get_height()
    axes[0].text(bar.get_x() + bar.get_width()/2., height,
                f'{count}\n({100*count/total_samples:.1f}%)',
                ha='center', va='bottom', fontsize=10)

# Pie chart
axes[1].pie(counts, labels=ACTION_NAMES, autopct='%1.1f%%', 
           colors=colors, startangle=90, textprops={'fontsize': 11})
axes[1].set_title('Action Distribution (Percentage)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Class imbalance warning
min_count = min(counts) if counts else 0
max_count = max(counts) if counts else 0
imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')

print(f"\nClass Imbalance Analysis:")
if counts:
    print(f"  Min count: {min_count} ({ACTION_NAMES[counts.index(min_count)]})")
    print(f"  Max count: {max_count} ({ACTION_NAMES[counts.index(max_count)]})")
    print(f"  Imbalance ratio: {imbalance_ratio:.2f}x")

if imbalance_ratio > 5:
    print(f"  ⚠️  HIGH IMBALANCE! Consider using class weights during training.")
elif imbalance_ratio > 2:
    print(f"  ⚠️  MODERATE IMBALANCE. Monitor per-class performance.")
else:
    print(f"  ✓ Balanced classes.")

## 5. Visualize Valid Agent Positions

Show where agents can be placed (3x3 expansion guarantees all are valid)

In [None]:
# Create valid position mask
valid_mask = np.zeros_like(cave_grid)
for y, x in dataset.valid_positions:
    valid_mask[y, x] = 1

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Cave layout
axes[0].imshow(cave_grid, origin='upper', cmap='binary')
axes[0].scatter([end_pos[1]], [end_pos[0]], s=200, c='red', marker='*',
               edgecolors='black', linewidths=2, label='Goal', zorder=10)
axes[0].set_title(f'Cave Layout ({Nx}x{Ny})', fontsize=14, fontweight='bold')
axes[0].set_xlabel('X')
axes[0].set_ylabel('Y')
axes[0].legend()
axes[0].axis('image')

# Valid agent positions
axes[1].imshow(valid_mask, origin='upper', cmap='Greens', alpha=0.6)
axes[1].contour(cave_grid, levels=[0.5], colors='black', linewidths=1)
axes[1].scatter([end_pos[1]], [end_pos[0]], s=200, c='red', marker='*',
               edgecolors='black', linewidths=2, label='Goal', zorder=10)
axes[1].set_title(f'Valid Agent Positions ({len(dataset.valid_positions)} total)', 
                 fontsize=14, fontweight='bold')
axes[1].set_xlabel('X')
axes[1].set_ylabel('Y')
axes[1].legend()
axes[1].axis('image')

plt.tight_layout()
plt.show()

print(f"Valid agent positions: {len(dataset.valid_positions)} / {(Nx*Ny)} total cells")
print(f"Coverage: {100*len(dataset.valid_positions)/(Nx*Ny):.1f}%")

## 6. Visualize Sample Agent with 8-Mic Array

Pick a random agent position and show:
- 3x3 footprint
- 8 microphone positions
- Pressure time-series at each mic
- Action label

In [None]:
# Pick a random sample
sample_idx = np.random.randint(0, len(dataset))
mic_data, action, (agent_y, agent_x) = dataset.get_sample_with_position(sample_idx)

action_name = ACTION_NAMES[action.item()]

print(f"Sample {sample_idx}:")
print(f"  Agent position: ({agent_y}, {agent_x})")
print(f"  Action: {action_name}")
print(f"  Mic data shape: {mic_data.shape}")
print(f"  Cave grid at agent: {cave_grid[agent_y, agent_x]} (should be 0=air)")

# Verify 3x3 footprint
r = AGENT_RADIUS
footprint = cave_grid[agent_y-r:agent_y+r+1, agent_x-r:agent_x+r+1]
print(f"  3x3 footprint all air: {np.all(footprint == 0)}")

# Create comprehensive visualization
fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)

# 1. Agent position in FULL cave (no zoom)
ax1 = fig.add_subplot(gs[0, 0])
ax1.imshow(cave_grid.T, origin='lower', cmap='binary', alpha=0.6)
ax1.contour(cave_grid.T, levels=[0.5], colors='black', linewidths=1)

# Draw 3x3 footprint
from matplotlib.patches import Rectangle
rect = Rectangle((agent_y-r-0.5, agent_x-r-0.5), 2*r+1, 2*r+1, 
                linewidth=3, edgecolor='blue', facecolor='blue', alpha=0.2)
ax1.add_patch(rect)

# Agent center
ax1.scatter([agent_y], [agent_x], s=300, c='blue', marker='o',
           edgecolors='white', linewidths=2, label='Agent', zorder=10)

# 8 microphones
mic_ys = [agent_y + dy for dy, dx in MIC_OFFSETS]
mic_xs = [agent_x + dx for dy, dx in MIC_OFFSETS]
ax1.scatter(mic_ys, mic_xs, s=100, c='lime', marker='^',
           edgecolors='black', linewidths=1, label='Mics', zorder=9)

# Goal
ax1.scatter([end_pos[0]], [end_pos[1]], s=200, c='red', marker='*',
           edgecolors='black', linewidths=1, label='Goal', zorder=8)

ax1.set_title(f'Agent Position (FULL CAVE)\nAction: {action_name}', fontsize=12, fontweight='bold')
ax1.set_xlabel('Row')
ax1.set_ylabel('Col')
ax1.legend(loc='upper right', fontsize=8)
ax1.axis('image')
# NO xlim/ylim - show full cave!

# 2. 3x3 Footprint Detail
ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(footprint.T, origin='lower', cmap='binary', alpha=0.6)
ax2.scatter([r], [r], s=300, c='blue', marker='o',
           edgecolors='white', linewidths=2, label='Agent')
mic_ys_local = [r + dy for dy, dx in MIC_OFFSETS]
mic_xs_local = [r + dx for dy, dx in MIC_OFFSETS]
ax2.scatter(mic_ys_local, mic_xs_local, s=100, c='lime', marker='^',
           edgecolors='black', linewidths=1, label='Mics')
ax2.set_title(f'3x3 Footprint Detail\n(All air = {footprint.sum() == 0})', 
             fontsize=12, fontweight='bold')
ax2.set_xlabel('Local Row')
ax2.set_ylabel('Local Col')
ax2.legend(fontsize=8)
ax2.axis('image')

# 3. Mic data statistics
ax3 = fig.add_subplot(gs[0, 2:])
mic_rms = np.sqrt(np.mean(mic_data.numpy()**2, axis=1))
mic_peak = np.max(np.abs(mic_data.numpy()), axis=1)
mic_labels = ['R', 'DR', 'D', 'DL', 'L', 'UL', 'U', 'UR']

x_pos = np.arange(8)
width = 0.35
ax3.bar(x_pos - width/2, mic_rms, width, label='RMS', alpha=0.7)
ax3.bar(x_pos + width/2, mic_peak, width, label='Peak', alpha=0.7)
ax3.set_xlabel('Microphone')
ax3.set_ylabel('Pressure')
ax3.set_title('Microphone Statistics', fontsize=12, fontweight='bold')
ax3.set_xticks(x_pos)
ax3.set_xticklabels(mic_labels)
ax3.legend()
ax3.grid(True, alpha=0.3, axis='y')

# 4-11. Time-series for each mic
time_array = np.arange(Nt) * dt * 1000  # Convert to ms

for i in range(8):
    row = 1 + i // 4
    col = i % 4
    ax = fig.add_subplot(gs[row, col])
    
    ax.plot(time_array, mic_data[i].numpy(), linewidth=1)
    ax.set_title(f'Mic {i+1} ({mic_labels[i]})', fontsize=10, fontweight='bold')
    ax.set_xlabel('Time (ms)', fontsize=8)
    ax.set_ylabel('Pressure', fontsize=8)
    ax.grid(True, alpha=0.3)
    ax.tick_params(labelsize=8)

plt.suptitle(f'Agent Sample - Position ({agent_y},{agent_x}) - Action: {action_name}',
            fontsize=16, fontweight='bold', y=0.995)
plt.show()

print(f"\nMicrophone Data Summary:")
print(f"  Shape: {mic_data.shape} (8 mics × {Nt} timesteps)")
print(f"  RMS range: [{mic_rms.min():.6f}, {mic_rms.max():.6f}]")
print(f"  Peak range: [{mic_peak.min():.6f}, {mic_peak.max():.6f}]")

In [None]:
# FULL MAP VISUALIZATION - Agent, Mics, and Goal
# This shows the complete cave with the agent's 3x3 footprint clearly visible

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Left: Full cave with agent (origin upper so y increases downward, matching action grid)
ax = axes[0]
ax.imshow(cave_grid, origin='upper', cmap='binary', alpha=0.8)
ax.contour(cave_grid, levels=[0.5], colors='gray', linewidths=1.5, alpha=0.5)

from matplotlib.patches import Rectangle
r = AGENT_RADIUS
rect = Rectangle((agent_x - r - 0.5, agent_y - r - 0.5), 2*r + 1, 2*r + 1,
                linewidth=4, edgecolor='blue', facecolor='blue', alpha=0.3, label='3x3 Footprint')
ax.add_patch(rect)

ax.scatter([agent_x], [agent_y], s=500, c='blue', marker='o',
           edgecolors='white', linewidths=3, label='Agent', zorder=10)

# Mic positions; only draw mics that are in air
mic_ys = [agent_y + dy for dy, dx in MIC_OFFSETS]
mic_xs = [agent_x + dx for dy, dx in MIC_OFFSETS]
mic_in_air = [(mx, my) for mx, my in zip(mic_xs, mic_ys) if 0 <= my < Nx and 0 <= mx < Ny and cave_grid[int(my), int(mx)] == 0]
if mic_in_air:
    ax.scatter([mx for mx,_ in mic_in_air], [my for _,my in mic_in_air], s=150, c='lime', marker='^',
               edgecolors='black', linewidths=2, label='8 Microphones (air only)', zorder=9)

ax.scatter([end_pos[1]], [end_pos[0]], s=400, c='red', marker='*',
           edgecolors='black', linewidths=2, label='Goal (Sound Source)', zorder=8)

ax.set_title(f'Full Cave Map - Agent at ({agent_y},{agent_x})
Action: {action_name}',
             fontsize=14, fontweight='bold')
ax.set_xlabel('Column Index', fontsize=12)
ax.set_ylabel('Row Index', fontsize=12)
ax.legend(loc='upper left', fontsize=10, framealpha=0.9)
ax.axis('image')
ax.grid(True, alpha=0.2, linestyle='--', linewidth=0.5)

# Right: Zoomed view around agent (same origin)
ax = axes[1]
zoom_radius = 10
y_min = max(0, agent_y - zoom_radius)
y_max = min(Nx, agent_y + zoom_radius + 1)
x_min = max(0, agent_x - zoom_radius)
x_max = min(Ny, agent_x + zoom_radius + 1)

cave_zoom = cave_grid[y_min:y_max, x_min:x_max]
extent_zoom = [x_min - 0.5, x_max - 0.5, y_min - 0.5, y_max - 0.5]
ax.imshow(cave_zoom, origin='upper', cmap='binary', alpha=0.8, extent=extent_zoom)

rect_zoom = Rectangle((agent_x - r - 0.5, agent_y - r - 0.5), 2*r + 1, 2*r + 1,
                      linewidth=4, edgecolor='blue', facecolor='blue', alpha=0.3)
ax.add_patch(rect_zoom)

if mic_in_air:
    ax.scatter([mx for mx,_ in mic_in_air], [my for _,my in mic_in_air], s=150, c='lime', marker='^',
               edgecolors='black', linewidths=2, label='Mics (air only)', zorder=9)

mic_labels = ['R', 'DR', 'D', 'DL', 'L', 'UL', 'U', 'UR']
for label, (mx, my) in zip(mic_labels, zip(mic_xs, mic_ys)):
    if 0 <= my < Nx and 0 <= mx < Ny and cave_grid[int(my), int(mx)] == 0:
        ax.text(mx, my - 0.6, label, fontsize=9, fontweight='bold',
                ha='center', va='top', color='darkgreen',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8, edgecolor='none'))

ax.set_title(f'Zoomed View (±{zoom_radius} cells)', fontsize=14, fontweight='bold')
ax.set_xlabel('Column Index', fontsize=12)
ax.set_ylabel('Row Index', fontsize=12)
ax.legend(loc='upper left', fontsize=10, framealpha=0.9)
ax.axis('image')
ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)

plt.tight_layout()

# Print summary
print(f"{'='*60}")
print(f"AGENT VERIFICATION")
print(f"{'='*60}")
print(f"Agent Position: ({agent_y}, {agent_x})")
print(f"Cave grid at agent: {cave_grid[agent_y, agent_x]} (0=air, 1=wall)")
print(f"3x3 footprint all air: {np.all(footprint == 0)}")
print(f"Action label: {action_name}")
print(f"Goal Position: {end_pos}")
print(f"Distance to goal: {np.sqrt((agent_y-end_pos[0])**2 + (agent_x-end_pos[1])**2):.1f} cells")
print(f"{'='*60}")


In [None]:
# Follow action labels from current agent position to goal
moves = {'up': (-1,0), 'down': (1,0), 'left': (0,-1), 'right': (0,1), 'stop': (0,0)}

def follow_path(start_y, start_x, max_steps=2000):
    path = [(int(start_y), int(start_x))]
    r, c = start_y, start_x
    for step in range(max_steps):
        a = action_grid_str[r, c]
        if a == 'stop':
            return True, path
        if a not in moves:
            return False, path
        dr, dc = moves[a]
        nr, nc = r + dr, c + dc
        if not (0 <= nr < Nx and 0 <= nc < Ny):
            return False, path
        if cave_grid[nr, nc] == 1:
            return False, path
        r, c = nr, nc
        path.append((int(r), int(c)))
    return False, path

ok, path = follow_path(agent_y, agent_x)
print(f"Path follow result: {'SUCCESS' if ok else 'FAIL'}, steps={len(path)-1}, end={path[-1]}")

fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.imshow(cave_grid, origin='upper', cmap='binary', alpha=0.8)
ys = [p[0] for p in path]; xs = [p[1] for p in path]
ax.plot(xs, ys, '-o', color='lime' if ok else 'red', linewidth=2, markersize=4)
ax.scatter([end_pos[1]], [end_pos[0]], s=200, c='red', marker='*', edgecolors='black', linewidths=1.5)
ax.set_title('Action-Label Path to Goal')
ax.set_xlabel('Column Index'); ax.set_ylabel('Row Index'); ax.axis('image')
plt.show()


In [None]:
# Validate mic footprint across all valid positions and sample a few
mic_hits_wall = 0
for (y, x) in dataset.valid_positions:
    for dy, dx in MIC_OFFSETS:
        my, mx = y + dy, x + dx
        if not (0 <= my < Nx and 0 <= mx < Ny) or cave_grid[my, mx] == 1:
            mic_hits_wall += 1
            break

print(f"Valid positions: {len(dataset.valid_positions)}")
print(f"Positions where any mic would hit a wall: {mic_hits_wall}")
assert mic_hits_wall == 0, "Some mic locations fall on walls; check MIC_OFFSETS or valid_positions logic"

import random
rng = random.Random(0)
samples = rng.sample(dataset.valid_positions, min(5, len(dataset.valid_positions)))
print("
Sampled positions (y, x): action, dist_to_goal")
for y, x in samples:
    action = action_grid_str[y, x]
    dist = ((y - end_pos[0])**2 + (x - end_pos[1])**2) ** 0.5
    print(f"  ({y}, {x}): {action.upper()}, dist={dist:.1f}")


## 7. Create DataLoader

Prepare PyTorch DataLoader for training

In [None]:
# DataLoader configuration
BATCH_SIZE = 32
TRAIN_SPLIT = 0.8
NUM_WORKERS = 0  # Use 0 for Windows, 4+ for Linux/Mac

# Split dataset
train_size = int(TRAIN_SPLIT * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Dataset Split:")
print(f"  Training: {len(train_dataset)} samples ({100*train_size/len(dataset):.0f}%)")
print(f"  Validation: {len(val_dataset)} samples ({100*val_size/len(dataset):.0f}%)")

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available()
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=torch.cuda.is_available()
)

print(f"\nDataLoaders:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")

# Test batch loading
sample_batch_data, sample_batch_actions = next(iter(train_loader))
print(f"\nSample Batch:")
print(f"  Data shape: {sample_batch_data.shape}")
print(f"  Actions shape: {sample_batch_actions.shape}")
print(f"  Actions in batch: {sample_batch_actions.tolist()[:10]}...")

# Action distribution in batch
batch_action_counts = Counter(sample_batch_actions.numpy())
print(f"\nAction distribution in sample batch:")
for i, name in enumerate(ACTION_NAMES):
    count = batch_action_counts.get(i, 0)
    print(f"  {name}: {count}")

## 8. Compute Class Weights for Training

To handle class imbalance (STOP is rare), compute class weights

In [None]:
# Compute class weights (inverse frequency)
action_counts = np.array([action_dist.get(name.lower(), 1) for name in ACTION_NAMES])
class_weights = 1.0 / action_counts
class_weights = class_weights / class_weights.sum() * len(ACTION_NAMES)  # Normalize

class_weights_tensor = torch.FloatTensor(class_weights)

print("Class Weights (for loss function):")
print("="*50)
for name, weight in zip(ACTION_NAMES, class_weights):
    print(f"  {name:>5s}: {weight:.4f}")
print("="*50)

print(f"\nTo use in training:")
print(f"  criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)")

# Save for later use
torch.save({
    'class_weights': class_weights_tensor,
    'action_names': ACTION_NAMES,
    'action_map': ACTION_MAP,
    'dataset_info': {
        'total_samples': len(dataset),
        'action_distribution': dict(action_dist),
        'grid_size': (Nx, Ny),
        'time_steps': Nt,
        'sampling_rate': 1/dt,
        'frequency': f0
    }
}, '../dataset/dataset_info.pt')

print(f"\n✓ Dataset info saved to ../dataset/dataset_info.pt")

## Summary

### Dataset Ready for Training!

**What we have:**
- ✓ PyTorch Dataset with 3x3 footprint validation
- ✓ DataLoaders for train/val splits
- ✓ Action distribution analysis
- ✓ Class weights to handle imbalance
- ✓ Verified 8-mic data extraction

**Next Steps:**
1. Design neural network architecture (CNN/RNN)
2. Implement training loop with class weights
3. Monitor per-class accuracy (especially STOP)
4. Evaluate on validation set