# Hand VLA Inference and Visualization

This notebook provides an interactive interface for hand action prediction and visualization using the VITRA model.

In [None]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import sys
import cv2
import math
import json
import torch
import numpy as np
from PIL import Image, ImageOps
from pathlib import Path
import multiprocessing as mp
from scipy.spatial.transform import Rotation as R

# Add project root to sys.path
# Assuming the notebook is in scripts/
repo_root = Path(os.getcwd()).parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

from vitra.models import VITRA_Paligemma, load_model
from vitra.utils.data_utils import resize_short_side_to_target, load_normalizer, recon_traj
from vitra.utils.config_utils import load_config
from vitra.datasets.human_dataset import pad_state_human, pad_action
from vitra.datasets.dataset_utils import (
    compute_new_intrinsics_resize, 
    calculate_fov,
    ActionFeature,
    StateFeature,
)

from visualization.visualize_core import HandVisualizer, normalize_camera_intrinsics, save_to_video, Renderer, process_single_hand_labels
from visualization.visualize_core import Config as HandConfig

## Helper Functions

In [None]:
def get_state(hand_data, hand_side='right'):
    """
    Load and extract hand state from hand data.
    """
    if hand_side not in ['left', 'right']:
        raise ValueError(f"hand_side must be 'left' or 'right', got '{hand_side}'")
    
    hand_pose_t0 = hand_data[hand_side][0]['hand_pose']
    hand_pose_t0_euler = R.from_matrix(hand_pose_t0).as_euler('xyz', degrees=False) # [15, 3]
    hand_pose_t0_euler = hand_pose_t0_euler.reshape(-1)  # [45]
    global_orient_mat_t0 = hand_data[hand_side][0]['global_orient']
    R_t0_euler = R.from_matrix(global_orient_mat_t0).as_euler('xyz', degrees=False)  # [3]
    transl_t0 = hand_data[hand_side][0]['transl']  # [3]
    state_t0 = np.concatenate([transl_t0, R_t0_euler, hand_pose_t0_euler])  # [3+3+45=51]
    fov_x = hand_data['fov_x']

    return state_t0, hand_data[hand_side][0]['beta'], fov_x, None

def euler_traj_to_rotmat_traj(euler_traj, T):
    """
    Convert Euler angle trajectory to rotation matrix trajectory.
    """
    hand_pose = euler_traj.reshape(-1, 3)  # [T*15, 3]
    pose_matrices = R.from_euler('xyz', hand_pose).as_matrix()  # [T*15, 3, 3]
    pose_matrices = pose_matrices.reshape(T, 15, 3, 3)  # [T, 15, 3, 3]

    return pose_matrices

## Persistent Workers for Multiprocessing

These functions run in separate processes to avoid CUDA context conflicts.

In [None]:
def _hand_reconstruction_worker(args_dict, task_queue, result_queue):
    """
    Persistent worker for hand reconstruction that runs in a separate process.
    """
    from data.tools.hand_recon_core import Config, HandReconstructor
    
    hand_reconstructor = None
    
    try:
        class ArgsObj:
            pass
        args_obj = ArgsObj()
        for key, value in args_dict.items():
            setattr(args_obj, key, value)
        
        print("[HandRecon Process] Initializing hand reconstructor...")
        config = Config(args_obj)
        hand_reconstructor = HandReconstructor(config=config, device='cuda')
        print("[HandRecon Process] Hand reconstructor ready")
        
        result_queue.put({'type': 'ready'})
        
        while True:
            task = task_queue.get()
            if task['type'] == 'shutdown':
                break
            elif task['type'] == 'reconstruct':
                try:
                    image_path = task['image_path']
                    image = cv2.imread(image_path)
                    if image is None:
                        raise ValueError(f"Failed to load image from {image_path}")
                    
                    image_list = [image]
                    recon_results = hand_reconstructor.recon(image_list)
                    result_queue.put({'type': 'result', 'success': True, 'data': recon_results})
                except Exception as e:
                    import traceback
                    result_queue.put({'type': 'result', 'success': False, 'error': str(e), 'traceback': traceback.format_exc()})
    except Exception as e:
        import traceback
        result_queue.put({'type': 'error', 'error': str(e), 'traceback': traceback.format_exc()})
    finally:
        if hand_reconstructor is not None:
            del hand_reconstructor
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        print("[HandRecon Process] Exiting")

