# Dataset Visualization & Analysis

Explore and visualize the acoustic navigation dataset.

## Visualizations:
1. Cave environments and action fields
2. Sample distribution across caves
3. Action distribution analysis
4. Acoustic signal examples
5. Dataset statistics and quality checks

In [None]:
import sys
sys.path.append('../')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import h5py
from tqdm.auto import tqdm

from data.audio_cave import AudioCave
from src.cave_dataset import (
    MultiCaveDataset,
    ACTION_MAP,
    ACTION_NAMES,
    MIC_OFFSETS,
    compute_class_distribution,
)

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')
%matplotlib inline

## 1. Load Dataset

In [None]:
# Dataset directory
DATASET_DIR = Path('D:/audiomaze_dataset_100')
H5_FILES = sorted(DATASET_DIR.glob('cave_*.h5'))

print(f"Found {len(H5_FILES)} cave files")
print(f"Dataset directory: {DATASET_DIR}")

# Load dataset
print("\nLoading dataset...")
dataset = MultiCaveDataset(H5_FILES, agent_radius=1, mic_offsets=MIC_OFFSETS, action_map=ACTION_MAP)
print(f"Total valid positions: {len(dataset):,}")

## 2. Dataset Statistics

In [None]:
# Class distribution
class_counts = compute_class_distribution(dataset)

print("Class Distribution:")
print("=" * 50)
total = sum(class_counts.values())
for action, count in class_counts.items():
    pct = 100 * count / total
    print(f"{action.upper():>6}: {count:>8,} samples ({pct:5.2f}%)")
print("=" * 50)
print(f"{'TOTAL':>6}: {total:>8,} samples")

# Per-file statistics
print("\nPer-File Statistics:")
print("=" * 50)
file_samples = [len(info['valid']) for info in dataset.file_infos]
print(f"Files: {len(H5_FILES)}")
print(f"Avg samples per file: {np.mean(file_samples):.1f}")
print(f"Min samples: {np.min(file_samples)}")
print(f"Max samples: {np.max(file_samples)}")
print(f"Std dev: {np.std(file_samples):.1f}")

In [None]:
# Visualize class distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
actions = list(class_counts.keys())
counts = list(class_counts.values())
colors = plt.cm.Set3(range(len(actions)))

bars = ax1.bar(actions, counts, color=colors, alpha=0.8, edgecolor='black')
ax1.set_xlabel('Action', fontsize=12)
ax1.set_ylabel('Count', fontsize=12)
ax1.set_title('Action Distribution (Linear Scale)', fontsize=14, fontweight='bold')
ax1.tick_params(axis='x', rotation=0)
ax1.grid(True, alpha=0.3, axis='y')

# Add counts on bars
for bar, count in zip(bars, counts):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
            f'{count:,}',
            ha='center', va='bottom', fontsize=10)

# Pie chart
ax2.pie(counts, labels=[a.upper() for a in actions], autopct='%1.1f%%',
       colors=colors, startangle=90, textprops={'fontsize': 11})
