# 06. Basic Inference with OpenVLA

**Goal**: Run OpenVLA-7B inference on sample images to predict robot actions.

## What We'll Learn
1. Loading OpenVLA from HuggingFace
2. Processing inputs (image + instruction)
3. Generating action predictions
4. Multi-GPU inference strategies
5. Batch inference for efficiency

---
## 1. Load OpenVLA Model

We'll use the HuggingFace transformers library for easy model loading.

In [None]:
# ============================================================
# CRITICAL: Set these BEFORE importing any packages!
# ============================================================
import os

# For NERSC Perlmutter, use your $PSCRATCH directory
PSCRATCH = "/pscratch/sd/d/dpark1"  # CHANGE THIS TO YOUR PATH
CACHE_DIR = f"{PSCRATCH}/.cache"

# Set all cache directories to $PSCRATCH/.cache
os.environ['XDG_CACHE_HOME'] = CACHE_DIR
os.environ['HF_HOME'] = f"{CACHE_DIR}/huggingface"
os.environ['TFDS_DATA_DIR'] = f"{CACHE_DIR}/tensorflow_datasets"
os.environ['TORCH_HOME'] = f"{CACHE_DIR}/torch"

# Create directories
for path in [CACHE_DIR, os.environ['HF_HOME'], os.environ['TFDS_DATA_DIR'], os.environ['TORCH_HOME']]:
    os.makedirs(path, exist_ok=True)

print(f"✅ All caches → {CACHE_DIR}")

# Now import other packages
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
import time

# Check GPU availability
print("\nGPU Configuration:")
print(f"  CUDA available: {torch.cuda.is_available()}")
print(f"  GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    props = torch.cuda.get_device_properties(i)
    print(f"  GPU {i}: {props.name} ({props.total_memory / 1e9:.1f} GB)")

In [None]:
# Model configuration
MODEL_ID = "openvla/openvla-7b"
DEVICE = "cuda:0"  # Use first GPU
DTYPE = torch.bfloat16  # BF16 for memory efficiency

print(f"Loading OpenVLA from {MODEL_ID}...")
print(f"  Device: {DEVICE}")
print(f"  Dtype: {DTYPE}")

In [None]:
# Load processor (handles image and text preprocessing)
processor = AutoProcessor.from_pretrained(
    MODEL_ID,
    trust_remote_code=True
)
print("Processor loaded!")

In [None]:
# Load model with optimizations
# Check if Flash Attention 2 is available
try:
    import flash_attn
    ATTN_IMPL = "flash_attention_2"
    print("✅ Flash Attention 2 available - using for 2-3x faster attention")
except ImportError:
    ATTN_IMPL = None  # Use default attention
    print("⚠️ Flash Attention 2 not installed - using default attention")
    print("   To install: pip install flash-attn --no-build-isolation")

# Build model kwargs
model_kwargs = {
    "torch_dtype": DTYPE,
    "low_cpu_mem_usage": True,
    "trust_remote_code": True,
}
if ATTN_IMPL:
    model_kwargs["attn_implementation"] = ATTN_IMPL

vla = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    **model_kwargs
).to(DEVICE)

# Set to evaluation mode
vla.eval()

print(f"\nModel loaded!")
print(f"  Parameters: {sum(p.numel() for p in vla.parameters()) / 1e9:.2f}B")
print(f"  Memory: ~{torch.cuda.memory_allocated() / 1e9:.1f} GB allocated")

---
## 2. Prepare Input Data

OpenVLA needs:
- RGB image from robot camera
- Natural language instruction

In [None]:
# Create a sample robot observation image
# In practice, this would come from your robot's camera

def create_sample_observation(size=(256, 256)):
    """
    Create a sample robot observation.
    
    In real usage:
    - image = camera.get_rgb_image()
    - instruction = task_specification
    """
    # Simulated robot workspace image
    # Add some structure to make it more realistic
    img = np.zeros((size[0], size[1], 3), dtype=np.uint8)
    
    # Background (table)
    img[:, :] = [200, 180, 160]  # Brownish table
    
    # Add some "objects"
    # Red block
    img[100:150, 80:130] = [200, 50, 50]  # Red
    # Blue cube
    img[80:120, 160:200] = [50, 50, 200]  # Blue
    # Green cylinder (approximated)
    img[140:180, 140:170] = [50, 200, 50]  # Green
    
    return Image.fromarray(img)

