# OpenVLA Fine-tuned Model Evaluation: Understanding What the Model Learned

This notebook helps you understand:
1. **What do the metrics mean?** (Loss, L1 Error explained)
2. **What does the model actually predict?** (Side-by-side comparison)
3. **How accurate is each action dimension?** (Position, rotation, gripper)
4. **Is the model making reasonable predictions?** (Visualizations)

## Metrics Explained

| Metric | What it measures | Good value |
|--------|------------------|------------|
| **Loss** | How wrong the predicted tokens are (cross-entropy) | < 2.0 |
| **L1 Error** | Average absolute difference between predicted and true actions | < 0.15 |
| **Gripper Accuracy** | % of times gripper open/close is correct | > 90% |
| **Direction Accuracy** | % of times movement direction is correct | > 80% |

### What does L1 Error = 0.15 mean?
- Actions are in range [-1, 1]
- L1 = 0.15 means predictions are off by ~15% on average
- For a robot moving 10cm, that's ~1.5cm error per step

In [None]:
# Setup
import os
import sys
import numpy as np
import torch
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
import h5py
from tqdm import tqdm

# Configuration
if 'SCRATCH' in os.environ:
    BASE_DIR = os.environ['SCRATCH']
else:
    BASE_DIR = "/home/idies/workspace/Temporary/dpark1/scratch"

CACHE_DIR = f"{BASE_DIR}/.cache"
LIBERO_DATA_DIR = f"{BASE_DIR}/libero_data"
CHECKPOINT_DIR = f"{BASE_DIR}/openvla_finetuned"

os.environ['HF_HOME'] = f"{CACHE_DIR}/huggingface"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import warnings
warnings.filterwarnings('ignore')

print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"LIBERO data directory: {LIBERO_DATA_DIR}")

In [None]:
# Find available checkpoints
checkpoint_path = Path(CHECKPOINT_DIR)
if checkpoint_path.exists():
    runs = sorted(checkpoint_path.glob("libero_*"))
    print("Available fine-tuning runs:")
    for i, run in enumerate(runs):
        print(f"  [{i}] {run.name}")
        # Check for best/final
        if (run / "best").exists():
            print(f"      - best/ (recommended)")
        if (run / "final").exists():
            print(f"      - final/")
else:
    print(f"No checkpoints found at {CHECKPOINT_DIR}")
    print("Please run fine-tuning first.")

In [None]:
# ============================================================
# SELECT YOUR CHECKPOINT HERE
# ============================================================
# Option 1: Use the latest run's best model
if runs:
    SELECTED_RUN = runs[-1]  # Latest run
    if (SELECTED_RUN / "best").exists():
        LORA_CHECKPOINT = str(SELECTED_RUN / "best")
    elif (SELECTED_RUN / "final").exists():
        LORA_CHECKPOINT = str(SELECTED_RUN / "final")
    else:
        # Find latest checkpoint-XXXX
        checkpoints = sorted(SELECTED_RUN.glob("checkpoint-*"))
        LORA_CHECKPOINT = str(checkpoints[-1]) if checkpoints else None
    
    print(f"Selected checkpoint: {LORA_CHECKPOINT}")
else:
    LORA_CHECKPOINT = None
    print("No checkpoint selected. Will evaluate base model only.")

# Option 2: Manual override (uncomment and set path)
# LORA_CHECKPOINT = "/path/to/your/checkpoint"

## 1. Load Models (Base vs Fine-tuned)

In [None]:
from transformers import AutoModelForVision2Seq, AutoProcessor

print("Loading base OpenVLA model...")
base_model = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    cache_dir=f"{CACHE_DIR}/huggingface",
    low_cpu_mem_usage=True,
)

processor = AutoProcessor.from_pretrained(
    "openvla/openvla-7b",
    trust_remote_code=True,
    cache_dir=f"{CACHE_DIR}/huggingface",
)

print(f"Base model loaded. Device: {base_model.device}")

In [None]:
# Load fine-tuned model (if checkpoint exists)
if LORA_CHECKPOINT:
    from peft import PeftModel
    
    print(f"Loading LoRA weights from: {LORA_CHECKPOINT}")
    finetuned_model = PeftModel.from_pretrained(
        base_model,
        LORA_CHECKPOINT,
        torch_dtype=torch.bfloat16,
    )
    finetuned_model.eval()
    print("Fine-tuned model loaded!")
else:
    finetuned_model = None
    print("No fine-tuned model - will compare base model predictions to ground truth.")

## 2. Action Tokenizer (for decoding predictions)

In [None]:
class ActionTokenizer:
    """Decode action tokens back to continuous values."""
    
    def __init__(self, vocab_size=32000, n_bins=256):
        self.vocab_size = vocab_size
        self.n_bins = n_bins
        self.bins = np.linspace(-1, 1, n_bins)
        self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2
        self.action_token_start = vocab_size - n_bins  # 31744
        self.action_token_end = vocab_size - 1  # 31999
    
    def decode(self, token_ids):
        """Convert token IDs to continuous actions."""
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.cpu().numpy()
        discretized = self.vocab_size - token_ids
        indices = np.clip(discretized - 1, 0, len(self.bin_centers) - 1)
        return self.bin_centers[indices]
    
    def is_action_token(self, token_id):
        """Check if token is an action token."""
        return self.action_token_start <= token_id <= self.action_token_end

action_tokenizer = ActionTokenizer(vocab_size=len(processor.tokenizer))
print(f"Action tokens range: [{action_tokenizer.action_token_start}, {action_tokenizer.action_token_end}]")

## 3. Load LIBERO Test Data

