# 02. OpenVLA Architecture Overview

**Goal**: Understand the high-level architecture of OpenVLA and how its components work together.

## What We'll Learn
1. Vision-Language-Action (VLA) model concept
2. OpenVLA's three main components
3. Information flow from image → action
4. Key design decisions and their rationale

---
## 1. What is a VLA Model?

A **Vision-Language-Action (VLA)** model is a neural network that:
- **Takes in**: RGB image(s) from robot camera(s) + natural language instruction
- **Outputs**: Robot action (typically 7D: 6 DoF pose + gripper)

```
┌─────────────────────────────────────────────────────────────┐
│                     VLA Model Pipeline                       │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│   [Camera Image]     [Language Instruction]                  │
│         │                     │                              │
│         ▼                     ▼                              │
│   ┌──────────┐         ┌──────────┐                         │
│   │  Vision  │         │   Text   │                         │
│   │ Encoder  │         │ Encoder  │                         │
│   └────┬─────┘         └────┬─────┘                         │
│        │                    │                                │
│        └──────────┬─────────┘                                │
│                   ▼                                          │
│           ┌──────────────┐                                   │
│           │   Backbone   │                                   │
│           │     LLM      │                                   │
│           └──────┬───────┘                                   │
│                  ▼                                           │
│           ┌──────────────┐                                   │
│           │    Action    │                                   │
│           │    Output    │                                   │
│           └──────────────┘                                   │
│                  │                                           │
│                  ▼                                           │
│     [7D Action: x, y, z, rx, ry, rz, gripper]               │
│                                                              │
└─────────────────────────────────────────────────────────────┘
```

---
## 2. OpenVLA's Architecture Components

OpenVLA builds on **PrismaticVLM** and consists of three main components:

```
┌────────────────────────────────────────────────────────────────────┐
│                        OpenVLA Architecture                         │
├────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  Component 1: VISION BACKBONE                                       │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │  DINOv2-ViT (semantic features) + SigLIP-ViT (text-aligned) │   │
│  │                                                              │   │
│  │  Image (224×224) → Patch Embeddings (576 tokens × 1024 dim) │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  Component 2: PROJECTOR                                            │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │           MLP Projector (2-layer with GeLU)                  │   │
│  │                                                              │   │
│  │  Vision dim (1024) → LLM dim (4096)                         │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  Component 3: LLM BACKBONE                                         │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │               Llama-2 7B (autoregressive)                    │   │
│  │                                                              │   │
│  │  [Vision tokens] + [Text tokens] → [Action tokens]          │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  OUTPUT: Action Tokens → ActionTokenizer → Continuous Actions       │
│                                                                     │
└────────────────────────────────────────────────────────────────────┘
```

In [None]:
# ============================================================
# CRITICAL: Set environment variables BEFORE importing packages!
# ============================================================
import os

# For NERSC Perlmutter, use your $SCRATCH directory
# Auto-detect environment (NERSC vs SciServer)
import os
if os.environ.get('SCRATCH'):
    SCRATCH = os.environ['SCRATCH']  # NERSC Perlmutter
elif os.environ.get('SCRATCH'):
    SCRATCH = os.environ['SCRATCH']  # Generic scratch
else:
    SCRATCH = "/home/idies/workspace/Temporary/dpark1/scratch"  # SciServer default  # CHANGE THIS TO YOUR PATH
CACHE_DIR = f"{SCRATCH}/.cache"

os.environ['XDG_CACHE_HOME'] = CACHE_DIR
os.environ['HF_HOME'] = f"{CACHE_DIR}/huggingface"
os.environ['TFDS_DATA_DIR'] = f"{CACHE_DIR}/tensorflow_datasets"
os.environ['TORCH_HOME'] = f"{CACHE_DIR}/torch"

for path in [CACHE_DIR, os.environ['HF_HOME'], os.environ['TFDS_DATA_DIR'], os.environ['TORCH_HOME']]:
    os.makedirs(path, exist_ok=True)

print(f"✅ All caches → {CACHE_DIR}")

# Now import packages
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor

MODEL_ID = "openvla/openvla-7b"

print("\nLoading OpenVLA model...")
vla = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
)
print("Model loaded!")

In [None]:
# Explore the model's top-level structure
print("OpenVLA Top-Level Components:")
print("="*60)
for name, child in vla.named_children():
    num_params = sum(p.numel() for p in child.parameters())
    print(f"{name}: {num_params/1e6:.1f}M parameters")

