# 02 — Data Pipeline

Collect demonstrations, store in HDF5, apply augmentations, and inspect the data.

This notebook walks through:
1. Scripted demonstration collection
2. Trajectory storage and retrieval
3. Visual and geometry augmentation
4. Dataset statistics

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 5)

## 2.1 — Generate a Disassembly Plan

In [None]:
from safedisassemble.sim.device_registry import get_device
from safedisassemble.data.demo_collector import ScriptedDisassemblyPolicy

spec = get_device('laptop_v1')
policy = ScriptedDisassemblyPolicy(spec, noise_std=0.002, seed=42)

plan = policy.generate_plan()

print(f"Disassembly plan ({len(plan)} steps):\n")
for i, step in enumerate(plan):
    safety = ' ⚠️ SAFETY' if 'battery' in step['component'].lower() else ''
    print(f"Step {i+1}: [{step['skill']:18s}] {step['component']}{safety}")
    print(f"   High: {step['instructions']['high']}")
    print(f"   Mid:  {step['instructions']['mid']}")
    print(f"   Low:  {step['instructions']['low']}")
    print()

## 2.2 — Collect Demonstration Trajectories

In [None]:
from safedisassemble.sim.envs.disassembly_env import DisassemblyEnv
from safedisassemble.data.trajectory import TrajectoryDataset

env = DisassemblyEnv(device_name='laptop_v1', image_size=84, max_steps=200)

demo_path = Path('../data/trajectories/notebook_demos.h5')
demo_path.parent.mkdir(parents=True, exist_ok=True)

N_DEMOS = 5  # increase for real training

with TrajectoryDataset(demo_path, mode='w') as ds:
    for i in range(N_DEMOS):
        traj = policy.collect_trajectory(env)
        ds.add_trajectory(traj)
        print(f"  Traj {i+1}: {traj.length} steps, "
              f"recovered={traj.components_recovered}, "
              f"success={traj.success}")
    
    stats = ds.get_statistics()

env.close()
print(f"\nDataset: {stats['num_trajectories']} trajs, {stats['total_timesteps']} timesteps")
print(f"Success rate: {stats['success_rate']:.1%}")

## 2.3 — Inspect a Trajectory

In [None]:
with TrajectoryDataset(demo_path, mode='r') as ds:
    traj = ds.get_trajectory(0)

print(f"Trajectory length: {traj.length}")
print(f"Device: {traj.device_name}")
print(f"Success: {traj.success}")
print(f"Components recovered: {traj.components_recovered}")

# Show subtask segments
segments = traj.get_subtask_segments()
print(f"\nSubtask segments ({len(segments)}):")
for start, end, instruction in segments:
    print(f"  [{start:4d} - {end:4d}] {instruction}")

In [None]:
# Plot trajectory data
ee_positions = np.array([ts.ee_pos for ts in traj.timesteps])
ee_forces = np.array([np.linalg.norm(ts.ee_force) for ts in traj.timesteps])
actions = np.array([ts.action for ts in traj.timesteps])

fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True)

# EE position
axes[0].plot(ee_positions[:, 0], label='x')
axes[0].plot(ee_positions[:, 1], label='y')
axes[0].plot(ee_positions[:, 2], label='z')
axes[0].set_ylabel('EE Position (m)')
axes[0].legend()
axes[0].set_title('End-Effector Trajectory')

# EE force
axes[1].plot(ee_forces, color='red', alpha=0.7)
axes[1].axhline(y=15.0, color='red', linestyle='--', alpha=0.3, label='Battery puncture threshold')
axes[1].set_ylabel('EE Force (N)')
axes[1].legend()

# Actions
axes[2].plot(actions[:, :3], alpha=0.7)
axes[2].plot(actions[:, 6], 'k--', alpha=0.5, label='gripper')
axes[2].set_ylabel('Action')
axes[2].set_xlabel('Timestep')
axes[2].legend(['dx', 'dy', 'dz', 'gripper'])

# Mark subtask boundaries
for start, end, inst in segments:
    for ax in axes:
        ax.axvline(x=start, color='gray', linestyle=':', alpha=0.3)

plt.tight_layout()
plt.show()

## 2.4 — Visual Augmentation

In [None]:
from safedisassemble.data.augmentation import VisualAugmentor

augmentor = VisualAugmentor(seed=42)

# Get a sample image from the trajectory
original = traj.timesteps[0].image_overhead

# Generate augmented versions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes[0, 0].imshow(original)
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')

for i in range(1, 8):
    row, col = divmod(i, 4)
    aug = augmentor.augment(original)
    axes[row, col].imshow(aug)
    axes[row, col].set_title(f'Augmented {i}')
    axes[row, col].axis('off')

plt.suptitle('Visual Augmentation Examples (color jitter, noise, cutout)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Cleanup
if demo_path.exists():
    demo_path.unlink()
    print(f"Cleaned up {demo_path}")