In [None]:
def load_libero_samples(data_dir, suite="libero_spatial", max_demos_per_task=5, max_steps_per_demo=50):
    """Load samples from LIBERO for evaluation.
    
    Args:
        data_dir: Path to LIBERO data
        suite: Which suite to load
        max_demos_per_task: How many demos to use per task (default: 5 = 10% of data)
        max_steps_per_demo: How many timesteps per demo (default: 50)
    
    For libero_spatial with defaults: 10 tasks √ó 5 demos √ó 50 steps = 2,500 samples
    This is enough for reliable metrics while being fast to evaluate.
    """
    data_path = Path(data_dir)
    
    # Find HDF5 files
    hdf5_files = sorted(list(data_path.rglob("*.hdf5")))
    if not hdf5_files:
        print(f"No HDF5 files found in {data_dir}")
        return [], []
    
    print(f"Found {len(hdf5_files)} task files (HDF5)")
    
    all_samples = []
    episodes = []  # For trajectory visualization
    
    for filepath in tqdm(hdf5_files, desc="Loading tasks"):
        try:
            with h5py.File(filepath, 'r') as f:
                # Get instruction
                instruction = "complete the task"
                for key in ['language_instruction', 'problem_info', 'language']:
                    if key in f.attrs:
                        inst = f.attrs[key]
                        if isinstance(inst, bytes):
                            inst = inst.decode('utf-8')
                        instruction = inst
                        break
                
                if 'data' not in f:
                    continue
                
                demo_keys = sorted([k for k in f['data'].keys() if k.startswith('demo_')])
                
                # Use last N demos as "validation" (held-out from training)
                val_demos = demo_keys[-max_demos_per_task:]
                
                for demo_key in val_demos:
                    demo = f['data'][demo_key]
                    if 'actions' not in demo or 'obs' not in demo:
                        continue
                    
                    # Find image key
                    img_key = None
                    for k in ['agentview_rgb', 'agentview_image', 'rgb', 'image']:
                        if k in demo['obs']:
                            img_key = k
                            break
                    if img_key is None:
                        continue
                    
                    n_steps = len(demo['actions'])
                    
                    # Store episode for trajectory viz
                    episode_data = {
                        'images': [],
                        'actions': [],
                        'instruction': instruction,
                        'file': filepath.name,
                        'demo': demo_key,
                    }
                    
                    # Sample timesteps evenly across the episode
                    step_indices = np.linspace(0, n_steps-1, min(max_steps_per_demo, n_steps), dtype=int)
                    
                    for t in step_indices:
                        image = demo['obs'][img_key][t]
                        image = np.rot90(image, k=2)  # 180 degree rotation
                        
                        action = demo['actions'][t]
                        if len(action) < 7:
                            action = np.pad(action, (0, 7 - len(action)))
                        else:
                            action = action[:7]
                        
                        # Apply official LIBERO transform
                        action = action.astype(np.float32)
                        action[:6] = np.clip(action[:6], -1.0, 1.0)
                        gripper = np.clip(action[6], 0.0, 1.0)
                        action[6] = 1.0 - gripper  # Invert gripper
                        
                        all_samples.append({
                            'image': image,
                            'instruction': instruction,
                            'action': action,
                            'file': filepath.name,
                            'demo': demo_key,
                            'timestep': t,
                        })
                        
                        episode_data['images'].append(image)
                        episode_data['actions'].append(action)
                    
                    if len(episode_data['actions']) > 0:
                        episode_data['actions'] = np.array(episode_data['actions'])
                        episodes.append(episode_data)
                        
        except Exception as e:
            print(f"Error reading {filepath}: {e}")
    
    return all_samples, episodes

# ============================================================
# CONFIGURE EVALUATION SIZE
# ============================================================
# For quick evaluation: max_demos_per_task=2, max_steps_per_demo=20 (~400 samples)
# For standard evaluation: max_demos_per_task=5, max_steps_per_demo=50 (~2,500 samples)
# For thorough evaluation: max_demos_per_task=10, max_steps_per_demo=100 (~10,000 samples)

MAX_DEMOS_PER_TASK = 5   # Use last 5 demos per task as validation
MAX_STEPS_PER_DEMO = 50  # Sample 50 timesteps per demo

print("\n" + "="*60)
print(" LOADING LIBERO VALIDATION SAMPLES")
print("="*60)
print(f"\nConfiguration:")
print(f"  Demos per task: {MAX_DEMOS_PER_TASK} (last N demos held out)")
print(f"  Steps per demo: {MAX_STEPS_PER_DEMO}")

test_samples, episodes = load_libero_samples(
    LIBERO_DATA_DIR, 
    max_demos_per_task=MAX_DEMOS_PER_TASK,
    max_steps_per_demo=MAX_STEPS_PER_DEMO
)

print(f"\n‚úÖ Loaded {len(test_samples)} validation samples from {len(episodes)} episodes")

if test_samples:
    print(f"\nDataset Statistics:")
    unique_instructions = set(s['instruction'] for s in test_samples)
    print(f"  Unique tasks: {len(unique_instructions)}")
    print(f"  Samples per task: ~{len(test_samples) // len(unique_instructions)}")
    print(f"\nSample instructions:")
    for inst in list(unique_instructions)[:5]:
        print(f"  - {inst[:60]}...")

## 4. Predict Actions

