# Dataset Exploration

Explore and visualize the Multi-Dimensional AI dataset.

**Purpose:**
- Load and inspect dataset samples
- Visualize vision/audio/sensor data
- Verify data quality and preprocessing
- Analyze data distribution

In [None]:
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from src.data.multimodal_dataset import SyntheticMultiModalDataset
from src.data.real_dataset import RealMultiModalDataset
from torch.utils.data import DataLoader

## Load Dataset

In [None]:
# Load synthetic dataset for exploration
dataset = SyntheticMultiModalDataset(
	num_samples=100,
	seq_length=64,
	vocab_size=1000
)

print(f"Dataset size: {len(dataset)}")
print(f"Sample 0 keys: {dataset[0].keys()}")

## Inspect Sample Data

In [None]:
# Get a sample
sample = dataset[0]

print("Input shapes:")
for key, value in sample['inputs'].items():
	if isinstance(value, torch.Tensor):
		print(f"  {key}: {value.shape}")

print("\nTarget shapes:")
for key, value in sample['targets'].items():
	if isinstance(value, torch.Tensor):
		print(f"  {key}: {value.shape}")

## Visualize Vision Input

In [None]:
# Visualize left eye image if available
if 'vision_left' in sample['inputs']:
	img = sample['inputs']['vision_left']
	
	# Convert from CHW to HWC for visualization
	if img.dim() == 3:
		img_np = img.permute(1, 2, 0).numpy()
		
		plt.figure(figsize=(8, 8))
		plt.imshow(img_np)
		plt.title('Left Eye Vision Input')
		plt.axis('off')
		plt.show()
else:
	print("No vision data in this sample")

## Visualize Audio Waveform

In [None]:
# Visualize audio waveform if available
if 'audio' in sample['inputs']:
	audio = sample['inputs']['audio'].numpy()
	
	plt.figure(figsize=(12, 4))
	plt.plot(audio)
	plt.title('Audio Input Waveform')
	plt.xlabel('Sample')
	plt.ylabel('Amplitude')
	plt.grid(True)
	plt.show()
else:
	print("No audio data in this sample")

## Analyze Sensor Data Distribution

In [None]:
# Collect touch sensor values across dataset
touch_values = []
for i in range(min(len(dataset), 50)):  # Sample first 50
	sample = dataset[i]
	if 'touch' in sample['inputs']:
		touch_values.append(sample['inputs']['touch'].numpy())

if touch_values:
	touch_array = np.array(touch_values)
	
	plt.figure(figsize=(10, 6))
	for finger in range(touch_array.shape[1]):
		plt.hist(touch_array[:, finger], alpha=0.5, label=f'Finger {finger}')
	
	plt.xlabel('Touch Value')
	plt.ylabel('Frequency')
	plt.title('Touch Sensor Value Distribution')
	plt.legend()
	plt.grid(True)
	plt.show()

## Data Validation

In [None]:
# Validate data ranges
from src.data.validation import validate_input_shapes, validate_value_ranges

try:
	validate_input_shapes(sample['inputs'])
	print("✓ Input shapes are valid")
except ValueError as e:
	print(f"✗ Shape validation failed: {e}")

try:
	validate_value_ranges(sample['inputs'])
	print("✓ Value ranges are valid")
except ValueError as e:
	print(f"✗ Range validation failed: {e}")