In [None]:
# Detailed breakdown of model architecture
def print_model_structure(model, max_depth=2, prefix=""):
    """Print model architecture with parameter counts."""
    for name, child in model.named_children():
        num_params = sum(p.numel() for p in child.parameters())
        print(f"{prefix}{name}: {num_params/1e6:.1f}M params")
        
        if max_depth > 1:
            print_model_structure(child, max_depth-1, prefix + "  ")

print("\nDetailed Model Structure:")
print("="*60)
print_model_structure(vla, max_depth=3)

---
## 3. Component Deep Dive

### 3.1 Vision Backbone

OpenVLA uses a **dual vision encoder**:
- **DINOv2**: Self-supervised, captures rich semantic features
- **SigLIP**: Text-aligned, optimized for language-vision correspondence

In [None]:
# Explore vision backbone - discover actual attribute names
print("Vision Backbone Configuration:")
print("="*60)

# Find vision-related components dynamically
for name, child in vla.named_children():
    if 'vision' in name.lower() or 'image' in name.lower():
        params = sum(p.numel() for p in child.parameters())
        print(f"Found: {name} ({params/1e6:.1f}M params)")
        print(f"  Type: {type(child).__name__}")

# Alternative: print full model to see structure
print("\nFull model structure (first level):")
for name, child in vla.named_children():
    params = sum(p.numel() for p in child.parameters())
    print(f"  {name}: {type(child).__name__} ({params/1e6:.1f}M params)")

In [None]:
# Check image transform configuration
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

print("\nImage Processing Configuration:")
print("="*60)
if hasattr(processor, 'image_processor'):
    img_proc = processor.image_processor
    print(f"Image size: {getattr(img_proc, 'size', 'N/A')}")
    print(f"Mean: {getattr(img_proc, 'image_mean', 'N/A')}")
    print(f"Std: {getattr(img_proc, 'image_std', 'N/A')}")

### 3.2 Projector

The projector maps vision features to the LLM's embedding space.

In [None]:
# Explore projector - discover actual attribute names
print("Projector Configuration:")
print("="*60)

# Find projector-related components
for name, child in vla.named_children():
    if 'project' in name.lower() or 'mlp' in name.lower() or 'connector' in name.lower():
        params = sum(p.numel() for p in child.parameters())
        print(f"Found: {name} ({params/1e6:.1f}M params)")
        print(f"  Type: {type(child).__name__}")
        
        # Try to get layer dimensions
        for subname, layer in child.named_modules():
            if hasattr(layer, 'in_features'):
                print(f"    {subname}: {layer.in_features} → {layer.out_features}")

### 3.3 LLM Backbone

OpenVLA uses Llama-2 7B as its language model backbone.

In [None]:
# Explore LLM backbone - discover actual attribute names
print("LLM Backbone Configuration:")
print("="*60)

# Find language model components
for name, child in vla.named_children():
    if 'llm' in name.lower() or 'language' in name.lower() or 'lm' in name.lower() or 'model' in name.lower():
        params = sum(p.numel() for p in child.parameters())
        print(f"Found: {name} ({params/1e6:.1f}M params)")
        print(f"  Type: {type(child).__name__}")
        
        # Try to get config
        if hasattr(child, 'config'):
            config = child.config
            print(f"\n  Config attributes:")
            for attr in ['hidden_size', 'num_hidden_layers', 'num_attention_heads', 'vocab_size']:
                if hasattr(config, attr):
                    print(f"    {attr}: {getattr(config, attr)}")

# Also check top-level config
if hasattr(vla, 'config'):
    print(f"\nTop-level model config type: {type(vla.config).__name__}")

---
## 4. Information Flow: Image → Action

Let's trace how an image and instruction become a robot action.

In [None]:
from PIL import Image
import numpy as np

# Create a sample input
sample_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
sample_instruction = "Pick up the red block"

print("Step 1: Raw Input")
print("="*60)
print(f"Image shape: {np.array(sample_image).shape}")
print(f"Instruction: '{sample_instruction}'")

In [None]:
# Step 2: Process inputs with processor
print("\nStep 2: Processor Output")
print("="*60)

# Format instruction as prompt
prompt = f"In: What action should the robot take to {sample_instruction.lower()}?\nOut:"

# Process image and text
inputs = processor(prompt, sample_image)

print(f"Input keys: {list(inputs.keys())}")
print(f"\nPixel values shape: {inputs['pixel_values'].shape}")
print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Attention mask shape: {inputs['attention_mask'].shape}")

In [None]:
# Step 3: Decode input tokens to see the prompt structure
print("\nStep 3: Tokenized Prompt")
print("="*60)

tokenizer = processor.tokenizer
decoded = tokenizer.decode(inputs['input_ids'][0])
print(f"Decoded prompt:\n{decoded}")