# Create sample image
sample_image = create_sample_observation()

# Display the image (in Jupyter)
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 6))
plt.imshow(sample_image)
plt.title("Sample Robot Observation")
plt.axis('off')
plt.show()

print(f"Image size: {sample_image.size}")

In [None]:
# Define the task instruction
instruction = "pick up the red block"

# Format as OpenVLA prompt
# The model expects: "In: What action should the robot take to {task}?\nOut:"
prompt = f"In: What action should the robot take to {instruction}?\nOut:"

print("Input prompt:")
print(f"  '{prompt}'")

---
## 3. Process Inputs

In [None]:
# Process image and text with the processor
inputs = processor(prompt, sample_image)

print("Processed inputs:")
for key, value in inputs.items():
    if hasattr(value, 'shape'):
        print(f"  {key}: shape={value.shape}, dtype={value.dtype}")
    else:
        print(f"  {key}: {type(value)}")

In [None]:
# Move inputs to device
inputs_device = {
    k: v.to(DEVICE, dtype=DTYPE) if isinstance(v, torch.Tensor) else v 
    for k, v in inputs.items()
}

print("Inputs moved to device:")
for key, value in inputs_device.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: device={value.device}, dtype={value.dtype}")

---
## 4. Generate Action Prediction

In [None]:
# First, see available unnorm_key options
print("Available unnorm_key options (dataset statistics):")
print("="*60)
for key in sorted(vla.norm_stats.keys()):
    print(f"  - {key}")

# Common choices:
# - bridge_orig: Good for tabletop manipulation (WidowX, similar to Franka)
# - fractal20220817_data: Google RT-1 dataset
# - kuka: KUKA robot grasping

In [None]:
# Generate action prediction with unnorm_key
# IMPORTANT: Must specify unnorm_key since model was trained on multiple datasets

UNNORM_KEY = "bridge_orig"  # Use Bridge dataset statistics (good for tabletop manipulation)

print(f"Generating action prediction using '{UNNORM_KEY}' statistics...")

with torch.no_grad():
    start_time = time.time()
    
    action = vla.predict_action(
        **inputs_device,
        unnorm_key=UNNORM_KEY,  # REQUIRED for multi-dataset models
        do_sample=False,        # Greedy decoding for determinism
    )
    
    inference_time = time.time() - start_time

print(f"\nInference completed in {inference_time*1000:.1f} ms")
print(f"\nPredicted action (normalized [-1, 1]):")
print(f"  Shape: {action.shape}")
print(f"  Values:")
action_names = ['x', 'y', 'z', 'roll', 'pitch', 'yaw', 'gripper']
for i, (name, val) in enumerate(zip(action_names, action)):
    print(f"    {name:8s}: {val:+.4f}")

In [None]:
# Method 2: Manual generation (for more control)
print("Manual generation method:")

with torch.no_grad():
    # Generate action tokens
    generated_ids = vla.generate(
        input_ids=inputs_device['input_ids'],
        attention_mask=inputs_device['attention_mask'],
        pixel_values=inputs_device['pixel_values'],
        max_new_tokens=7,
        do_sample=False,
        pad_token_id=processor.tokenizer.pad_token_id,
    )
    
    # Extract generated tokens (excluding input)
    input_len = inputs_device['input_ids'].shape[1]
    action_token_ids = generated_ids[0, input_len:].cpu().numpy()
    
print(f"\nGenerated token IDs: {action_token_ids}")
print(f"Vocabulary size: {processor.tokenizer.vocab_size}")
print(f"Action token range: [{processor.tokenizer.vocab_size - 256}, {processor.tokenizer.vocab_size - 1}]")

---
## 5. Understanding Action Output