In [None]:
def predict_action(model, processor, image, instruction, action_tokenizer):
    """Predict 7-DoF action from image and instruction."""
    # Create prompt
    prompt = f"In: What action should the robot take to {instruction.lower()}?\nOut:"
    
    # Convert image to PIL
    if isinstance(image, np.ndarray):
        pil_image = Image.fromarray(image.astype(np.uint8))
    else:
        pil_image = image
    
    # Resize to 224x224
    if pil_image.size != (224, 224):
        pil_image = pil_image.resize((224, 224), Image.LANCZOS)
    
    # Process inputs
    inputs = processor(prompt, pil_image, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # CRITICAL: Convert pixel_values to bfloat16 to match model dtype
    if 'pixel_values' in inputs:
        inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=7,
            do_sample=False,
            pad_token_id=processor.tokenizer.pad_token_id,
        )
    
    # Extract action tokens (last 7)
    action_tokens = outputs[0, -7:].cpu().numpy()
    
    # Decode to continuous actions
    action = action_tokenizer.decode(action_tokens)
    
    return action, action_tokens

# Test prediction
if test_samples:
    sample = test_samples[0]
    model_to_test = finetuned_model if finetuned_model else base_model
    pred_action, pred_tokens = predict_action(
        model_to_test, processor, sample['image'], sample['instruction'], action_tokenizer
    )
    print("Test prediction:")
    print(f"  Instruction: {sample['instruction'][:50]}...")
    print(f"  Ground truth: {sample['action']}")
    print(f"  Prediction:   {pred_action}")
    print(f"  Token IDs:    {pred_tokens}")

## 5. Run Full Evaluation

In [None]:
def evaluate_model(model, processor, samples, action_tokenizer, model_name="Model"):
    """Evaluate model on samples and compute metrics."""
    results = []
    
    print(f"\nEvaluating {model_name} on {len(samples)} samples...")
    
    for sample in tqdm(samples):
        try:
            pred_action, pred_tokens = predict_action(
                model, processor, sample['image'], sample['instruction'], action_tokenizer
            )
            
            gt_action = sample['action']
            
            results.append({
                'gt_action': gt_action,
                'pred_action': pred_action,
                'instruction': sample['instruction'],
                'image': sample['image'],
            })
        except Exception as e:
            print(f"Error: {e}")
    
    return results

# Evaluate fine-tuned model
if finetuned_model and test_samples:
    finetuned_results = evaluate_model(
        finetuned_model, processor, test_samples, action_tokenizer, "Fine-tuned"
    )
elif test_samples:
    print("No fine-tuned model. Evaluating base model...")
    finetuned_results = evaluate_model(
        base_model, processor, test_samples, action_tokenizer, "Base"
    )

## 6. Compute Interpretable Metrics

In [None]:
def compute_metrics(results):
    """Compute interpretable metrics from evaluation results."""
    gt_actions = np.array([r['gt_action'] for r in results])
    pred_actions = np.array([r['pred_action'] for r in results])
    
    metrics = {}
    
    # 1. Overall L1 Error (lower is better)
    l1_error = np.abs(pred_actions - gt_actions).mean()
    metrics['L1 Error (Overall)'] = l1_error
    
    # 2. Per-dimension L1 Error
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'gripper']
    for i, name in enumerate(dim_names):
        metrics[f'L1 Error ({name})'] = np.abs(pred_actions[:, i] - gt_actions[:, i]).mean()
    
    # 3. Position Error (dims 0-2)
    pos_error = np.abs(pred_actions[:, :3] - gt_actions[:, :3]).mean()
    metrics['Position Error (xyz)'] = pos_error
    
    # 4. Rotation Error (dims 3-5)
    rot_error = np.abs(pred_actions[:, 3:6] - gt_actions[:, 3:6]).mean()
    metrics['Rotation Error (rpy)'] = rot_error
    
    # 5. Gripper Accuracy (is open/close correct?)
    # Gripper > 0.5 = open, < 0.5 = close
    gt_gripper_open = gt_actions[:, 6] > 0.5
    pred_gripper_open = pred_actions[:, 6] > 0.5
    gripper_accuracy = (gt_gripper_open == pred_gripper_open).mean() * 100
    metrics['Gripper Accuracy (%)'] = gripper_accuracy
    
    # 6. Direction Accuracy (is movement direction correct?)
    # Check if sign matches for position
    gt_signs = np.sign(gt_actions[:, :3])
    pred_signs = np.sign(pred_actions[:, :3])
    # Only count where ground truth is not ~0
    significant_movement = np.abs(gt_actions[:, :3]) > 0.05
    if significant_movement.sum() > 0:
        direction_accuracy = (gt_signs[significant_movement] == pred_signs[significant_movement]).mean() * 100
    else:
        direction_accuracy = 0
    metrics['Direction Accuracy (%)'] = direction_accuracy
    
    # 7. Action magnitude correlation
    gt_magnitude = np.linalg.norm(gt_actions[:, :3], axis=1)
    pred_magnitude = np.linalg.norm(pred_actions[:, :3], axis=1)
    if gt_magnitude.std() > 0:
        correlation = np.corrcoef(gt_magnitude, pred_magnitude)[0, 1]
    else:
        correlation = 0
    metrics['Magnitude Correlation'] = correlation
    
    return metrics, gt_actions, pred_actions