ax2.set_title('Action Distribution (Proportions)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

## 3. Visualize Sample Caves

In [None]:
# Load a few sample caves for visualization
sample_indices = [0, len(H5_FILES)//4, len(H5_FILES)//2, 3*len(H5_FILES)//4]

fig, axes = plt.subplots(2, 4, figsize=(18, 9))
axes = axes.flatten()

for i, file_idx in enumerate(sample_indices):
    info = dataset.file_infos[file_idx]
    
    # Cave layout
    ax = axes[i*2]
    ax.imshow(info['cave_grid'].T, origin='lower', cmap='binary')
    ax.scatter([info['start_pos'][0]], [info['start_pos'][1]], s=200, c='green',
              marker='o', edgecolors='black', linewidths=2, label='Start', zorder=10)
    ax.scatter([info['end_pos'][0]], [info['end_pos'][1]], s=200, c='red',
              marker='*', edgecolors='black', linewidths=2, label='Goal', zorder=10)
    ax.set_title(f'Cave {file_idx}: Layout', fontsize=11, fontweight='bold')
    ax.legend(fontsize=8)
    ax.axis('image')
    
    # Action field
    ax = axes[i*2 + 1]
    symbol_map = {"up": 1, "down": 2, "left": 3, "right": 4, "stop": 5, "": 0}
    action_numeric = np.vectorize(symbol_map.get)(info['action_grid'])
    ax.imshow(action_numeric.T, origin='lower', cmap='tab10', vmin=0, vmax=5, alpha=0.7)
    ax.contour(info['cave_grid'].T, levels=[0.5], colors='black', linewidths=0.5)
    ax.scatter([info['end_pos'][0]], [info['end_pos'][1]], s=200, c='red',
              marker='*', edgecolors='black', linewidths=2, zorder=10)
    ax.set_title(f'Cave {file_idx}: Action Field', fontsize=11, fontweight='bold')
    ax.axis('image')
    
    # Print stats
    valid_count = len(info['valid'])
    wall_pct = 100 * info['cave_grid'].mean()
    if i == 0:
        print(f"Cave {file_idx}: {valid_count:>4} valid positions, {wall_pct:>4.1f}% walls")

plt.tight_layout()
plt.show()

## 4. Acoustic Signal Analysis

In [None]:
# Load a sample acoustic signal
print("Loading sample acoustic data...")
sample_idx = 1000  # Random sample
mic_data, action, file_idx, position = dataset[sample_idx]

print(f"\nSample {sample_idx}:")
print(f"  Action: {ACTION_NAMES[action]}")
print(f"  File: {file_idx}")
print(f"  Position: ({position[0]}, {position[1]})")
print(f"  Signal shape: {mic_data.shape}")
print(f"  Signal dtype: {mic_data.dtype}")

# Convert to numpy for plotting
mic_array = mic_data.numpy()

print(f"\nSignal statistics (after normalization):")
print(f"  Mean: {mic_array.mean():.6f}")
print(f"  Std: {mic_array.std():.6f}")
print(f"  Min: {mic_array.min():.6f}")
print(f"  Max: {mic_array.max():.6f}")

In [None]:
# Visualize 8-mic signals
fig, axes = plt.subplots(4, 2, figsize=(14, 10))
axes = axes.flatten()

mic_labels = ['Right', 'Down-Right', 'Down', 'Down-Left', 'Left', 'Up-Left', 'Up', 'Up-Right']

for i in range(8):
    ax = axes[i]
    
    # Plot time series (first 1000 samples for visibility)
    time_slice = slice(0, 1000)
    ax.plot(mic_array[i, time_slice], linewidth=0.5, color=f'C{i}')
    ax.set_title(f'Mic {i+1}: {mic_labels[i]}', fontsize=11, fontweight='bold')
    ax.set_xlabel('Sample' if i >= 6 else '')
    ax.set_ylabel('Amplitude')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(mic_array.min(), mic_array.max())

fig.suptitle(f'8-Mic Array Signals (Action: {ACTION_NAMES[action]})', 
            fontsize=14, fontweight='bold', y=1.00)
plt.tight_layout()
plt.show()

In [None]:
# Spectral analysis
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(8):
    ax = axes[i]
    
    # Compute FFT
    signal = mic_array[i]
    fft = np.fft.fft(signal)
    freqs = np.fft.fftfreq(len(signal))
    
    # Plot magnitude spectrum (positive frequencies only)
    mask = freqs > 0
    ax.plot(freqs[mask], np.abs(fft[mask]), linewidth=1, color=f'C{i}')
    ax.set_title(f'Mic {i+1}: {mic_labels[i]}', fontsize=10)
    ax.set_xlabel('Frequency' if i >= 4 else '')
    ax.set_ylabel('Magnitude')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 0.5)  # Nyquist frequency

fig.suptitle('Frequency Domain Analysis (FFT)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Samples Per File Distribution

In [None]:
# Analyze sample distribution across files
file_samples = [len(info['valid']) for info in dataset.file_infos]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
ax1.hist(file_samples, bins=30, alpha=0.7, color='steelblue', edgecolor='black')
ax1.axvline(np.mean(file_samples), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(file_samples):.1f}')
ax1.axvline(np.median(file_samples), color='green', linestyle='--', linewidth=2, label=f'Median: {np.median(file_samples):.1f}')
ax1.set_xlabel('Valid Samples per Cave', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.set_title('Distribution of Samples per Cave', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Cumulative distribution
sorted_samples = np.sort(file_samples)
cumsum = np.cumsum(sorted_samples)
ax2.plot(range(len(sorted_samples)), cumsum, linewidth=2, color='steelblue')
ax2.fill_between(range(len(sorted_samples)), cumsum, alpha=0.3)
ax2.set_xlabel('Cave Index (sorted)', fontsize=12)
ax2.set_ylabel('Cumulative Samples', fontsize=12)
ax2.set_title('Cumulative Sample Distribution', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Files with < 500 samples: {sum(1 for s in file_samples if s < 500)}")
print(f"Files with > 1500 samples: {sum(1 for s in file_samples if s > 1500)}")

## 6. Action Distribution Across Cave Regions

In [None]:
# Analyze where different actions occur spatially
sample_file_idx = 0
info = dataset.file_infos[sample_file_idx]

fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.flatten()

# For each action type
action_types = ['up', 'down', 'left', 'right', 'stop']

for i, action_type in enumerate(action_types):
    ax = axes[i]
    
    # Create mask for this action
    action_mask = np.zeros_like(info['cave_grid'], dtype=float)
    for y in range(action_mask.shape[0]):
        for x in range(action_mask.shape[1]):
            if info['action_grid'][y][x] == action_type:
                action_mask[y, x] = 1.0
    
    # Plot
    ax.imshow(info['cave_grid'].T, origin='lower', cmap='binary', alpha=0.3)
    ax.imshow(action_mask.T, origin='lower', cmap='Reds', alpha=0.7, vmin=0, vmax=1)
    ax.scatter([info['end_pos'][0]], [info['end_pos'][1]], s=300, c='blue',
              marker='*', edgecolors='white', linewidths=2, zorder=10)
    ax.set_title(f'{action_type.upper()} Actions', fontsize=12, fontweight='bold')
    ax.axis('image')
    
    # Count
    count = int(action_mask.sum())
    ax.text(0.02, 0.98, f'{count} samples', transform=ax.transAxes,
           fontsize=11, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Remove extra subplot
fig.delaxes(axes[5])

fig.suptitle(f'Spatial Distribution of Actions (Cave {sample_file_idx})', 
            fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 7. Summary Statistics

In [None]:
print("=" * 70)
print(" " * 20 + "DATASET SUMMARY")
print("=" * 70)

print(f"\nDataset Location: {DATASET_DIR}")
print(f"Number of caves: {len(H5_FILES)}")
print(f"Total valid samples: {len(dataset):,}")

print(f"\nSamples per cave:")
print(f"  Mean: {np.mean(file_samples):.1f}")
print(f"  Median: {np.median(file_samples):.1f}")
print(f"  Min: {np.min(file_samples)}")
print(f"  Max: {np.max(file_samples)}")
print(f"  Std: {np.std(file_samples):.1f}")

print(f"\nClass distribution:")
for action, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
    pct = 100 * count / len(dataset)
    print(f"  {action.upper():>6}: {count:>8,} ({pct:5.2f}%)")

print(f"\nSignal properties:")
print(f"  Microphones: 8 (circular array)")
print(f"  Samples per signal: {mic_data.shape[1]:,}")
print(f"  Data type: {mic_data.dtype}")
print(f"  Normalization: Per-sample (mean=0, std=1)")

# Estimate dataset size
bytes_per_sample = 8 * mic_data.shape[1] * 4  # 8 mics, 11434 samples, float32
total_bytes = bytes_per_sample * len(dataset)
print(f"\nEstimated dataset size:")
print(f"  Per sample: {bytes_per_sample / (1024**2):.2f} MB")
print(f"  Total: {total_bytes / (1024**3):.2f} GB")

print("\n" + "=" * 70)
print("âœ“ Dataset analysis complete!")