# VITRA Inference Demo

This notebook allows for interactive inference with the VITRA model. You can modify inputs and re-run inference without reloading the model.

In [None]:
# 1. Environment Setup & Imports
import os
import sys

# Change working directory to project root
# This is crucial for relative paths (like weights loading) to work correctly
try:
    # Assuming notebook is in my_demo/ folder, root is one level up
    project_root = os.path.abspath("..")
    os.chdir(project_root)
    print(f"Changed working directory to: {os.getcwd()}")
    
    # Add project root to path
    if project_root not in sys.path:
        sys.path.append(project_root)
except Exception as e:
    print(f"Error setting working directory: {e}")

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

import json
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from vitra.models import VITRA_Paligemma, load_model
from vitra.utils.data_utils import resize_short_side_to_target, load_normalizer
from vitra.datasets.human_dataset import pad_state_human, pad_action
from vitra.utils.config_utils import load_config
from vitra.datasets.dataset_utils import (
    ActionFeature,
    StateFeature,
)

print("Libraries imported successfully.")

In [None]:
# 2. Load Model & Config (Run Once)
# This step loads the heavy model weights to GPU.

# Load configs
configs = load_config('VITRA-VLA/VITRA-VLA-3B')

# Override config if provided
pretrained_path = 'VITRA-VLA/VITRA-VLA-3B'
statistics_path = 'VITRA-VLA/VITRA-VLA-3B'
configs['model_load_path'] = pretrained_path
configs['statistics_path'] = statistics_path

print("Loading model... (this may take a minute)")
# Load model and normalizer
model = load_model(configs).cuda()
model.eval()

normalizer = load_normalizer(configs)
print("Model loaded successfully!")

In [None]:
# 3. Load & Process Image
image_path = "image.png"  # Relative path in my_demo folder

try:
    # Handle path relative to notebook location
    if not os.path.exists(image_path):
        # Try looking in parent directory or absolute path
        # Adjust this if your image is elsewhere
        print(f"Image not found at {image_path}, checking current dir...")
        image_path = "image.png"
        
    image_pil = Image.open(image_path)
    image_resized = resize_short_side_to_target(image_pil, target=224)
    
    # Convert to numpy for model input
    image_np = np.array(image_resized)
    print(f"[DEBUG] Image shape: {image_np.shape}")
    
    # Display image
    plt.imshow(image_np)
    plt.axis('off')
    plt.show()

except Exception as e:
    print(f"Error loading image: {e}")
    # Create dummy image for testing if file missing
    image_np = np.zeros((224, 224, 3), dtype=np.uint8)
    print("Using dummy black image.")

In [None]:
# 4. Prepare Input Data (Edit this cell to change inputs)

# --- Parameters ---
instruction = "Left hand: None. Right hand: Pick up the phone on the table."
fov_deg = 60.0

# State Construction
# Total dimension: 122
state = np.zeros((normalizer.state_mean.shape[0],)) 
print(f"[DEBUG] Initial State shape: {state.shape}")

# Masks
# state_mask: [Left, Right]
state_mask = np.array([False, True], dtype=bool)
print(f"[DEBUG] State Mask shape: {state_mask.shape}")

# action_mask: [Time, 2]
# action_mask[:, 0] = Left, action_mask[:, 1] = Right
action_mask = np.tile(np.array([[False, True]], dtype=bool), (model.chunk_size, 1))  
print(f"[DEBUG] Action Mask shape: {action_mask.shape}")

# --- Preprocessing ---
# FOV to tensor
fov = torch.tensor([[np.deg2rad(fov_deg), np.deg2rad(fov_deg)]], dtype=torch.float32).cuda()

# Normalize state
norm_state = normalizer.normalize_state(state)
print(f"[DEBUG] Normalized State shape: {norm_state.shape}")

# Unified dimensions
unified_action_dim = ActionFeature.ALL_FEATURES[1]   # 192
unified_state_dim = StateFeature.ALL_FEATURES[1]     # 212

# Padding
unified_state, unified_state_mask = pad_state_human(
    state = norm_state,
    state_mask = state_mask,
    action_dim = normalizer.action_mean.shape[0],
    state_dim = normalizer.state_mean.shape[0],
    unified_state_dim = unified_state_dim,
)
_, unified_action_mask = pad_action(
    actions=None,
    action_mask=action_mask,
    action_dim=normalizer.action_mean.shape[0],
    unified_action_dim=unified_action_dim
)

print(f"[DEBUG] Unified State shape: {unified_state.shape}")
print(f"[DEBUG] Unified State Mask shape: {unified_state_mask.shape}")
print(f"[DEBUG] Unified Action Mask shape: {unified_action_mask.shape}")

In [None]:
# 5. Run Inference
print(f"Running inference for instruction: '{instruction}'")

with torch.no_grad():
    norm_action = model.predict_action(
        image = image_np,
        instruction = instruction,
        current_state = unified_state.unsqueeze(0), # Add batch dim
        current_state_mask = unified_state_mask.unsqueeze(0),
        action_mask_torch = unified_action_mask.unsqueeze(0),
        num_ddim_steps = 10,
        cfg_scale = 5.0,
        fov = fov,
        sample_times = 1
    )

print(f"[DEBUG] Raw Model Output shape: {norm_action.shape}")

# Extract valid action part (first 102 dims)
# Output shape is [Batch, Time, Dim]
valid_norm_action = norm_action[0, :, :102]

# Denormalize
unnorm_action = normalizer.unnormalize_action(valid_norm_action)

print(f"[DEBUG] Final Unnormalized Action shape: {unnorm_action.shape}")

In [None]:
# 6. Results Analysis
print("Predicted Action (First 5 steps):")
print(unnorm_action[:5])

# Optional: Plot trajectory of wrist position (indices 0-3 for left, 51-54 for right)
traj_left = unnorm_action[:, 0:3]
traj_right = unnorm_action[:, 51:54]

fig = plt.figure(figsize=(12, 5))

# Right Hand Trajectory
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot(traj_right[:, 0], traj_right[:, 1], traj_right[:, 2], 'r-', label='Right Hand')
ax1.set_title("Right Hand Wrist Trajectory")
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
ax1.set_zlabel('Z')

# Left Hand Trajectory
ax2 = fig.add_subplot(122, projection='3d')
ax2.plot(traj_left[:, 0], traj_left[:, 1], traj_left[:, 2], 'b-', label='Left Hand')
ax2.set_title("Left Hand Wrist Trajectory")
ax2.set_xlabel('X')
ax2.set_ylabel('Y')
ax2.set_zlabel('Z')

plt.tight_layout()
plt.show()