# Compute metrics
if finetuned_results:
    metrics, gt_actions, pred_actions = compute_metrics(finetuned_results)
    
    print("\n" + "="*60)
    print(" EVALUATION RESULTS")
    print("="*60)
    
    print("\nüìä OVERALL METRICS:")
    print(f"  L1 Error:           {metrics['L1 Error (Overall)']:.4f}  (lower is better, target < 0.15)")
    print(f"  Gripper Accuracy:   {metrics['Gripper Accuracy (%)']:.1f}%  (higher is better, target > 90%)")
    print(f"  Direction Accuracy: {metrics['Direction Accuracy (%)']:.1f}%  (higher is better, target > 80%)")
    print(f"  Magnitude Corr:     {metrics['Magnitude Correlation']:.3f}  (higher is better, target > 0.7)")
    
    print("\nüìè PER-DIMENSION L1 ERROR:")
    print(f"  Position (x,y,z):  {metrics['Position Error (xyz)']:.4f}")
    print(f"  Rotation (r,p,y):  {metrics['Rotation Error (rpy)']:.4f}")
    print(f"  Gripper:           {metrics['L1 Error (gripper)']:.4f}")
    
    print("\nüìà INTERPRETATION:")
    if metrics['L1 Error (Overall)'] < 0.15:
        print("  ‚úÖ L1 Error is GOOD - model is making accurate predictions")
    elif metrics['L1 Error (Overall)'] < 0.25:
        print("  ‚ö†Ô∏è L1 Error is MODERATE - model is learning but could improve")
    else:
        print("  ‚ùå L1 Error is HIGH - model needs more training or debugging")
    
    if metrics['Gripper Accuracy (%)'] > 90:
        print("  ‚úÖ Gripper Accuracy is EXCELLENT")
    elif metrics['Gripper Accuracy (%)'] > 70:
        print("  ‚ö†Ô∏è Gripper Accuracy is OK but could improve")
    else:
        print("  ‚ùå Gripper Accuracy is LOW - check gripper transform")

## 7. Visualize Predictions vs Ground Truth

