# OpenVLA Fine-tuned Model Evaluation: Understanding What the Model Learned

This notebook helps you understand:
1. **What do the metrics mean?** (Loss, L1 Error explained)
2. **What does the model actually predict?** (Side-by-side comparison)
3. **How accurate is each action dimension?** (Position, rotation, gripper)
4. **Is the model making reasonable predictions?** (Visualizations)

## Metrics Explained

| Metric | What it measures | Good value |
|--------|------------------|------------|
| **Loss** | How wrong the predicted tokens are (cross-entropy) | < 2.0 |
| **L1 Error** | Average absolute difference between predicted and true actions | < 0.15 |
| **Gripper Accuracy** | % of times gripper open/close is correct | > 90% |
| **Direction Accuracy** | % of times movement direction is correct | > 80% |

### What does L1 Error = 0.15 mean?
- Actions are in range [-1, 1]
- L1 = 0.15 means predictions are off by ~15% on average
- For a robot moving 10cm, that's ~1.5cm error per step

In [None]:
# Setup
import os
import sys
import numpy as np
import torch
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
import h5py
from tqdm import tqdm

# Configuration
if 'SCRATCH' in os.environ:
    BASE_DIR = os.environ['SCRATCH']
else:
    BASE_DIR = "/home/idies/workspace/Temporary/dpark1/scratch"

CACHE_DIR = f"{BASE_DIR}/.cache"
LIBERO_DATA_DIR = f"{BASE_DIR}/libero_data"
CHECKPOINT_DIR = f"{BASE_DIR}/openvla_finetuned"

os.environ['HF_HOME'] = f"{CACHE_DIR}/huggingface"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import warnings
warnings.filterwarnings('ignore')

print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"LIBERO data directory: {LIBERO_DATA_DIR}")

In [None]:
# Find available checkpoints
checkpoint_path = Path(CHECKPOINT_DIR)
if checkpoint_path.exists():
    runs = sorted(checkpoint_path.glob("libero_*"))
    print("Available fine-tuning runs:")
    for i, run in enumerate(runs):
        print(f"  [{i}] {run.name}")
        # Check for best/final
        if (run / "best").exists():
            print(f"      - best/ (recommended)")
        if (run / "final").exists():
            print(f"      - final/")
else:
    print(f"No checkpoints found at {CHECKPOINT_DIR}")
    print("Please run fine-tuning first.")

In [None]:
# ============================================================
# SELECT YOUR CHECKPOINT HERE
# ============================================================
# Option 1: Use the latest run's best model
if runs:
    SELECTED_RUN = runs[-1]  # Latest run
    if (SELECTED_RUN / "best").exists():
        LORA_CHECKPOINT = str(SELECTED_RUN / "best")
    elif (SELECTED_RUN / "final").exists():
        LORA_CHECKPOINT = str(SELECTED_RUN / "final")
    else:
        # Find latest checkpoint-XXXX
        checkpoints = sorted(SELECTED_RUN.glob("checkpoint-*"))
        LORA_CHECKPOINT = str(checkpoints[-1]) if checkpoints else None
    
    print(f"Selected checkpoint: {LORA_CHECKPOINT}")
else:
    LORA_CHECKPOINT = None
    print("No checkpoint selected. Will evaluate base model only.")

# Option 2: Manual override (uncomment and set path)
# LORA_CHECKPOINT = "/path/to/your/checkpoint"

## 1. Load Models (Base vs Fine-tuned)

In [None]:
from transformers import AutoModelForVision2Seq, AutoProcessor

print("Loading base OpenVLA model...")
base_model = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    cache_dir=f"{CACHE_DIR}/huggingface",
    low_cpu_mem_usage=True,
)

processor = AutoProcessor.from_pretrained(
    "openvla/openvla-7b",
    trust_remote_code=True,
    cache_dir=f"{CACHE_DIR}/huggingface",
)

print(f"Base model loaded. Device: {base_model.device}")

In [None]:
# Load fine-tuned model (if checkpoint exists)
if LORA_CHECKPOINT:
    from peft import PeftModel
    
    print(f"Loading LoRA weights from: {LORA_CHECKPOINT}")
    finetuned_model = PeftModel.from_pretrained(
        base_model,
        LORA_CHECKPOINT,
        torch_dtype=torch.bfloat16,
    )
    finetuned_model.eval()
    print("Fine-tuned model loaded!")
else:
    finetuned_model = None
    print("No fine-tuned model - will compare base model predictions to ground truth.")

## 2. Action Tokenizer (for decoding predictions)

