# 05. Data Pipeline

**Goal**: Understand how OpenVLA loads and processes robot demonstration data.

## What We'll Learn
1. RLDS (Robot Learning Dataset Standard) format
2. Open X-Embodiment (OXE) dataset collection
3. Data loading and preprocessing
4. Dataset mixtures and sampling
5. Preparing custom data for OpenVLA

In [None]:
# ============================================================
# CRITICAL: Set these BEFORE importing any packages!
# ============================================================
import os

# For NERSC Perlmutter, use your $PSCRATCH directory
PSCRATCH = "/pscratch/sd/d/dpark1"  # CHANGE THIS TO YOUR PATH

# HuggingFace cache (models, tokenizers)
os.environ['HF_HOME'] = f"{PSCRATCH}/.cache/huggingface"

# TensorFlow Datasets (RLDS data) - THIS IS WHERE OXE DATA GOES
os.environ['TFDS_DATA_DIR'] = f"{PSCRATCH}/tensorflow_datasets"

# Torch Hub cache
os.environ['TORCH_HOME'] = f"{PSCRATCH}/.cache/torch"

# Create directories
for env_var in ['HF_HOME', 'TFDS_DATA_DIR', 'TORCH_HOME']:
    os.makedirs(os.environ[env_var], exist_ok=True)

print(f"✅ TFDS_DATA_DIR = {os.environ['TFDS_DATA_DIR']}")
print("   (All RLDS/OXE datasets will be stored here)")

---
## 1. RLDS: Robot Learning Dataset Standard

OpenVLA uses **RLDS** format for robot demonstration data.

### Why RLDS?
- Standardized format across different robots
- Efficient storage and loading with TensorFlow
- Supports large-scale datasets
- Compatible with many existing robot datasets

In [None]:
rlds_structure = """
┌────────────────────────────────────────────────────────────────────┐
│                      RLDS Dataset Structure                         │
├────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  Dataset                                                            │
│  └── Episode (trajectory)                                           │
│      └── Step (timestep)                                            │
│          ├── observation                                            │
│          │   ├── image_primary: (H, W, 3) RGB                      │
│          │   ├── image_wrist: (H, W, 3) RGB (optional)             │
│          │   └── state: Robot proprioception (optional)             │
│          ├── action: (action_dim,) Robot action                    │
│          ├── reward: Scalar (optional)                              │
│          ├── is_terminal: Boolean                                   │
│          └── language_instruction: String                           │
│                                                                     │
│  Example Episode:                                                   │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │ Task: "pick up the red block"                                │   │
│  │                                                               │   │
│  │ Step 0: [image] [action: approach]                           │   │
│  │ Step 1: [image] [action: lower]                              │   │
│  │ Step 2: [image] [action: grasp]                              │   │
│  │ ...                                                           │   │
│  │ Step N: [image] [action: done] is_terminal=True              │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└────────────────────────────────────────────────────────────────────┘
"""
print(rlds_structure)

In [None]:
# Standard RLDS data fields
import numpy as np

# Example step data structure
example_step = {
    'observation': {
        'image_primary': np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8),
        'image_wrist': np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8),
        'state': np.random.randn(7),  # Joint positions, velocities, etc.
    },
    'action': np.array([0.01, -0.02, 0.05, 0.0, 0.0, 0.1, 1.0]),  # 7-DoF
    'reward': 0.0,
    'is_terminal': False,
    'is_first': False,
    'is_last': False,
}

print("RLDS Step Structure:")
print("="*60)
for key, value in example_step.items():
    if isinstance(value, dict):
        print(f"{key}:")
        for k, v in value.items():
            if isinstance(v, np.ndarray):
                print(f"  {k}: shape={v.shape}, dtype={v.dtype}")
            else:
                print(f"  {k}: {v}")
    elif isinstance(value, np.ndarray):
        print(f"{key}: shape={value.shape}")
    else:
        print(f"{key}: {value}")

---
## 2. Open X-Embodiment (OXE) Dataset Collection

OpenVLA was trained on **OXE**, a collection of 970K robot trajectories.

In [None]:
# OXE dataset information
oxe_datasets = {
    'bridge_orig': {
        'trajectories': 60000,
        'robot': 'WidowX',
        'tasks': 'Kitchen manipulation',
        'source': 'Berkeley'
    },
    'fractal20220817_data': {
        'trajectories': 130000,
        'robot': 'Everyday Robot',
        'tasks': 'RT-1 dataset',
        'source': 'Google'
    },
    'taco_play': {
        'trajectories': 15000,
        'robot': 'KUKA iiwa',
        'tasks': 'Table-top manipulation',
        'source': 'TU Darmstadt'
    },
    'kuka': {
        'trajectories': 50000,
        'robot': 'KUKA iiwa',
        'tasks': 'Grasping',
        'source': 'Google'
    },
    'droid': {
        'trajectories': 350000,
        'robot': 'Franka',
        'tasks': 'Diverse manipulation',
        'source': 'DROID consortium'
    }
}