def _vla_inference_worker(configs_dict, task_queue, result_queue):
    """
    Persistent worker for VLA model inference.
    """
    from vitra.models import load_model
    from vitra.utils.data_utils import load_normalizer
    from vitra.datasets.human_dataset import pad_state_human, pad_action
    from vitra.datasets.dataset_utils import ActionFeature, StateFeature
    
    model = None
    normalizer = None
    
    try:
        print("[VLA Process] Loading VLA model...")
        model = load_model(configs_dict).cuda()
        model.eval()
        normalizer = load_normalizer(configs_dict)
        print(f"[VLA Process] VLA model ready.")
        
        result_queue.put({'type': 'ready'})
        
        while True:
            task = task_queue.get()
            if task['type'] == 'shutdown':
                break
            elif task['type'] == 'predict':
                try:
                    image = task['image']
                    instruction = task['instruction']
                    state = task['state']
                    state_mask = task['state_mask']
                    action_mask = task['action_mask']
                    fov = task['fov']
                    num_ddim_steps = task.get('num_ddim_steps', 10)
                    cfg_scale = task.get('cfg_scale', 5.0)
                    sample_times = task.get('sample_times', 1)
                    
                    norm_state = normalizer.normalize_state(state.copy())
                    unified_action_dim = ActionFeature.ALL_FEATURES[1]
                    unified_state_dim = StateFeature.ALL_FEATURES[1]
                    
                    unified_state, unified_state_mask = pad_state_human(
                        state=norm_state,
                        state_mask=state_mask,
                        action_dim=normalizer.action_mean.shape[0],
                        state_dim=normalizer.state_mean.shape[0],
                        unified_state_dim=unified_state_dim,
                    )
                    _, unified_action_mask = pad_action(
                        actions=None,
                        action_mask=action_mask.copy(),
                        action_dim=normalizer.action_mean.shape[0],
                        unified_action_dim=unified_action_dim
                    )
                    
                    fov = torch.from_numpy(fov).unsqueeze(0)
                    unified_state = unified_state.unsqueeze(0)
                    unified_state_mask = unified_state_mask.unsqueeze(0)
                    unified_action_mask = unified_action_mask.unsqueeze(0)
                    
                    norm_action = model.predict_action(
                        image=image,
                        instruction=instruction,
                        current_state=unified_state,
                        current_state_mask=unified_state_mask,
                        action_mask_torch=unified_action_mask,
                        num_ddim_steps=num_ddim_steps,
                        cfg_scale=cfg_scale,
                        fov=fov,
                        sample_times=sample_times,
                    )
                    
                    norm_action = norm_action[:, :, :102]
                    unnorm_action = normalizer.unnormalize_action(norm_action)
                    
                    if isinstance(unnorm_action, torch.Tensor):
                        unnorm_action_np = unnorm_action.cpu().numpy()
                    else:
                        unnorm_action_np = np.array(unnorm_action)
                    
                    result_queue.put({'type': 'result', 'success': True, 'data': unnorm_action_np})
                except Exception as e:
                    import traceback
                    result_queue.put({'type': 'result', 'success': False, 'error': str(e), 'traceback': traceback.format_exc()})
    except Exception as e:
        import traceback
        result_queue.put({'type': 'error', 'error': str(e), 'traceback': traceback.format_exc()})
    finally:
        if model is not None: del model
        if normalizer is not None: del normalizer
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        print("[VLA Process] Exiting")

## Service Classes

In [None]:
class HandReconstructionService:
    def __init__(self, args_dict):
        self.ctx = mp.get_context('spawn')
        self.task_queue = self.ctx.Queue()
        self.result_queue = self.ctx.Queue()
        self.process = self.ctx.Process(target=_hand_reconstruction_worker, args=(args_dict, self.task_queue, self.result_queue))
        self.process.start()
        ready_msg = self.result_queue.get()
        if ready_msg['type'] == 'ready':
            print("Hand reconstruction service initialized")
        elif ready_msg['type'] == 'error':
            raise RuntimeError(f"Failed to initialize hand reconstruction: {ready_msg['error']}")
    
    def reconstruct(self, image_path):
        self.task_queue.put({'type': 'reconstruct', 'image_path': image_path})
        result = self.result_queue.get()
        if result['type'] == 'result' and result['success']:
            return result['data']
        else:
            raise RuntimeError(f"Hand reconstruction failed: {result.get('error', 'Unknown error')}")
    
    def shutdown(self):
        self.task_queue.put({'type': 'shutdown'})
        self.process.join(timeout=10)
        if self.process.is_alive():
            self.process.terminate()
            self.process.join()