In [None]:
# Visualize the action prediction
def visualize_action(action, title="Predicted Action"):
    """
    Visualize a 7-DoF robot action.
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    action_names = ['x', 'y', 'z', 'roll', 'pitch', 'yaw', 'gripper']
    colors = ['steelblue'] * 3 + ['coral'] * 3 + ['green']
    
    # Bar chart
    ax1 = axes[0]
    bars = ax1.bar(action_names, action, color=colors)
    ax1.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    ax1.axhline(y=1, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    ax1.axhline(y=-1, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    ax1.set_ylim(-1.5, 1.5)
    ax1.set_ylabel('Normalized Value')
    ax1.set_title(f'{title}\nPosition (blue), Rotation (orange), Gripper (green)')
    
    # Interpretation panel
    ax2 = axes[1]
    ax2.axis('off')
    
    interpretation = f"""
    Action Interpretation:
    
    Position Changes (normalized):
      X: {action[0]:+.3f} ({"right" if action[0] > 0 else "left"})
      Y: {action[1]:+.3f} ({"forward" if action[1] > 0 else "backward"})
      Z: {action[2]:+.3f} ({"up" if action[2] > 0 else "down"})
    
    Rotation Changes:
      Roll:  {action[3]:+.3f}
      Pitch: {action[4]:+.3f}
      Yaw:   {action[5]:+.3f}
    
    Gripper:
      Value: {action[6]:+.3f} ({"close" if action[6] > 0 else "open"})
    
    Note: Actions are normalized to [-1, 1].
    Actual robot motion depends on un-normalization
    using dataset-specific statistics.
    """
    ax2.text(0.1, 0.9, interpretation, transform=ax2.transAxes,
             fontfamily='monospace', fontsize=10, verticalalignment='top')
    
    plt.tight_layout()
    plt.show()

visualize_action(action, f"Predicted Action for: '{instruction}'")

---
## 6. Action Un-normalization

To get actual robot commands, we need to un-normalize using dataset statistics.

In [None]:
# Example: Un-normalize using Bridge dataset statistics
# These are approximate values - actual values stored in model

bridge_stats = {
    'mean': np.array([0.0002, 0.0001, -0.0003, 0.0001, 0.0003, -0.0002, 0.49]),
    'std': np.array([0.0074, 0.0058, 0.0074, 0.026, 0.024, 0.052, 0.50]),
}

def unnormalize_action(normalized_action, stats):
    """Convert normalized [-1,1] action to real robot units."""
    # OpenVLA uses: normalized = (action - mean) / std
    # So: action = normalized * std + mean
    return normalized_action * stats['std'] + stats['mean']

# Un-normalize the predicted action
real_action = unnormalize_action(action, bridge_stats)

print("Un-normalized Action (Bridge dataset):")
print("="*60)
print(f"{'Dimension':<10} {'Normalized':>12} {'Real':>12} {'Unit':>10}")
print("-"*60)

units = ['m', 'm', 'm', 'rad', 'rad', 'rad', '']
for i, (name, norm, real, unit) in enumerate(zip(action_names, action, real_action, units)):
    print(f"{name:<10} {norm:>12.4f} {real:>12.6f} {unit:>10}")

In [None]:
# Using OpenVLA's built-in un-normalization
# The model stores statistics for datasets it was trained on

# List available un-normalization keys
if hasattr(vla, 'norm_stats'):
    print("Available un-normalization keys:")
    for key in vla.norm_stats.keys():
        print(f"  - {key}")
else:
    print("Checking model config for normalization statistics...")
    # Statistics may be stored differently depending on model version

---
## 7. Multi-GPU Inference

With 4×40GB GPUs, we can run multiple model instances in parallel.

In [None]:
# Strategy 1: Load model on multiple GPUs for parallel rollouts
def load_multi_gpu_models(model_id, devices=["cuda:0", "cuda:1", "cuda:2", "cuda:3"]):
    """
    Load separate model instances on each GPU.
    
    Useful for:
    - Running multiple environment rollouts in parallel
    - Different tasks on different GPUs
    """
    # Check Flash Attention availability
    try:
        import flash_attn
        attn_impl = "flash_attention_2"
    except ImportError:
        attn_impl = None
    
    model_kwargs = {
        "torch_dtype": torch.bfloat16,
        "low_cpu_mem_usage": True,
        "trust_remote_code": True,
    }
    if attn_impl:
        model_kwargs["attn_implementation"] = attn_impl
    
    models = {}
    
    for device in devices:
        print(f"Loading model on {device}...")
        model = AutoModelForVision2Seq.from_pretrained(
            model_id,
            **model_kwargs
        ).to(device)
        model.eval()
        models[device] = model
    
    return models

# Example usage (commented out to save memory)
# multi_gpu_models = load_multi_gpu_models(MODEL_ID)
print("Multi-GPU strategy: One model per GPU for parallel rollouts")
print("  GPU 0: Environment 1")
print("  GPU 1: Environment 2")
print("  GPU 2: Environment 3")
print("  GPU 3: Environment 4")

In [None]:
# Strategy 2: Model parallelism with device_map="auto"
# Splits model across GPUs for larger batch sizes

model_parallel_config = """
# For model parallelism:
vla = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    device_map="auto",  # Automatically distribute across GPUs
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