In [None]:
class ActionTokenizer:
    """Decode action tokens back to continuous values."""
    
    def __init__(self, vocab_size=32000, n_bins=256):
        self.vocab_size = vocab_size
        self.n_bins = n_bins
        self.bins = np.linspace(-1, 1, n_bins)
        self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2
        self.action_token_start = vocab_size - n_bins  # 31744
        self.action_token_end = vocab_size - 1  # 31999
    
    def decode(self, token_ids):
        """Convert token IDs to continuous actions."""
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.cpu().numpy()
        discretized = self.vocab_size - token_ids
        indices = np.clip(discretized - 1, 0, len(self.bin_centers) - 1)
        return self.bin_centers[indices]
    
    def is_action_token(self, token_id):
        """Check if token is an action token."""
        return self.action_token_start <= token_id <= self.action_token_end

action_tokenizer = ActionTokenizer(vocab_size=len(processor.tokenizer))
print(f"Action tokens range: [{action_tokenizer.action_token_start}, {action_tokenizer.action_token_end}]")

## 3. Load LIBERO Test Data

In [None]:
def load_libero_samples(data_dir, suite="libero_spatial", n_samples=50):
    """Load random samples from LIBERO for evaluation."""
    data_path = Path(data_dir)
    
    # Find HDF5 files
    hdf5_files = list(data_path.rglob("*.hdf5"))
    if not hdf5_files:
        print(f"No HDF5 files found in {data_dir}")
        return []
    
    print(f"Found {len(hdf5_files)} HDF5 files")
    
    samples = []
    for filepath in hdf5_files:
        try:
            with h5py.File(filepath, 'r') as f:
                # Get instruction
                instruction = "complete the task"
                for key in ['language_instruction', 'problem_info', 'language']:
                    if key in f.attrs:
                        inst = f.attrs[key]
                        if isinstance(inst, bytes):
                            inst = inst.decode('utf-8')
                        instruction = inst
                        break
                
                if 'data' not in f:
                    continue
                
                demo_keys = [k for k in f['data'].keys() if k.startswith('demo_')]
                
                for demo_key in demo_keys[:2]:  # 2 demos per file
                    demo = f['data'][demo_key]
                    if 'actions' not in demo or 'obs' not in demo:
                        continue
                    
                    # Find image key
                    img_key = None
                    for k in ['agentview_rgb', 'agentview_image', 'rgb', 'image']:
                        if k in demo['obs']:
                            img_key = k
                            break
                    if img_key is None:
                        continue
                    
                    n_steps = len(demo['actions'])
                    # Sample a few timesteps per demo
                    timesteps = np.random.choice(n_steps, min(3, n_steps), replace=False)
                    
                    for t in timesteps:
                        image = demo['obs'][img_key][t]
                        image = np.rot90(image, k=2)  # 180 degree rotation
                        
                        action = demo['actions'][t]
                        if len(action) < 7:
                            action = np.pad(action, (0, 7 - len(action)))
                        else:
                            action = action[:7]
                        
                        # Apply official LIBERO transform
                        action = action.astype(np.float32)
                        action[:6] = np.clip(action[:6], -1.0, 1.0)
                        gripper = np.clip(action[6], 0.0, 1.0)
                        action[6] = 1.0 - gripper  # Invert gripper
                        
                        samples.append({
                            'image': image,
                            'instruction': instruction,
                            'action': action,
                            'file': filepath.name,
                            'demo': demo_key,
                            'timestep': t,
                        })
                        
                        if len(samples) >= n_samples:
                            return samples
        except Exception as e:
            print(f"Error reading {filepath}: {e}")
    
    return samples

# Load samples
print("\nLoading LIBERO samples for evaluation...")
test_samples = load_libero_samples(LIBERO_DATA_DIR, n_samples=100)
print(f"Loaded {len(test_samples)} test samples")

if test_samples:
    print(f"\nSample instructions:")
    unique_instructions = set(s['instruction'] for s in test_samples)
    for inst in list(unique_instructions)[:5]:
        print(f"  - {inst[:60]}...")

## 4. Predict Actions

In [None]:
def predict_action(model, processor, image, instruction, action_tokenizer):
    """Predict 7-DoF action from image and instruction."""
    # Create prompt
    prompt = f"In: What action should the robot take to {instruction.lower()}?\nOut:"
    
    # Convert image to PIL
    if isinstance(image, np.ndarray):
        pil_image = Image.fromarray(image.astype(np.uint8))
    else:
        pil_image = image
    
    # Resize to 224x224
    if pil_image.size != (224, 224):
        pil_image = pil_image.resize((224, 224), Image.LANCZOS)
    
    # Process inputs
    inputs = processor(prompt, pil_image, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=7,
            do_sample=False,
            pad_token_id=processor.tokenizer.pad_token_id,
        )
    
    # Extract action tokens (last 7)
    action_tokens = outputs[0, -7:].cpu().numpy()
    
    # Decode to continuous actions
    action = action_tokenizer.decode(action_tokens)
    
    return action, action_tokens