print(f"\nNumber of text tokens: {len(inputs['input_ids'][0])}")

In [None]:
# Step 4: Vision encoding (trace the forward pass)
print("\nStep 4: Vision Encoding")
print("="*60)

# Move model and inputs to device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
vla = vla.to(device)

# IMPORTANT: Convert inputs to correct dtype (bfloat16 to match model)
inputs_device = {}
for k, v in inputs.items():
    if isinstance(v, torch.Tensor):
        if v.dtype == torch.float32:
            # Convert float tensors to bfloat16 to match model
            inputs_device[k] = v.to(device, dtype=torch.bfloat16)
        else:
            inputs_device[k] = v.to(device)
    else:
        inputs_device[k] = v

print(f"Device: {device}")
print(f"Pixel values shape: {inputs_device['pixel_values'].shape}")
print(f"Pixel values dtype: {inputs_device['pixel_values'].dtype}")  # Should be bfloat16

print("\nVision encoding converts the image into tokens that the LLM can process.")
print("Typical flow: Image → Patch embeddings → Vision transformer → Projected tokens")

In [None]:
# Step 5: Projection to LLM Space
print("\nStep 5: Projection to LLM Space")
print("="*60)

print("The projector (MLP) maps vision features to the LLM's embedding dimension.")
print("Typical transformation: vision_dim (e.g., 1024) → llm_dim (e.g., 4096)")
print("\nThis allows the LLM to 'see' the image as if it were a sequence of tokens.")

In [None]:
# Step 6: Action generation
print("\nStep 6: Action Token Generation")
print("="*60)

# The model generates action tokens autoregressively
# These tokens are from the extended vocabulary (last 256 tokens)

vocab_size = vla.config.text_config.vocab_size
print(f"Total vocabulary size: {vocab_size}")
print(f"Action tokens occupy: {vocab_size - 256} to {vocab_size - 1}")
print(f"Number of action bins: 256")

In [None]:
# Step 7: Generate actual action
print("\nStep 7: Full Forward Pass (Inference)")
print("="*60)

# IMPORTANT: Must specify unnorm_key since model was trained on multiple datasets
# Available keys can be found with: list(vla.norm_stats.keys())
UNNORM_KEY = "bridge_orig"  # Good for tabletop manipulation

print(f"Using unnorm_key: '{UNNORM_KEY}'")
print(f"\nAvailable unnorm_key options: {len(vla.norm_stats)} datasets")
print("  (Use list(vla.norm_stats.keys()) to see all)")

with torch.no_grad():
    # Use the built-in predict_action method
    action = vla.predict_action(
        **inputs_device,
        unnorm_key=UNNORM_KEY,  # REQUIRED for multi-dataset models
        do_sample=False,        # Greedy decoding for determinism
    )
    
print(f"\nOutput action shape: {action.shape}")
print(f"Action values (normalized [-1, 1]):")
print(f"  x:       {action[0]:+.4f}")
print(f"  y:       {action[1]:+.4f}")
print(f"  z:       {action[2]:+.4f}")
print(f"  roll:    {action[3]:+.4f}")
print(f"  pitch:   {action[4]:+.4f}")
print(f"  yaw:     {action[5]:+.4f}")
print(f"  gripper: {action[6]:+.4f}")

---
## 5. Key Design Decisions

### 5.1 Why Dual Vision Encoders?

| Encoder | Strength | What it Captures |
|---------|----------|------------------|
| DINOv2 | Self-supervised learning | Rich semantic features, object boundaries, spatial understanding |
| SigLIP | Contrastive text-image learning | Text-aligned representations, compositional understanding |

**Combination**: DINOv2's semantic features + SigLIP's language alignment = better instruction following.

In [None]:
# Visualize the dual encoder concept
dual_encoder_diagram = """
┌────────────────────────────────────────────────────────────────┐
│                    Dual Vision Encoder                          │
├────────────────────────────────────────────────────────────────┤
│                                                                 │
│                    [Input Image 224×224]                        │
│                           │                                     │
│              ┌────────────┴────────────┐                       │
│              │                         │                        │
│              ▼                         ▼                        │
│     ┌──────────────┐          ┌──────────────┐                 │
│     │   DINOv2     │          │   SigLIP     │                 │
│     │   ViT-L/14   │          │   ViT-L/14   │                 │
│     │              │          │              │                 │
│     │ Self-        │          │ Text-aligned │                 │
│     │ supervised   │          │ contrastive  │                 │
│     └──────┬───────┘          └──────┬───────┘                 │
│            │                         │                          │
│            │  [B, 256, 1024]        │  [B, 256, 1024]          │
│            │                         │                          │
│            └──────────┬──────────────┘                         │
│                       │                                         │
│                       ▼                                         │
│              ┌──────────────┐                                   │
│              │ Concatenate  │                                   │
│              │ + Project    │                                   │
│              └──────┬───────┘                                   │
│                     │                                           │
│                     ▼                                           │
│             [B, 576, 4096]                                      │
│         (LLM-compatible vision tokens)                          │
│                                                                 │
└────────────────────────────────────────────────────────────────┘
"""
print(dual_encoder_diagram)