# The model will be split across available GPUs
# Useful for:
# - Larger batch sizes than single GPU can handle
# - Running very large models (e.g., 70B)
"""
print("Model Parallelism Configuration:")
print(model_parallel_config)

# Batch inference for multiple observations
def batch_predict_actions(model, processor, images, instructions, device, unnorm_key="bridge_orig"):
    """
    Predict actions for multiple observations in a batch.
    
    Args:
        model: OpenVLA model
        processor: OpenVLA processor
        images: List of PIL Images
        instructions: List of instruction strings
        device: Target device
        unnorm_key: Dataset statistics to use for un-normalization
    
    Returns:
        Array of actions, shape (batch_size, 7)
    """
    batch_size = len(images)
    
    # Format prompts
    prompts = [
        f"In: What action should the robot take to {inst}?\nOut:"
        for inst in instructions
    ]
    
    # Process batch
    # Note: This is a simplified version - actual batching may need
    # special handling depending on processor implementation
    batch_actions = []
    
    for prompt, image in zip(prompts, images):
        inputs = processor(prompt, image)
        inputs_device = {
            k: v.to(device, dtype=torch.bfloat16) if isinstance(v, torch.Tensor) else v
            for k, v in inputs.items()
        }
        
        with torch.no_grad():
            action = model.predict_action(
                **inputs_device,
                unnorm_key=unnorm_key,  # REQUIRED
                do_sample=False,
            )
        batch_actions.append(action)
    
    return np.stack(batch_actions)

# Test batch inference
test_images = [create_sample_observation() for _ in range(4)]
test_instructions = [
    "pick up the red block",
    "push the blue cube",
    "move the green object",
    "grasp the red block",
]

print("Running batch inference on 4 observations...")
start = time.time()
batch_actions = batch_predict_actions(vla, processor, test_images, test_instructions, DEVICE, UNNORM_KEY)
batch_time = time.time() - start

print(f"\nBatch inference completed in {batch_time*1000:.1f} ms")
print(f"Average per sample: {batch_time/4*1000:.1f} ms")
print(f"\nBatch actions shape: {batch_actions.shape}")

for i, (inst, action) in enumerate(zip(test_instructions, batch_actions)):
    print(f"\nTask {i+1}: '{inst}'")
    print(f"  Action: {action}")

In [None]:
# Batch inference for multiple observations
def batch_predict_actions(model, processor, images, instructions, device):
    """
    Predict actions for multiple observations in a batch.
    
    Args:
        model: OpenVLA model
        processor: OpenVLA processor
        images: List of PIL Images
        instructions: List of instruction strings
        device: Target device
    
    Returns:
        Array of actions, shape (batch_size, 7)
    """
    batch_size = len(images)
    
    # Format prompts
    prompts = [
        f"In: What action should the robot take to {inst}?\nOut:"
        for inst in instructions
    ]
    
    # Process batch
    # Note: This is a simplified version - actual batching may need
    # special handling depending on processor implementation
    batch_actions = []
    
    for prompt, image in zip(prompts, images):
        inputs = processor(prompt, image)
        inputs_device = {
            k: v.to(device, dtype=torch.bfloat16) if isinstance(v, torch.Tensor) else v
            for k, v in inputs.items()
        }
        
        with torch.no_grad():
            action = model.predict_action(
                **inputs_device,
                do_sample=False,
                max_new_tokens=7,
            )
        batch_actions.append(action)
    
    return np.stack(batch_actions)

# Test batch inference
test_images = [create_sample_observation() for _ in range(4)]
test_instructions = [
    "pick up the red block",
    "push the blue cube",
    "move the green object",
    "grasp the red block",
]

print("Running batch inference on 4 observations...")
start = time.time()
batch_actions = batch_predict_actions(vla, processor, test_images, test_instructions, DEVICE)
batch_time = time.time() - start

print(f"\nBatch inference completed in {batch_time*1000:.1f} ms")
print(f"Average per sample: {batch_time/4*1000:.1f} ms")
print(f"\nBatch actions shape: {batch_actions.shape}")

for i, (inst, action) in enumerate(zip(test_instructions, batch_actions)):
    print(f"\nTask {i+1}: '{inst}'")
    print(f"  Action: {action}")

---
## 9. Inference Optimization Tips

In [None]:
# Benchmark inference speed
def benchmark_inference(model, processor, device, unnorm_key="bridge_orig", n_runs=10, warmup=3):
    """Benchmark inference speed."""
    image = create_sample_observation()
    instruction = "pick up the red block"
    prompt = f"In: What action should the robot take to {instruction}?\nOut:"
    
    inputs = processor(prompt, image)
    inputs_device = {
        k: v.to(device, dtype=torch.bfloat16) if isinstance(v, torch.Tensor) else v
        for k, v in inputs.items()
    }
    
    # Warmup runs
    print(f"Warming up ({warmup} runs)...")
    for _ in range(warmup):
        with torch.no_grad():
            _ = model.predict_action(**inputs_device, unnorm_key=unnorm_key, do_sample=False)
    
    # Benchmark runs
    torch.cuda.synchronize()
    times = []
    
    print(f"Benchmarking ({n_runs} runs)...")
    for _ in range(n_runs):
        torch.cuda.synchronize()
        start = time.time()
        
        with torch.no_grad():
            _ = model.predict_action(**inputs_device, unnorm_key=unnorm_key, do_sample=False)
        
        torch.cuda.synchronize()
        times.append(time.time() - start)
    
    times = np.array(times) * 1000  # Convert to ms
    
    print(f"\nInference Benchmark Results:")
    print(f"  Mean: {times.mean():.1f} ms")
    print(f"  Std:  {times.std():.1f} ms")
    print(f"  Min:  {times.min():.1f} ms")
    print(f"  Max:  {times.max():.1f} ms")
    print(f"  Throughput: {1000/times.mean():.1f} actions/sec")
    
    return times

times = benchmark_inference(vla, processor, DEVICE, UNNORM_KEY)

In [None]:
# Benchmark inference speed
def benchmark_inference(model, processor, device, n_runs=10, warmup=3):
    """Benchmark inference speed."""
    image = create_sample_observation()
    instruction = "pick up the red block"
    prompt = f"In: What action should the robot take to {instruction}?\nOut:"
    
    inputs = processor(prompt, image)
    inputs_device = {
        k: v.to(device, dtype=torch.bfloat16) if isinstance(v, torch.Tensor) else v
        for k, v in inputs.items()
    }
    
    # Warmup runs
    print(f"Warming up ({warmup} runs)...")
    for _ in range(warmup):
        with torch.no_grad():
            _ = model.predict_action(**inputs_device, do_sample=False, max_new_tokens=7)
    
    # Benchmark runs
    torch.cuda.synchronize()
    times = []
    
    print(f"Benchmarking ({n_runs} runs)...")
    for _ in range(n_runs):
        torch.cuda.synchronize()
        start = time.time()
        
        with torch.no_grad():
            _ = model.predict_action(**inputs_device, do_sample=False, max_new_tokens=7)
        
        torch.cuda.synchronize()
        times.append(time.time() - start)
    
    times = np.array(times) * 1000  # Convert to ms
    
    print(f"\nInference Benchmark Results:")
    print(f"  Mean: {times.mean():.1f} ms")
    print(f"  Std:  {times.std():.1f} ms")
    print(f"  Min:  {times.min():.1f} ms")
    print(f"  Max:  {times.max():.1f} ms")
    print(f"  Throughput: {1000/times.mean():.1f} actions/sec")
    
    return times

times = benchmark_inference(vla, processor, DEVICE)

---
## Summary

### Key Steps for OpenVLA Inference

1. **Load Model**: Use HuggingFace transformers with Flash Attention 2

2. **Prepare Inputs**:
   - RGB image from robot camera
   - Natural language instruction
   - Format as: "In: What action should the robot take to {task}?\nOut:"

3. **Generate Action**:
   - `model.predict_action()` returns normalized action
   - Or manual generation with `model.generate()`

4. **Un-normalize**: Convert [-1,1] to actual robot commands

### Your 4×40GB GPU Setup
- **Option 1**: 4 parallel model instances for 4 parallel rollouts
- **Option 2**: Model parallelism for larger batch sizes
- **Memory per model**: ~14GB in BF16

### Next Steps
→ Continue to **07_libero_setup.ipynb** to set up LIBERO simulation environment.

In [None]:
# Clean up
del vla
torch.cuda.empty_cache()
print("Memory cleared.")