# Test prediction
if test_samples:
    sample = test_samples[0]
    model_to_test = finetuned_model if finetuned_model else base_model
    pred_action, pred_tokens = predict_action(
        model_to_test, processor, sample['image'], sample['instruction'], action_tokenizer
    )
    print("Test prediction:")
    print(f"  Instruction: {sample['instruction'][:50]}...")
    print(f"  Ground truth: {sample['action']}")
    print(f"  Prediction:   {pred_action}")
    print(f"  Token IDs:    {pred_tokens}")

## 5. Run Full Evaluation

In [None]:
def evaluate_model(model, processor, samples, action_tokenizer, model_name="Model"):
    """Evaluate model on samples and compute metrics."""
    results = []
    
    print(f"\nEvaluating {model_name} on {len(samples)} samples...")
    
    for sample in tqdm(samples):
        try:
            pred_action, pred_tokens = predict_action(
                model, processor, sample['image'], sample['instruction'], action_tokenizer
            )
            
            gt_action = sample['action']
            
            results.append({
                'gt_action': gt_action,
                'pred_action': pred_action,
                'instruction': sample['instruction'],
                'image': sample['image'],
            })
        except Exception as e:
            print(f"Error: {e}")
    
    return results

# Evaluate fine-tuned model
if finetuned_model and test_samples:
    finetuned_results = evaluate_model(
        finetuned_model, processor, test_samples, action_tokenizer, "Fine-tuned"
    )
elif test_samples:
    print("No fine-tuned model. Evaluating base model...")
    finetuned_results = evaluate_model(
        base_model, processor, test_samples, action_tokenizer, "Base"
    )

## 6. Compute Interpretable Metrics

In [None]:
def compute_metrics(results):
    """Compute interpretable metrics from evaluation results."""
    gt_actions = np.array([r['gt_action'] for r in results])
    pred_actions = np.array([r['pred_action'] for r in results])
    
    metrics = {}
    
    # 1. Overall L1 Error (lower is better)
    l1_error = np.abs(pred_actions - gt_actions).mean()
    metrics['L1 Error (Overall)'] = l1_error
    
    # 2. Per-dimension L1 Error
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'gripper']
    for i, name in enumerate(dim_names):
        metrics[f'L1 Error ({name})'] = np.abs(pred_actions[:, i] - gt_actions[:, i]).mean()
    
    # 3. Position Error (dims 0-2)
    pos_error = np.abs(pred_actions[:, :3] - gt_actions[:, :3]).mean()
    metrics['Position Error (xyz)'] = pos_error
    
    # 4. Rotation Error (dims 3-5)
    rot_error = np.abs(pred_actions[:, 3:6] - gt_actions[:, 3:6]).mean()
    metrics['Rotation Error (rpy)'] = rot_error
    
    # 5. Gripper Accuracy (is open/close correct?)
    # Gripper > 0.5 = open, < 0.5 = close
    gt_gripper_open = gt_actions[:, 6] > 0.5
    pred_gripper_open = pred_actions[:, 6] > 0.5
    gripper_accuracy = (gt_gripper_open == pred_gripper_open).mean() * 100
    metrics['Gripper Accuracy (%)'] = gripper_accuracy
    
    # 6. Direction Accuracy (is movement direction correct?)
    # Check if sign matches for position
    gt_signs = np.sign(gt_actions[:, :3])
    pred_signs = np.sign(pred_actions[:, :3])
    # Only count where ground truth is not ~0
    significant_movement = np.abs(gt_actions[:, :3]) > 0.05
    if significant_movement.sum() > 0:
        direction_accuracy = (gt_signs[significant_movement] == pred_signs[significant_movement]).mean() * 100
    else:
        direction_accuracy = 0
    metrics['Direction Accuracy (%)'] = direction_accuracy
    
    # 7. Action magnitude correlation
    gt_magnitude = np.linalg.norm(gt_actions[:, :3], axis=1)
    pred_magnitude = np.linalg.norm(pred_actions[:, :3], axis=1)
    if gt_magnitude.std() > 0:
        correlation = np.corrcoef(gt_magnitude, pred_magnitude)[0, 1]
    else:
        correlation = 0
    metrics['Magnitude Correlation'] = correlation
    
    return metrics, gt_actions, pred_actions