print("OXE Dataset Collection (sample):")
print("="*70)
print(f"{'Dataset':<25} {'Trajs':>10} {'Robot':<15} {'Tasks'}")
print("-"*70)
for name, info in oxe_datasets.items():
    print(f"{name:<25} {info['trajectories']:>10,} {info['robot']:<15} {info['tasks']}")
print("-"*70)
total = sum(d['trajectories'] for d in oxe_datasets.values())
print(f"{'Sample total':<25} {total:>10,}")
print(f"\nFull OXE: ~970,000 trajectories from 22 robot embodiments")

---
## 3. Dataset Configuration in OpenVLA

In [None]:
import os
REPO_ROOT = "/Users/davidpark/Documents/Claude/openvla"

# Read OXE config file
oxe_config_path = os.path.join(REPO_ROOT, "prismatic/vla/datasets/rlds/oxe/configs.py")
print(f"OXE Config location: {oxe_config_path}")
print("\nThis file defines per-dataset standardization:")
print("  - Image keys to extract")
print("  - Action normalization type")
print("  - Proprioceptive state keys")
print("  - Dataset-specific transforms")

In [None]:
# Example dataset config structure
example_dataset_config = {
    'bridge_orig': {
        'image_obs_keys': {
            'primary': 'image_0',
            'wrist': None,
        },
        'depth_obs_keys': {'primary': None},
        'proprio_obs_key': 'state',
        'language_key': 'language_instruction',
        'action_normalization_type': 'normal',  # mean/std
        'action_proprio_normalization_type': 'normal',
    },
    'libero_spatial_no_noops': {
        'image_obs_keys': {
            'primary': 'agentview_image',
            'wrist': None,
        },
        'depth_obs_keys': {'primary': None},
        'proprio_obs_key': 'ee_states',
        'language_key': 'language_instruction',
        'action_normalization_type': 'bounds',  # min/max
    }
}

print("Dataset Configuration Examples:")
print("="*60)
for dataset, config in example_dataset_config.items():
    print(f"\n{dataset}:")
    for key, value in config.items():
        print(f"  {key}: {value}")

---
## 4. Data Mixtures and Sampling

OpenVLA uses **dataset mixtures** to balance training across different datasets.

In [None]:
# Example data mixture configuration
oxe_magic_soup = [
    ("fractal20220817_data", 0.54),
    ("bridge_orig", 1.0),
    ("taco_play", 2.0),
    ("berkeley_cable_routing", 3.0),
    ("roboturk", 1.0),
    ("nyu_door_opening_surprising_effectiveness", 5.0),
    ("viola", 2.0),
    ("berkeley_autolab_ur5", 2.0),
    ("toto", 1.0),
    ("language_table", 0.1),
]

print("OXE Magic Soup Mixture (sample):")
print("="*60)
print(f"{'Dataset':<45} {'Weight':>10}")
print("-"*60)

# Normalize weights for visualization
total_weight = sum(w for _, w in oxe_magic_soup)
for dataset, weight in oxe_magic_soup:
    pct = weight / total_weight * 100
    bar = "█" * int(pct / 2)
    print(f"{dataset:<45} {weight:>6.2f}  {bar}")

print(f"\nTotal weight: {total_weight}")
print("\nWeights control sampling frequency during training.")
print("Higher weight = more samples from that dataset per epoch.")

In [None]:
# Simulate mixture sampling
import numpy as np

def sample_from_mixture(mixture, n_samples=1000):
    """Simulate sampling from a dataset mixture."""
    datasets = [d for d, _ in mixture]
    weights = np.array([w for _, w in mixture])
    probs = weights / weights.sum()
    
    samples = np.random.choice(datasets, size=n_samples, p=probs)
    
    # Count samples per dataset
    counts = {}
    for d in datasets:
        counts[d] = (samples == d).sum()
    
    return counts

# Simulate 10000 samples
sample_counts = sample_from_mixture(oxe_magic_soup, n_samples=10000)

print("\nSimulated Sampling (10,000 samples):")
print("="*60)
for dataset, count in sorted(sample_counts.items(), key=lambda x: -x[1]):
    pct = count / 10000 * 100
    print(f"{dataset:<45} {count:>5} ({pct:.1f}%)")

---
## 5. Data Loading Pipeline

