# Evaluate ACT Checkpoint

This notebook evaluates a trained ACT checkpoint by calculating its success rate on the bimanual manipulation task.


In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path
sys.path.append(str(Path(os.getcwd()).parent.absolute()))


## Configuration

Set your checkpoint path and evaluation parameters here:


In [None]:
import torch
from omegaconf import OmegaConf

# ============ CONFIGURATION ============
# Change these paths to evaluate different checkpoints
CHECKPOINT_PATH = "/home_shared/grail_andre/code/bimaminobolonana/runs/act_results/checkpoint/act-train/150.pt"
CONFIG_PATH = "/home_shared/grail_andre/code/bimaminobolonana/runs/act_results/config.yaml"

# Evaluation parameters
NUM_ROLLOUTS = 100  # Number of simulation rollouts to run
MAX_STEPS_PER_ROLLOUT = 600  # Maximum steps per rollout
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")


## Load Model and Checkpoint


In [None]:
from policy.act import build_act_policy

# Load config
print(f"Loading config from: {CONFIG_PATH}")
config = OmegaConf.load(CONFIG_PATH)
print(f"\nConfig: {config.name}")
print(f"  Chunk size: {config.chunk_size}")
print(f"  Temporal context: {config.temporal_context}")
print(f"  Encoder: {config.encoder.name}")
print(f"  Image size: {config.image_size}")
print(f"  Temporal ensemble: {config.get('temporal_ensemble', False)}")

# Build model
print(f"\nBuilding ACT policy...")
model = build_act_policy(config).to(DEVICE)
model.eval()

# Load checkpoint
print(f"Loading checkpoint: {Path(CHECKPOINT_PATH).name}")
state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
print("✓ Checkpoint loaded successfully!")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")


## Setup Policy Wrapper

Create a wrapper that handles temporal context and action chunking:


In [None]:
import numpy as np
from collections import deque
from robot.sim import BimanualAction, BimanualObs, BimanualSim, randomize_block_position
from train.dataset import TensorBimanualObs

class ACTPolicyWrapper:
    """Wrapper for ACT policy that handles temporal context and action chunking."""
    def __init__(self, model, config, device):
        self.model = model
        self.config = config
        self.device = device
        self.temporal_context = config.temporal_context
        self.chunk_size = config.chunk_size
        self.ensemble_window = config.get('ensemble_window', 10) if config.get('temporal_ensemble', False) else 1
        self.obs_history = deque(maxlen=self.temporal_context)
        self.action_buffer = deque(maxlen=self.ensemble_window)

    def reset(self):
        """Reset temporal context and action buffers."""
        self.obs_history.clear()
        self.action_buffer.clear()

    def __call__(self, obs: BimanualObs) -> BimanualAction:
        """Predict action from observation."""
        self.obs_history.append(obs)
        while len(self.obs_history) < self.temporal_context:
            self.obs_history.appendleft(obs)

        tensor_obs = self._to_tensor_obs(list(self.obs_history))

        with torch.no_grad():
            action_chunk = self.model.predict_action_chunk(tensor_obs)

        self.action_buffer.append(action_chunk.cpu())

        if len(self.action_buffer) > 0:
            current_actions = []
            for i, chunk in enumerate(self.action_buffer):
                step_in_chunk = len(self.action_buffer) - 1 - i
                if step_in_chunk < chunk.shape[1]:
                    current_actions.append(chunk[0, step_in_chunk])

            if current_actions:
                action_array = torch.stack(current_actions).mean(dim=0).numpy()
            else:
                action_array = action_chunk[0, 0].cpu().numpy()
        else:
            action_array = action_chunk[0, 0].cpu().numpy()

        return BimanualAction(action_array)

    def _to_tensor_obs(self, obs_list):
        """Convert list of BimanualObs to TensorBimanualObs."""
        visual = np.stack([o.visual for o in obs_list], axis=0)
        visual = torch.from_numpy(visual).float().unsqueeze(0).to(self.device)
        qpos = torch.from_numpy(obs_list[-1].qpos.array).float().unsqueeze(0).to(self.device)
        qvel = torch.from_numpy(obs_list[-1].qvel.array).float().unsqueeze(0).to(self.device)
        return TensorBimanualObs(visual, qpos, qvel)

# Create policy wrapper
act_policy_wrapper = ACTPolicyWrapper(model, config, DEVICE)
print("✓ Policy wrapper created")


## Define Policy and Simulation Creator


In [None]:
def act_policy(obs: BimanualObs) -> BimanualAction:
    """Policy function that wraps the ACT model."""
    return act_policy_wrapper(obs)

def create_sim() -> BimanualSim:
    """Create a fresh simulation environment for each rollout."""
    act_policy_wrapper.reset()
    sim = BimanualSim(
        merge_xml_files=[Path('robot/block.xml')],
        on_mujoco_init=randomize_block_position,
        camera_dims=(config.image_size, config.image_size),
        obs_camera_names=config.camera_names
    )
    return sim

print("✓ Policy and simulation creator defined")


## Run Full Evaluation

This will run multiple rollouts and calculate the success rate. This may take a while!


In [None]:
from validate.evaluation import evaluate_policy

print("=" * 60)
print("Starting ACT Policy Evaluation")
print("=" * 60)
print(f"Checkpoint: {Path(CHECKPOINT_PATH).name}")
print(f"Number of rollouts: {NUM_ROLLOUTS}")
print(f"Max steps per rollout: {MAX_STEPS_PER_ROLLOUT}")
print(f"Temporal context: {config.temporal_context} frames")
print(f"Action chunk size: {config.chunk_size}")
print(f"Temporal ensemble: {config.get('temporal_ensemble', False)}")
if config.get('temporal_ensemble', False):
    print(f"Ensemble window: {config.get('ensemble_window', 10)}")
print("=" * 60)
print()

success_rate = evaluate_policy(
    policy=act_policy,
    create_sim=create_sim,
    num_rollouts=NUM_ROLLOUTS,
    max_steps_per_rollout=MAX_STEPS_PER_ROLLOUT,
    verbose=True
)

print()
print("=" * 60)
print(f"ACT Policy Success Rate: {success_rate * 100:.2f}%")
print(f"Successful rollouts: {int(success_rate * NUM_ROLLOUTS)}/{NUM_ROLLOUTS}")
print("=" * 60)