# Compute metrics
if finetuned_results:
    metrics, gt_actions, pred_actions = compute_metrics(finetuned_results)
    
    print("\n" + "="*60)
    print(" EVALUATION RESULTS")
    print("="*60)
    
    print("\nüìä OVERALL METRICS:")
    print(f"  L1 Error:           {metrics['L1 Error (Overall)']:.4f}  (lower is better, target < 0.15)")
    print(f"  Gripper Accuracy:   {metrics['Gripper Accuracy (%)']:.1f}%  (higher is better, target > 90%)")
    print(f"  Direction Accuracy: {metrics['Direction Accuracy (%)']:.1f}%  (higher is better, target > 80%)")
    print(f"  Magnitude Corr:     {metrics['Magnitude Correlation']:.3f}  (higher is better, target > 0.7)")
    
    print("\nüìè PER-DIMENSION L1 ERROR:")
    print(f"  Position (x,y,z):  {metrics['Position Error (xyz)']:.4f}")
    print(f"  Rotation (r,p,y):  {metrics['Rotation Error (rpy)']:.4f}")
    print(f"  Gripper:           {metrics['L1 Error (gripper)']:.4f}")
    
    print("\nüìà INTERPRETATION:")
    if metrics['L1 Error (Overall)'] < 0.15:
        print("  ‚úÖ L1 Error is GOOD - model is making accurate predictions")
    elif metrics['L1 Error (Overall)'] < 0.25:
        print("  ‚ö†Ô∏è L1 Error is MODERATE - model is learning but could improve")
    else:
        print("  ‚ùå L1 Error is HIGH - model needs more training or debugging")
    
    if metrics['Gripper Accuracy (%)'] > 90:
        print("  ‚úÖ Gripper Accuracy is EXCELLENT")
    elif metrics['Gripper Accuracy (%)'] > 70:
        print("  ‚ö†Ô∏è Gripper Accuracy is OK but could improve")
    else:
        print("  ‚ùå Gripper Accuracy is LOW - check gripper transform")

## 7. Visualize Predictions vs Ground Truth