In [None]:
data_pipeline = """
┌────────────────────────────────────────────────────────────────────┐
│                    OpenVLA Data Loading Pipeline                    │
├────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  RLDS Files (TFRecord)                                              │
│       │                                                             │
│       ▼                                                             │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ 1. LOAD: TensorFlow data loader                              │  │
│  │    - Efficient parallel reading                               │  │
│  │    - Automatic sharding for multi-GPU                        │  │
│  └──────────────────────────────────────────────────────────────┘  │
│       │                                                             │
│       ▼                                                             │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ 2. STANDARDIZE: Unified format across datasets                │  │
│  │    - Map image keys (agentview → primary)                    │  │
│  │    - Extract language instruction                             │  │
│  │    - Normalize actions                                        │  │
│  └──────────────────────────────────────────────────────────────┘  │
│       │                                                             │
│       ▼                                                             │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ 3. TRANSFORM: Apply data augmentation                         │  │
│  │    - Resize images (224×224)                                 │  │
│  │    - Task augmentation (rephrase instructions)               │  │
│  │    - Goal relabeling (optional)                              │  │
│  └──────────────────────────────────────────────────────────────┘  │
│       │                                                             │
│       ▼                                                             │
│  ┌──────────────────────────────────────────────────────────────┐  │
│  │ 4. BATCH: Create training batches                            │  │
│  │    - Shuffle within/across trajectories                      │  │
│  │    - Sample from mixture weights                              │  │
│  │    - Collate into PyTorch tensors                            │  │
│  └──────────────────────────────────────────────────────────────┘  │
│       │                                                             │
│       ▼                                                             │
│  Training Batch:                                                    │
│  {                                                                  │
│    'pixel_values': (B, C, H, W),                                   │
│    'input_ids': (B, seq_len),                                      │
│    'attention_mask': (B, seq_len),                                 │
│    'labels': (B, seq_len),  # action tokens                        │
│  }                                                                  │
│                                                                     │
└────────────────────────────────────────────────────────────────────┘
"""
print(data_pipeline)

---
## 6. Action Statistics for Normalization

Each dataset has computed statistics for action normalization.

In [None]:
# Example action statistics (from OpenVLA)
action_statistics = {
    'bridge_orig': {
        'mean': [0.00021961, 0.00## 015, -0.00028, 0.00013, 0.00033, -0.00019, 0.49],
        'std': [0.0074, 0.0058, 0.0074, 0.026, 0.024, 0.052, 0.50],
        'min': [-0.05, -0.04, -0.05, -0.17, -0.16, -0.35, 0.0],
        'max': [0.05, 0.04, 0.05, 0.17, 0.16, 0.35, 1.0],
    },
    'libero_spatial_no_noops': {
        'q01': [-0.065, -0.065, -0.055, -0.17, -0.17, -0.42, -1.0],
        'q99': [0.065, 0.065, 0.055, 0.17, 0.17, 0.42, 1.0],
    }
}

print("Action Statistics Examples:")
print("="*60)
for dataset, stats in action_statistics.items():
    print(f"\n{dataset}:")
    for stat_name, values in stats.items():
        print(f"  {stat_name}: {[f'{v:.4f}' for v in values]}")

---
## 7. Preparing Custom Data for OpenVLA

To fine-tune OpenVLA on your own data, convert it to RLDS format.

In [None]:
# Example: Converting custom data to RLDS-compatible format
def create_rlds_example(image, action, instruction, is_terminal=False):
    """
    Create a single RLDS step from custom data.
    
    Args:
        image: RGB image array (H, W, 3)
        action: Robot action array (7,)
        instruction: Task description string
        is_terminal: Whether this is the last step
    """
    return {
        'observation': {
            'image_primary': image,
        },
        'action': action,
        'language_instruction': instruction,
        'is_terminal': is_terminal,
        'is_first': False,
        'is_last': is_terminal,
    }

# Example trajectory
trajectory = []
instruction = "pick up the red cube"
for i in range(10):
    image = np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8)
    action = np.random.randn(7) * 0.01  # Small random actions
    step = create_rlds_example(image, action, instruction, is_terminal=(i == 9))
    trajectory.append(step)

print(f"Example trajectory: {len(trajectory)} steps")
print(f"Task: '{instruction}'")

In [None]:
# Register custom dataset with OpenVLA
custom_dataset_config = """
# Add to prismatic/vla/datasets/rlds/oxe/configs.py:

OXE_DATASET_CONFIGS = {
    # ... existing configs ...
    
    'my_custom_dataset': {
        'image_obs_keys': {
            'primary': 'image_primary',
            'wrist': None,
        },
        'depth_obs_keys': {'primary': None},
        'proprio_obs_key': None,
        'language_key': 'language_instruction',
        'action_normalization_type': 'normal',  # or 'bounds'
    }
}
"""
print("Custom Dataset Registration:")
print(custom_dataset_config)

---
## Summary

### Key Concepts

1. **RLDS Format**: Standardized format for robot trajectory data
   - Episodes contain steps with observations, actions, instructions
   - Efficient TensorFlow-based storage

2. **OXE Collection**: 970K trajectories from 22+ robot embodiments
   - Diverse tasks and environments
   - Unified under RLDS standard

3. **Data Mixtures**: Weighted sampling across datasets
   - Balances data distribution during training
   - Configurable for different training objectives

4. **Normalization**: Per-dataset action statistics
   - Maps actions to [-1, 1] range
   - Critical for action tokenization

5. **Custom Data**: Convert your data to RLDS format
   - Register with dataset configs
   - Compute action statistics

### Next Steps
→ Continue to **06_basic_inference.ipynb** to run OpenVLA on sample data.