In [None]:
def visualize_predictions(results, n_samples=6):
    """Visualize side-by-side predictions vs ground truth."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    indices = np.random.choice(len(results), min(n_samples, len(results)), replace=False)
    
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'grip']
    
    for idx, ax in zip(indices, axes):
        result = results[idx]
        
        # Show image
        ax.imshow(result['image'])
        ax.axis('off')
        
        # Add action comparison as text
        gt = result['gt_action']
        pred = result['pred_action']
        
        text = f"Instruction: {result['instruction'][:30]}...\n\n"
        text += "       GT    Pred   Err\n"
        for i, name in enumerate(dim_names):
            err = abs(gt[i] - pred[i])
            color = '‚úì' if err < 0.15 else '‚úó'
            text += f"{name:5}: {gt[i]:+.2f}  {pred[i]:+.2f}  {color}\n"
        
        ax.set_title(text, fontsize=8, family='monospace', loc='left')
    
    plt.tight_layout()
    plt.savefig('prediction_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved to: prediction_visualization.png")

if finetuned_results:
    visualize_predictions(finetuned_results)

In [None]:
def plot_action_distribution(gt_actions, pred_actions):
    """Plot distribution of predicted vs ground truth actions."""
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'gripper']
    
    for i, (ax, name) in enumerate(zip(axes.flatten()[:7], dim_names)):
        ax.scatter(gt_actions[:, i], pred_actions[:, i], alpha=0.5, s=20)
        ax.plot([-1, 1], [-1, 1], 'r--', label='Perfect')
        ax.set_xlabel(f'Ground Truth {name}')
        ax.set_ylabel(f'Predicted {name}')
        ax.set_title(f'{name}: corr={np.corrcoef(gt_actions[:, i], pred_actions[:, i])[0,1]:.3f}')
        ax.set_xlim(-1.1, 1.1)
        ax.set_ylim(-1.1, 1.1)
        ax.grid(True, alpha=0.3)
    
    # Hide last subplot
    axes.flatten()[-1].axis('off')
    
    plt.suptitle('Predicted vs Ground Truth Actions (each point = one sample)', fontsize=14)
    plt.tight_layout()
    plt.savefig('action_distribution.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved to: action_distribution.png")

if finetuned_results:
    plot_action_distribution(gt_actions, pred_actions)

In [None]:
def plot_error_histogram(gt_actions, pred_actions):
    """Plot histogram of prediction errors."""
    errors = pred_actions - gt_actions
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'gripper']
    
    for i, (ax, name) in enumerate(zip(axes.flatten()[:7], dim_names)):
        ax.hist(errors[:, i], bins=30, edgecolor='black', alpha=0.7)
        ax.axvline(x=0, color='r', linestyle='--', label='Zero error')
        ax.axvline(x=errors[:, i].mean(), color='g', linestyle='-', label=f'Mean: {errors[:, i].mean():.3f}')
        ax.set_xlabel(f'Error in {name}')
        ax.set_ylabel('Count')
        ax.set_title(f'{name}: std={errors[:, i].std():.3f}')
        ax.legend(fontsize=8)
    
    axes.flatten()[-1].axis('off')
    
    plt.suptitle('Distribution of Prediction Errors (0 = perfect)', fontsize=14)
    plt.tight_layout()
    plt.savefig('error_histogram.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved to: error_histogram.png")

if finetuned_results:
    plot_error_histogram(gt_actions, pred_actions)

## 8. Compare Base vs Fine-tuned (Optional)

In [None]:
# Compare base model vs fine-tuned model on ALL samples
# NOTE: PEFT modifies base_model in-place, so we need to disable adapters for true comparison

if finetuned_model and len(test_samples) > 0:
    print("="*70)
    print(" EVALUATING BASE MODEL (LoRA adapters DISABLED) - ALL SAMPLES")
    print("="*70)
    print("\n‚ö†Ô∏è  IMPORTANT: We disable LoRA adapters to get the TRUE base OpenVLA-7B")
    print("    performance. This lets us see how much fine-tuning actually helped.")
    print(f"\nüìä Evaluating on ALL {len(test_samples)} samples for consistent comparison.\n")
    
    # CRITICAL: Disable LoRA adapters to get true base model behavior
    finetuned_model.disable_adapter_layers()
    
    base_results = evaluate_model(
        finetuned_model, processor, test_samples, action_tokenizer, "Base OpenVLA-7B (no LoRA)"
    )
    
    # Re-enable LoRA adapters for fine-tuned comparison
    finetuned_model.enable_adapter_layers()
    
    # Compute base metrics
    base_metrics, base_gt, base_pred = compute_metrics(base_results)
    
    # Also evaluate fine-tuned on same samples for fair comparison
    print("\n" + "="*70)
    print(" EVALUATING FINE-TUNED MODEL (LoRA adapters ENABLED) - ALL SAMPLES")
    print("="*70)
    
    ft_results_comparison = evaluate_model(
        finetuned_model, processor, test_samples, action_tokenizer, "Fine-tuned OpenVLA"
    )
    ft_metrics_comparison, ft_gt, ft_pred = compute_metrics(ft_results_comparison)
    
    # Store for trajectory visualization
    comparison_results = {
        'base': base_results,
        'finetuned': ft_results_comparison,
    }
    
    # Display comparison table
    print("\n" + "="*70)
    print(f" üìä BASE vs FINE-TUNED COMPARISON (on {len(test_samples)} samples)")
    print("="*70)
    print("\n‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê")
    print("‚îÇ Metric                  ‚îÇ Base Model   ‚îÇ Fine-tuned   ‚îÇ Change       ‚îÇ")
    print("‚îú‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î§")
    
    for key in ['L1 Error (Overall)', 'Gripper Accuracy (%)', 'Direction Accuracy (%)']:
        base_val = base_metrics[key]
        ft_val = ft_metrics_comparison[key]
        
        if 'Error' in key:
            improvement = base_val - ft_val  # Lower is better for error
            sign = "-" if improvement > 0 else "+"
            better = "‚úÖ" if improvement > 0 else "‚ùå"
        else:
            improvement = ft_val - base_val  # Higher is better for accuracy
            sign = "+" if improvement > 0 else ""
            better = "‚úÖ" if improvement > 0 else "‚ùå"
        
        print(f"‚îÇ {key:<23} ‚îÇ {base_val:>10.3f}   ‚îÇ {ft_val:>10.3f}   ‚îÇ {sign}{abs(improvement):>8.3f} {better} ‚îÇ")
    
    print("‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò")
    
    print("\nüìå EXPLANATION:")
    print("   ‚Ä¢ Base Model     = Pretrained OpenVLA-7B (trained on Open-X Embodiment)")
    print("   ‚Ä¢ Fine-tuned     = Same model + LoRA adapters trained on YOUR LIBERO data")
    print("   ‚Ä¢ Change column  = How much fine-tuning improved (‚úÖ) or hurt (‚ùå) the metric")
    
    # Calculate overall improvement
    l1_improvement = (base_metrics['L1 Error (Overall)'] - ft_metrics_comparison['L1 Error (Overall)']) / base_metrics['L1 Error (Overall)'] * 100
    grip_improvement = ft_metrics_comparison['Gripper Accuracy (%)'] - base_metrics['Gripper Accuracy (%)']
    dir_improvement = ft_metrics_comparison['Direction Accuracy (%)'] - base_metrics['Direction Accuracy (%)']
    
    print(f"\nüéØ IMPROVEMENTS FROM FINE-TUNING:")
    print(f"   ‚Ä¢ L1 Error reduced by:        {l1_improvement:.1f}%")
    print(f"   ‚Ä¢ Gripper Accuracy improved:  {grip_improvement:+.1f}%")
    print(f"   ‚Ä¢ Direction Accuracy improved: {dir_improvement:+.1f}%")
    
    if l1_improvement > 30 and grip_improvement > 5:
        print("\n‚úÖ Fine-tuning was SUCCESSFUL! The model learned LIBERO-specific actions.")
    elif l1_improvement > 10:
        print("\n‚ö†Ô∏è Fine-tuning showed MODERATE improvement. Consider more epochs or data.")
    else:
        print("\n‚ùå Fine-tuning showed MINIMAL improvement. Check hyperparameters or data quality.")

In [None]:
# Diagnose direction accuracy issue
if 'ft_gt' in dir() and 'ft_pred' in dir():
    print("="*70)
    print(" DIAGNOSING DIRECTION ACCURACY")
    print("="*70)
    
    # Compare action magnitude distributions
    print("\nüìä ACTION MAGNITUDE ANALYSIS (Position dims 0-2):")
    print("\n  Ground Truth:")
    print(f"    Mean magnitude: {np.abs(ft_gt[:, :3]).mean():.4f}")
    print(f"    Std magnitude:  {np.abs(ft_gt[:, :3]).std():.4f}")
    print(f"    % near zero (|a| < 0.05): {(np.abs(ft_gt[:, :3]) < 0.05).mean()*100:.1f}%")
    
    print("\n  Base Model Predictions:")
    print(f"    Mean magnitude: {np.abs(base_pred[:, :3]).mean():.4f}")
    print(f"    Std magnitude:  {np.abs(base_pred[:, :3]).std():.4f}")
    print(f"    % near zero (|a| < 0.05): {(np.abs(base_pred[:, :3]) < 0.05).mean()*100:.1f}%")
    
    print("\n  Fine-tuned Predictions:")
    print(f"    Mean magnitude: {np.abs(ft_pred[:, :3]).mean():.4f}")
    print(f"    Std magnitude:  {np.abs(ft_pred[:, :3]).std():.4f}")
    print(f"    % near zero (|a| < 0.05): {(np.abs(ft_pred[:, :3]) < 0.05).mean()*100:.1f}%")
    
    # Check for sign inversion pattern
    print("\nüìä SIGN ANALYSIS (for significant movements |gt| > 0.1):")
    significant = np.abs(ft_gt[:, :3]) > 0.1
    
    if significant.sum() > 0:
        gt_signs = np.sign(ft_gt[:, :3][significant])
        base_signs = np.sign(base_pred[:, :3][significant])
        ft_signs = np.sign(ft_pred[:, :3][significant])
        
        base_match = (gt_signs == base_signs).mean() * 100
        ft_match = (gt_signs == ft_signs).mean() * 100
        
        print(f"  Samples with |gt| > 0.1: {significant.sum()}")
        print(f"  Base direction accuracy:      {base_match:.1f}%")
        print(f"  Fine-tuned direction accuracy: {ft_match:.1f}%")
        
        # Check if there's systematic sign inversion
        ft_inverted = (-gt_signs == ft_signs).mean() * 100
        print(f"\n  Fine-tuned INVERTED direction: {ft_inverted:.1f}%")
        
        if ft_inverted > 60:
            print("\n  ‚ö†Ô∏è POSSIBLE SIGN INVERSION DETECTED!")
            print("     The model may have learned inverted actions.")
            print("     Check: Is there a coordinate frame mismatch?")
        elif ft_match < 40:
            print("\n  ‚ö†Ô∏è PREDICTIONS CLUSTER NEAR ZERO")
            print("     The model predicts small magnitudes, making signs unreliable.")
    
    # Per-dimension analysis
    print("\nüìä PER-DIMENSION DIRECTION ACCURACY:")
    dim_names = ['dx', 'dy', 'dz']
    for i, name in enumerate(dim_names):
        sig = np.abs(ft_gt[:, i]) > 0.05
        if sig.sum() > 10:
            gt_s = np.sign(ft_gt[:, i][sig])
            base_s = np.sign(base_pred[:, i][sig])
            ft_s = np.sign(ft_pred[:, i][sig])
            
            base_acc = (gt_s == base_s).mean() * 100
            ft_acc = (gt_s == ft_s).mean() * 100
            ft_inv = (-gt_s == ft_s).mean() * 100
            
            status = "‚úÖ" if ft_acc > base_acc else ("‚ö†Ô∏è INVERTED" if ft_inv > 60 else "‚ùå")
            print(f"  {name}: Base={base_acc:.1f}%, Fine-tuned={ft_acc:.1f}%, Inverted={ft_inv:.1f}% {status}")

## 9. Summary and Next Steps

In [None]:
print("\n" + "="*60)
print(" SUMMARY")
print("="*60)

if finetuned_results:
    print(f"\nüìä Final Metrics:")
    print(f"   L1 Error:         {metrics['L1 Error (Overall)']:.4f}")
    print(f"   Gripper Accuracy: {metrics['Gripper Accuracy (%)']:.1f}%")
    print(f"   Direction Acc:    {metrics['Direction Accuracy (%)']:.1f}%")
    
    print(f"\nüéØ Quality Assessment:")
    
    score = 0
    if metrics['L1 Error (Overall)'] < 0.15:
        score += 3
        print("   ‚úÖ Excellent L1 Error")
    elif metrics['L1 Error (Overall)'] < 0.25:
        score += 2
        print("   ‚ö†Ô∏è Moderate L1 Error - more training may help")
    else:
        score += 1
        print("   ‚ùå High L1 Error - check preprocessing")
    
    if metrics['Gripper Accuracy (%)'] > 90:
        score += 3
        print("   ‚úÖ Excellent Gripper Accuracy")
    elif metrics['Gripper Accuracy (%)'] > 70:
        score += 2
        print("   ‚ö†Ô∏è Moderate Gripper Accuracy")
    else:
        score += 1
        print("   ‚ùå Low Gripper Accuracy - check gripper transform")
    
    if metrics['Direction Accuracy (%)'] > 80:
        score += 3
        print("   ‚úÖ Excellent Direction Accuracy")
    elif metrics['Direction Accuracy (%)'] > 60:
        score += 2
        print("   ‚ö†Ô∏è Moderate Direction Accuracy")
    else:
        score += 1
        print("   ‚ùå Low Direction Accuracy")
    
    print(f"\nüèÜ Overall Score: {score}/9")
    if score >= 7:
        print("   Model is ready for deployment!")
    elif score >= 5:
        print("   Model is learning. Consider more training epochs.")
    else:
        print("   Model needs debugging. Check preprocessing and hyperparameters.")

print("\nüìÅ Saved Files:")
print("   - prediction_visualization.png")
print("   - action_distribution.png")
print("   - error_histogram.png")

In [None]:
def predict_episode_trajectory(model, processor, episode, action_tokenizer):
    """Predict actions for an entire episode."""
    predicted_actions = []
    
    for i, image in enumerate(episode['images']):
        try:
            pred_action, _ = predict_action(
                model, processor, image, episode['instruction'], action_tokenizer
            )
            predicted_actions.append(pred_action)
        except Exception as e:
            # Use zeros if prediction fails
            predicted_actions.append(np.zeros(7))
    
    return np.array(predicted_actions)

def create_trajectory_animation(episode, gt_actions, pred_actions, save_path="trajectory_animation.gif"):
    """Create animated GIF comparing GT vs predicted trajectories."""
    from matplotlib.animation import FuncAnimation, PillowWriter
    from mpl_toolkits.mplot3d import Axes3D
    
    n_frames = len(episode['images'])
    
    # Compute cumulative positions from delta actions
    gt_positions = np.cumsum(gt_actions[:, :3], axis=0)
    pred_positions = np.cumsum(pred_actions[:, :3], axis=0)
    
    # Create figure with subplots
    fig = plt.figure(figsize=(16, 8))
    
    # Subplot 1: Camera view
    ax1 = fig.add_subplot(2, 3, 1)
    ax1.set_title("Camera View")
    ax1.axis('off')
    
    # Subplot 2: 3D trajectory
    ax2 = fig.add_subplot(2, 3, 2, projection='3d')
    ax2.set_title("3D Trajectory")
    ax2.set_xlabel('X')
    ax2.set_ylabel('Y')
    ax2.set_zlabel('Z')
    
    # Subplot 3: Position over time
    ax3 = fig.add_subplot(2, 3, 3)
    ax3.set_title("Position Actions (dx, dy, dz)")
    ax3.set_xlabel("Timestep")
    ax3.set_ylabel("Action Value")
    
    # Subplot 4: Rotation over time
    ax4 = fig.add_subplot(2, 3, 4)
    ax4.set_title("Rotation Actions (rx, ry, rz)")
    ax4.set_xlabel("Timestep")
    ax4.set_ylabel("Action Value")
    
    # Subplot 5: Gripper over time
    ax5 = fig.add_subplot(2, 3, 5)
    ax5.set_title("Gripper Action")
    ax5.set_xlabel("Timestep")
    ax5.set_ylabel("Gripper (1=open, 0=close)")
    
    # Subplot 6: Per-step error
    ax6 = fig.add_subplot(2, 3, 6)
    ax6.set_title("Per-Step L1 Error")
    ax6.set_xlabel("Timestep")
    ax6.set_ylabel("L1 Error")
    
    # Pre-compute trajectory bounds
    all_pos = np.vstack([gt_positions, pred_positions])
    pos_min, pos_max = all_pos.min(axis=0), all_pos.max(axis=0)
    margin = (pos_max - pos_min).max() * 0.1 + 0.1
    
    # Initialize plots
    img_display = ax1.imshow(episode['images'][0])
    
    # 3D trajectory lines
    gt_line_3d, = ax2.plot([], [], [], 'b-', linewidth=2, label='Ground Truth')
    pred_line_3d, = ax2.plot([], [], [], 'r-', linewidth=2, label='Predicted')
    gt_point_3d, = ax2.plot([], [], [], 'bo', markersize=10)
    pred_point_3d, = ax2.plot([], [], [], 'ro', markersize=10)
    ax2.legend(loc='upper left', fontsize=8)
    
    # Set 3D axis limits
    ax2.set_xlim(pos_min[0] - margin, pos_max[0] + margin)
    ax2.set_ylim(pos_min[1] - margin, pos_max[1] + margin)
    ax2.set_zlim(pos_min[2] - margin, pos_max[2] + margin)
    
    # Time series data
    timesteps = np.arange(n_frames)
    per_step_errors = np.abs(pred_actions - gt_actions).mean(axis=1)
    
    def init():
        return []
    
    def animate(frame):
        # Update camera view
        img_display.set_array(episode['images'][frame])
        ax1.set_title(f"Camera View (t={frame}/{n_frames-1})")
        
        # Update 3D trajectory
        ax2.clear()
        ax2.set_title("3D Trajectory")
        ax2.set_xlabel('X')
        ax2.set_ylabel('Y')
        ax2.set_zlabel('Z')
        ax2.set_xlim(pos_min[0] - margin, pos_max[0] + margin)
        ax2.set_ylim(pos_min[1] - margin, pos_max[1] + margin)
        ax2.set_zlim(pos_min[2] - margin, pos_max[2] + margin)
        
        # Plot trajectories up to current frame
        if frame > 0:
            ax2.plot(gt_positions[:frame+1, 0], gt_positions[:frame+1, 1], gt_positions[:frame+1, 2], 
                    'b-', linewidth=2, label='Ground Truth')
            ax2.plot(pred_positions[:frame+1, 0], pred_positions[:frame+1, 1], pred_positions[:frame+1, 2], 
                    'r-', linewidth=2, label='Predicted')
        ax2.scatter([gt_positions[frame, 0]], [gt_positions[frame, 1]], [gt_positions[frame, 2]], 
                   c='blue', s=100, marker='o')
        ax2.scatter([pred_positions[frame, 0]], [pred_positions[frame, 1]], [pred_positions[frame, 2]], 
                   c='red', s=100, marker='o')
        ax2.legend(loc='upper left', fontsize=8)
        
        # Update position plot
        ax3.clear()
        ax3.set_title("Position Actions (dx, dy, dz)")
        ax3.set_xlabel("Timestep")
        ax3.set_ylabel("Action Value")
        ax3.set_xlim(0, n_frames)
        ax3.set_ylim(-1.1, 1.1)
        for i, (color, label) in enumerate(zip(['r', 'g', 'b'], ['dx', 'dy', 'dz'])):
            ax3.plot(timesteps[:frame+1], gt_actions[:frame+1, i], f'{color}-', alpha=0.7, label=f'GT {label}')
            ax3.plot(timesteps[:frame+1], pred_actions[:frame+1, i], f'{color}--', alpha=0.7, label=f'Pred {label}')
        ax3.axvline(x=frame, color='gray', linestyle=':', alpha=0.5)
        ax3.legend(loc='upper right', fontsize=6, ncol=2)
        ax3.grid(True, alpha=0.3)
        
        # Update rotation plot
        ax4.clear()
        ax4.set_title("Rotation Actions (rx, ry, rz)")
        ax4.set_xlabel("Timestep")
        ax4.set_ylabel("Action Value")
        ax4.set_xlim(0, n_frames)
        ax4.set_ylim(-1.1, 1.1)
        for i, (color, label) in enumerate(zip(['r', 'g', 'b'], ['rx', 'ry', 'rz'])):
            ax4.plot(timesteps[:frame+1], gt_actions[:frame+1, i+3], f'{color}-', alpha=0.7, label=f'GT {label}')
            ax4.plot(timesteps[:frame+1], pred_actions[:frame+1, i+3], f'{color}--', alpha=0.7, label=f'Pred {label}')
        ax4.axvline(x=frame, color='gray', linestyle=':', alpha=0.5)
        ax4.legend(loc='upper right', fontsize=6, ncol=2)
        ax4.grid(True, alpha=0.3)
        
        # Update gripper plot
        ax5.clear()
        ax5.set_title("Gripper Action (1=open, 0=close)")
        ax5.set_xlabel("Timestep")
        ax5.set_ylabel("Gripper Value")
        ax5.set_xlim(0, n_frames)
        ax5.set_ylim(-0.1, 1.1)
        ax5.plot(timesteps[:frame+1], gt_actions[:frame+1, 6], 'b-', linewidth=2, label='GT Gripper')
        ax5.plot(timesteps[:frame+1], pred_actions[:frame+1, 6], 'r--', linewidth=2, label='Pred Gripper')
        ax5.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5, label='Threshold')
        ax5.axvline(x=frame, color='gray', linestyle=':', alpha=0.5)
        ax5.legend(loc='upper right', fontsize=8)
        ax5.grid(True, alpha=0.3)
        
        # Update error plot
        ax6.clear()
        ax6.set_title(f"Per-Step L1 Error (current: {per_step_errors[frame]:.3f})")
        ax6.set_xlabel("Timestep")
        ax6.set_ylabel("L1 Error")
        ax6.set_xlim(0, n_frames)
        ax6.set_ylim(0, per_step_errors.max() * 1.1 + 0.01)
        ax6.bar(timesteps[:frame+1], per_step_errors[:frame+1], color='purple', alpha=0.7)
        ax6.axhline(y=per_step_errors.mean(), color='orange', linestyle='--', label=f'Avg: {per_step_errors.mean():.3f}')
        ax6.axvline(x=frame, color='gray', linestyle=':', alpha=0.5)
        ax6.legend(loc='upper right', fontsize=8)
        ax6.grid(True, alpha=0.3)
        
        plt.tight_layout()
        return []
    
    # Create animation
    anim = FuncAnimation(fig, animate, init_func=init, frames=n_frames, interval=200, blit=False)
    
    # Save as GIF
    print(f"Saving animation to {save_path}...")
    writer = PillowWriter(fps=5)
    anim.save(save_path, writer=writer)
    plt.close()
    
    print(f"‚úÖ Animation saved to: {save_path}")
    return save_path

## 10. Trajectory Animation: Ground Truth vs Prediction

This section creates animated visualizations comparing:
1. **Ground Truth trajectory** (what the robot actually did)
2. **Predicted trajectory** (what the model thinks should happen)

The animation shows:
- Robot camera view (image frames)
- 3D position trajectory (x, y, z)
- Action values over time
- Gripper state changes

In [None]:
# Display animations inline in the notebook
from IPython.display import Image, display, HTML

if 'animation_files' in dir() and animation_files:
    print("="*70)
    print(" DISPLAYING TRAJECTORY ANIMATIONS")
    print("="*70)
    print("\nüé¨ Showing animated comparisons of Ground Truth (blue) vs Predicted (red):\n")
    
    for i, gif_path in enumerate(animation_files):
        if os.path.exists(gif_path):
            print(f"\n--- Episode {i+1} ---")
            display(Image(filename=gif_path))
        else:
            print(f"Animation file not found: {gif_path}")
else:
    print("No animations to display. Run the previous cell first.")

In [None]:
# Generate trajectory animations for a few episodes
# This shows how well the model predicts actions over an entire episode

if finetuned_model and len(episodes) > 0:
    print("="*70)
    print(" GENERATING TRAJECTORY ANIMATIONS")
    print("="*70)
    print("\nThis creates animated GIFs comparing Ground Truth vs Predicted trajectories.")
    print("Each animation shows an entire episode with the model's predictions.\n")
    
    # Select a few diverse episodes
    n_animations = min(3, len(episodes))
    selected_indices = np.linspace(0, len(episodes)-1, n_animations, dtype=int)
    
    animation_files = []
    
    for i, idx in enumerate(selected_indices):
        episode = episodes[idx]
        print(f"\n[{i+1}/{n_animations}] Episode: {episode['demo']} from {episode['file']}")
        print(f"    Instruction: {episode['instruction'][:50]}...")
        print(f"    Length: {len(episode['images'])} steps")
        
        # Get ground truth actions
        gt_actions_ep = episode['actions']
        
        # Predict actions for this episode
        print(f"    Predicting actions...")
        pred_actions_ep = predict_episode_trajectory(
            finetuned_model, processor, episode, action_tokenizer
        )
        
        # Compute episode metrics
        episode_l1 = np.abs(pred_actions_ep - gt_actions_ep).mean()
        gt_gripper = gt_actions_ep[:, 6] > 0.5
        pred_gripper = pred_actions_ep[:, 6] > 0.5
        gripper_acc = (gt_gripper == pred_gripper).mean() * 100
        
        print(f"    Episode L1 Error: {episode_l1:.4f}")
        print(f"    Episode Gripper Accuracy: {gripper_acc:.1f}%")
        
        # Create animation
        save_path = f"trajectory_episode_{i+1}.gif"
        create_trajectory_animation(episode, gt_actions_ep, pred_actions_ep, save_path)
        animation_files.append(save_path)
    
    print("\n" + "="*70)
    print(" ANIMATIONS COMPLETE")
    print("="*70)
    print("\nüìÅ Generated animation files:")
    for f in animation_files:
        print(f"   - {f}")
    print("\nüí° Open these GIF files to see the animated comparison!")
else:
    print("No fine-tuned model or episodes available for animation.")