In [None]:
def visualize_predictions(results, n_samples=6):
    """Visualize side-by-side predictions vs ground truth."""
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    indices = np.random.choice(len(results), min(n_samples, len(results)), replace=False)
    
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'grip']
    
    for idx, ax in zip(indices, axes):
        result = results[idx]
        
        # Show image
        ax.imshow(result['image'])
        ax.axis('off')
        
        # Add action comparison as text
        gt = result['gt_action']
        pred = result['pred_action']
        
        text = f"Instruction: {result['instruction'][:30]}...\n\n"
        text += "       GT    Pred   Err\n"
        for i, name in enumerate(dim_names):
            err = abs(gt[i] - pred[i])
            color = '‚úì' if err < 0.15 else '‚úó'
            text += f"{name:5}: {gt[i]:+.2f}  {pred[i]:+.2f}  {color}\n"
        
        ax.set_title(text, fontsize=8, family='monospace', loc='left')
    
    plt.tight_layout()
    plt.savefig('prediction_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved to: prediction_visualization.png")

if finetuned_results:
    visualize_predictions(finetuned_results)

In [None]:
def plot_action_distribution(gt_actions, pred_actions):
    """Plot distribution of predicted vs ground truth actions."""
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'gripper']
    
    for i, (ax, name) in enumerate(zip(axes.flatten()[:7], dim_names)):
        ax.scatter(gt_actions[:, i], pred_actions[:, i], alpha=0.5, s=20)
        ax.plot([-1, 1], [-1, 1], 'r--', label='Perfect')
        ax.set_xlabel(f'Ground Truth {name}')
        ax.set_ylabel(f'Predicted {name}')
        ax.set_title(f'{name}: corr={np.corrcoef(gt_actions[:, i], pred_actions[:, i])[0,1]:.3f}')
        ax.set_xlim(-1.1, 1.1)
        ax.set_ylim(-1.1, 1.1)
        ax.grid(True, alpha=0.3)
    
    # Hide last subplot
    axes.flatten()[-1].axis('off')
    
    plt.suptitle('Predicted vs Ground Truth Actions (each point = one sample)', fontsize=14)
    plt.tight_layout()
    plt.savefig('action_distribution.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved to: action_distribution.png")

if finetuned_results:
    plot_action_distribution(gt_actions, pred_actions)

In [None]:
def plot_error_histogram(gt_actions, pred_actions):
    """Plot histogram of prediction errors."""
    errors = pred_actions - gt_actions
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    dim_names = ['dx', 'dy', 'dz', 'rx', 'ry', 'rz', 'gripper']
    
    for i, (ax, name) in enumerate(zip(axes.flatten()[:7], dim_names)):
        ax.hist(errors[:, i], bins=30, edgecolor='black', alpha=0.7)
        ax.axvline(x=0, color='r', linestyle='--', label='Zero error')
        ax.axvline(x=errors[:, i].mean(), color='g', linestyle='-', label=f'Mean: {errors[:, i].mean():.3f}')
        ax.set_xlabel(f'Error in {name}')
        ax.set_ylabel('Count')
        ax.set_title(f'{name}: std={errors[:, i].std():.3f}')
        ax.legend(fontsize=8)
    
    axes.flatten()[-1].axis('off')
    
    plt.suptitle('Distribution of Prediction Errors (0 = perfect)', fontsize=14)
    plt.tight_layout()
    plt.savefig('error_histogram.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Saved to: error_histogram.png")

if finetuned_results:
    plot_error_histogram(gt_actions, pred_actions)

## 8. Compare Base vs Fine-tuned (Optional)

In [None]:
# Compare base model vs fine-tuned model
if finetuned_model and len(test_samples) > 0:
    print("Evaluating BASE model for comparison...")
    base_results = evaluate_model(
        base_model, processor, test_samples[:30], action_tokenizer, "Base"
    )
    
    base_metrics, base_gt, base_pred = compute_metrics(base_results)
    
    print("\n" + "="*60)
    print(" BASE vs FINE-TUNED COMPARISON")
    print("="*60)
    print(f"{'Metric':<25} {'Base':>12} {'Fine-tuned':>12} {'Improvement':>12}")
    print("-"*60)
    
    for key in ['L1 Error (Overall)', 'Gripper Accuracy (%)', 'Direction Accuracy (%)']:
        base_val = base_metrics[key]
        ft_val = metrics[key]
        
        if 'Error' in key:
            improvement = base_val - ft_val  # Lower is better
            better = "‚úÖ" if improvement > 0 else "‚ùå"
        else:
            improvement = ft_val - base_val  # Higher is better
            better = "‚úÖ" if improvement > 0 else "‚ùå"
        
        print(f"{key:<25} {base_val:>12.3f} {ft_val:>12.3f} {improvement:>+10.3f} {better}")

## 9. Summary and Next Steps

In [None]:
print("\n" + "="*60)
print(" SUMMARY")
print("="*60)

if finetuned_results:
    print(f"\nüìä Final Metrics:")
    print(f"   L1 Error:         {metrics['L1 Error (Overall)']:.4f}")
    print(f"   Gripper Accuracy: {metrics['Gripper Accuracy (%)']:.1f}%")
    print(f"   Direction Acc:    {metrics['Direction Accuracy (%)']:.1f}%")
    
    print(f"\nüéØ Quality Assessment:")
    
    score = 0
    if metrics['L1 Error (Overall)'] < 0.15:
        score += 3
        print("   ‚úÖ Excellent L1 Error")
    elif metrics['L1 Error (Overall)'] < 0.25:
        score += 2
        print("   ‚ö†Ô∏è Moderate L1 Error - more training may help")
    else:
        score += 1
        print("   ‚ùå High L1 Error - check preprocessing")
    
    if metrics['Gripper Accuracy (%)'] > 90:
        score += 3
        print("   ‚úÖ Excellent Gripper Accuracy")
    elif metrics['Gripper Accuracy (%)'] > 70:
        score += 2
        print("   ‚ö†Ô∏è Moderate Gripper Accuracy")
    else:
        score += 1
        print("   ‚ùå Low Gripper Accuracy - check gripper transform")
    
    if metrics['Direction Accuracy (%)'] > 80:
        score += 3
        print("   ‚úÖ Excellent Direction Accuracy")
    elif metrics['Direction Accuracy (%)'] > 60:
        score += 2
        print("   ‚ö†Ô∏è Moderate Direction Accuracy")
    else:
        score += 1
        print("   ‚ùå Low Direction Accuracy")
    
    print(f"\nüèÜ Overall Score: {score}/9")
    if score >= 7:
        print("   Model is ready for deployment!")
    elif score >= 5:
        print("   Model is learning. Consider more training epochs.")
    else:
        print("   Model needs debugging. Check preprocessing and hyperparameters.")

print("\nüìÅ Saved Files:")
print("   - prediction_visualization.png")
print("   - action_distribution.png")
print("   - error_histogram.png")