class VLAInferenceService:
    def __init__(self, configs):
        self.ctx = mp.get_context('spawn')
        self.task_queue = self.ctx.Queue()
        self.result_queue = self.ctx.Queue()
        self.process = self.ctx.Process(target=_vla_inference_worker, args=(configs, self.task_queue, self.result_queue))
        self.process.start()
        ready_msg = self.result_queue.get()
        if ready_msg['type'] == 'ready':
            print("VLA inference service initialized")
        elif ready_msg['type'] == 'error':
            raise RuntimeError(f"Failed to initialize VLA model: {ready_msg['error']}")
    
    def predict(self, image, instruction, state, state_mask, action_mask, fov, num_ddim_steps=10, cfg_scale=5.0, sample_times=1):
        self.task_queue.put({
            'type': 'predict', 'image': image, 'instruction': instruction, 'state': state,
            'state_mask': state_mask, 'action_mask': action_mask, 'fov': fov,
            'num_ddim_steps': num_ddim_steps, 'cfg_scale': cfg_scale, 'sample_times': sample_times,
        })
        result = self.result_queue.get()
        if result['type'] == 'result' and result['success']:
            return result['data']
        else:
            raise RuntimeError(f"VLA inference failed: {result.get('error', 'Unknown error')}")
    
    def shutdown(self):
        self.task_queue.put({'type': 'shutdown'})
        self.process.join(timeout=10)
        if self.process.is_alive():
            self.process.terminate()
            self.process.join()

## Configuration and Initialization

In [None]:
class Args:
    def __init__(self):
        # Model Configuration
        self.config_path = '../config/human_vla.json'
        self.model_path = None
        self.statistics_path = None
        
        # Input/Output
        self.image_path = '../data/examples/human_example.jpg'
        self.hand_path = None
        self.video_path = './example_human_inf.mp4'
        
        # Hand Reconstruction Models
        self.hawor_model_path = '../weights/hawor/checkpoints/hawor.ckpt'
        self.detector_path = '../weights/hawor/external/detector.pt'
        self.moge_model_name = 'Ruicheng/moge-2-vitl'
        self.mano_path = '../weights/mano'
        
        # Prediction Settings
        self.use_left = True
        self.use_right = True
        self.instruction = "Left: Put the trash into the garbage. Right: None."
        self.sample_times = 4
        self.fps = 8
        self.save_state_local = True

args = Args()
configs = load_config(args.config_path)
if args.model_path: configs['model_load_path'] = args.model_path
if args.statistics_path: configs['statistics_path'] = args.statistics_path

image_path_obj = Path(args.image_path)
npy_path = image_path_obj.with_suffix('.npy')

hand_data = None
hand_recon_service = None

if npy_path.exists():
    print(f"Found precomputed hand state results: {npy_path}.")
    hand_data = np.load(npy_path, allow_pickle=True).item()
else:
    recon_args_dict = {
        'hawor_model_path': args.hawor_model_path,
        'detector_path': args.detector_path,
        'moge_model_name': args.moge_model_name,
        'mano_path': args.mano_path,
    }
    hand_recon_service = HandReconstructionService(recon_args_dict)

vla_service = VLAInferenceService(configs)

hand_config = HandConfig(args)
hand_config.FPS = args.fps
visualizer = HandVisualizer(hand_config, render_gradual_traj=False)

## Run Inference and Visualization