### 5.2 Why Discretize Actions?

OpenVLA treats action prediction as a **token classification problem**, not regression.

**Benefits**:
1. **Leverages LLM strengths**: LLMs excel at discrete token prediction
2. **Unified training**: Same loss function for language and actions
3. **Multi-modal robustness**: Better generalization across different robot embodiments
4. **Simple integration**: Works with standard autoregressive decoding

In [None]:
# Demonstrate action discretization
import numpy as np

def discretize_action(continuous_action, n_bins=256):
    """Convert continuous action [-1, 1] to discrete bin."""
    # Clip to valid range
    clipped = np.clip(continuous_action, -1, 1)
    # Map [-1, 1] to [0, n_bins-1]
    bin_idx = np.digitize(clipped, np.linspace(-1, 1, n_bins)) - 1
    return np.clip(bin_idx, 0, n_bins - 1)

def undiscretize_action(bin_idx, n_bins=256):
    """Convert discrete bin back to continuous action."""
    # Map [0, n_bins-1] to [-1, 1]
    bin_centers = np.linspace(-1, 1, n_bins)
    return bin_centers[bin_idx]

# Example
original = 0.35
discretized = discretize_action(original)
reconstructed = undiscretize_action(discretized)

print("Action Discretization Example:")
print(f"  Original continuous: {original}")
print(f"  Discretized bin: {discretized}")
print(f"  Reconstructed: {reconstructed:.4f}")
print(f"  Quantization error: {abs(original - reconstructed):.6f}")
print(f"  Bin resolution: {2/256:.6f}")

### 5.3 Why Llama-2 as Backbone?

| Aspect | Benefit |
|--------|--------|
| **Pre-training** | 2T tokens of diverse text → strong language understanding |
| **Instruction following** | Fine-tuned for following complex instructions |
| **Long context** | 4096 tokens → can process detailed task descriptions |
| **Size (7B)** | Balance between capability and inference speed |
| **Open weights** | Allows fine-tuning on domain-specific robot data |

---
## 6. Memory and Compute Summary

In [None]:
# Calculate model statistics
def model_stats(model):
    """Calculate detailed model statistics."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Memory in different precisions
    mem_fp32 = total_params * 4 / 1e9  # GB
    mem_fp16 = total_params * 2 / 1e9  # GB
    mem_int8 = total_params * 1 / 1e9  # GB
    
    return {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'mem_fp32_gb': mem_fp32,
        'mem_fp16_gb': mem_fp16,
        'mem_int8_gb': mem_int8,
    }

stats = model_stats(vla)

print("OpenVLA-7B Model Statistics")
print("="*60)
print(f"Total parameters: {stats['total_params']/1e9:.2f}B")
print(f"Trainable parameters: {stats['trainable_params']/1e9:.2f}B")
print(f"\nMemory Requirements:")
print(f"  FP32 (training): {stats['mem_fp32_gb']:.1f} GB")
print(f"  FP16/BF16 (inference): {stats['mem_fp16_gb']:.1f} GB")
print(f"  INT8 (quantized): {stats['mem_int8_gb']:.1f} GB")
print(f"\nYour GPU Setup:")
print(f"  4 × 40GB GPUs = 160 GB total")
print(f"  Can run {int(160 / stats['mem_fp16_gb'])} model instances in BF16")

In [None]:
# Clean up
del vla
torch.cuda.empty_cache()
print("Model cleared from memory.")

---
## Summary

### Key Takeaways

1. **OpenVLA Architecture**: Vision Backbone → Projector → LLM → Action Tokens

2. **Three Components**:
   - Vision: DINOv2 + SigLIP (dual encoder for rich + aligned features)
   - Projector: MLP mapping vision to LLM space
   - LLM: Llama-2 7B for instruction understanding and action generation

3. **Action Representation**: 
   - Continuous actions discretized to 256 bins
   - Generated as tokens by the LLM
   - 7 dimensions: x, y, z, roll, pitch, yaw, gripper

4. **Memory**: ~14GB in BF16, fits easily on your 40GB GPUs

### Next Steps
→ Continue to **03_vision_backbone_deep_dive.ipynb** to understand the dual vision encoder in detail.