In [None]:
try:
    if hand_data is None:
        print("Running hand reconstruction...")
        hand_data = hand_recon_service.reconstruct(args.image_path)
        if args.save_state_local:
            np.save(npy_path, hand_data, allow_pickle=True)
            print(f"Saved reconstructed hand state to {npy_path}")

    image = Image.open(args.image_path)
    ori_w, ori_h = image.size
    try:
        image = ImageOps.exif_transpose(image)
    except Exception: pass

    image_resized = resize_short_side_to_target(image, target=224)
    w, h = image_resized.size

    current_state_left, current_state_right = None, None
    if args.use_right:
        current_state_right, beta_right, fov_x, _ = get_state(hand_data, hand_side='right')
    if args.use_left:
        current_state_left, beta_left, fov_x, _ = get_state(hand_data, hand_side='left')
    
    fov_x_rad = fov_x * np.pi / 180
    f_ori = ori_w / np.tan(fov_x_rad / 2) / 2
    fov_y_rad = 2 * np.arctan(ori_h / (2 * f_ori))
    f = w / np.tan(fov_x_rad / 2) / 2
    intrinsics = np.array([[f, 0, w/2], [0, f, h/2], [0, 0, 1]])

    state_left = current_state_left if args.use_left else np.zeros_like(current_state_right)
    beta_left = beta_left if args.use_left else np.zeros_like(beta_right)
    state_right = current_state_right if args.use_right else np.zeros_like(current_state_left)
    beta_right = beta_right if args.use_right else np.zeros_like(beta_left)
    
    state = np.concatenate([state_left, beta_left, state_right, beta_right], axis=0)
    state_mask = np.array([args.use_left, args.use_right], dtype=bool)
    chunk_size = configs.get('fwd_pred_next_n', 16)
    action_mask = np.tile(np.array([[args.use_left, args.use_right]], dtype=bool), (chunk_size, 1)) 
    fov = np.array([fov_x_rad, fov_y_rad], dtype=np.float32)
    image_resized_np = np.array(image_resized)

    print(f"Running VLA inference...")
    unnorm_action = vla_service.predict(
        image=image_resized_np, instruction=args.instruction, state=state,
        state_mask=state_mask, action_mask=action_mask, fov=fov,
        num_ddim_steps=10, cfg_scale=5.0, sample_times=args.sample_times,
    )
    
    fx_exo, fy_exo = intrinsics[0, 0], intrinsics[1, 1]
    renderer = Renderer(w, h, (fx_exo, fy_exo), 'cuda')
    T = len(action_mask) + 1
    traj_mask = np.tile(np.array([[args.use_left, args.use_right]], dtype=bool), (T, 1)) 
    hand_mask = (traj_mask[:, 0], traj_mask[:, 1])
    all_rendered_frames = []
    
    for i in range(args.sample_times):
        traj_left = recon_traj(state=state_left, rel_action=unnorm_action[i, :, 0:51]) if args.use_left else np.zeros((T, 51))
        traj_right = recon_traj(state=state_right, rel_action=unnorm_action[i, :, 51:102]) if args.use_right else np.zeros((T, 51))
        
        left_hand_labels = {
            'transl_worldspace': traj_left[:, 0:3],
            'global_orient_worldspace': R.from_euler('xyz', traj_left[:, 3:6]).as_matrix(),
            'hand_pose': euler_traj_to_rotmat_traj(traj_left[:, 6:51], T),
            'beta': beta_left,
        }
        right_hand_labels = {
            'transl_worldspace': traj_right[:, 0:3],
            'global_orient_worldspace': R.from_euler('xyz', traj_right[:, 3:6]).as_matrix(),
            'hand_pose': euler_traj_to_rotmat_traj(traj_right[:, 6:51], T),
            'beta': beta_right,
        }
        verts_left, _ = process_single_hand_labels(left_hand_labels, hand_mask[0], visualizer.mano, is_left=True)
        verts_right, _ = process_single_hand_labels(right_hand_labels, hand_mask[1], visualizer.mano, is_left=False)
        
        extrinsics = (np.broadcast_to(np.eye(3), (T, 3, 3)).copy(), np.zeros((T, 3, 1), dtype=np.float32))
        save_frames = visualizer._render_hand_trajectory([image_resized_np[..., ::-1]] * T, (verts_left, verts_right), hand_mask, extrinsics, renderer, mode='first')
        all_rendered_frames.append(save_frames)
    
    grid_cols = math.ceil(math.sqrt(args.sample_times))
    grid_rows = math.ceil(args.sample_times / grid_cols)
    combined_frames = []
    for frame_idx in range(T):
        sample_frames = [all_rendered_frames[i][frame_idx] for i in range(args.sample_times)]
        while len(sample_frames) < grid_rows * grid_cols:
            sample_frames.append(np.zeros_like(sample_frames[0]))
        rows = [np.concatenate(sample_frames[r*grid_cols:(r+1)*grid_cols], axis=1) for r in range(grid_rows)]
        combined_frames.append(np.concatenate(rows, axis=0))

    save_to_video(combined_frames, f'{args.video_path}', fps=hand_config.FPS)
    print(f"Combined video saved to {args.video_path}")

finally:
    print("Shutting down services...")
    if hand_recon_service is not None: hand_recon_service.shutdown()
    vla_service.shutdown()
    print("All services shut down")

## Display Results

You can use the following cell to display the generated video in the notebook.

In [None]:
from IPython.display import Video
Video(args.video_path, embed=True)