# PI05 Policy Test: Subtask Generation and Action Prediction

This notebook tests the PI05 policy with subtask generation and action prediction, visualizing prompts, subtasks, and actions simulating the inference process with a dataset.


In [None]:
# Imports
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.core import TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_TOKENS
from transformers import AutoTokenizer
from xhuman.policies.pi05.processor_pi05 import make_pi05_pre_post_processors_ki
from xhuman.policies.factory import make_xhuman_policy
from xhuman.policies.pi05.configuration_pi05 import PI05Config

print("‚úì Imports loaded")


In [None]:
# Constants
DS_ID = "NONHUMAN-RESEARCH/TEST_RECORD_ANNOTATIONS"
PRETRAINED_PATH = "lerobot/pi05_base"
TOKENIZER_NAME = "google/paligemma-3b-pt-224"

print(f"Dataset ID: {DS_ID}")
print(f"Pretrained path: {PRETRAINED_PATH}")


In [None]:
# Helper Functions
def decode_tokens(tokens: torch.Tensor, tokenizer) -> str:
    """Decode tokens to visualize the prompt."""
    if tokens.dim() == 2:
        tokens = tokens[0]  # Take first sample from batch
    
    # Remove padding (token id 0)
    tokens = tokens[tokens != 0]
    
    return tokenizer.decode(tokens, skip_special_tokens=False)


def visualize_prompt(batch: dict, tokenizer, step: int, prompt_type: str):
    """Visualize the prompt being sent to the model."""
    tokens = batch[OBS_LANGUAGE_TOKENS]
    decoded = decode_tokens(tokens, tokenizer)
    
    print(f"\n{'='*80}")
    print(f"Step {step}: {prompt_type}")
    print(f"{'='*80}")
    print(f"Prompt: {decoded}")
    print(f"{'='*80}\n")

print("‚úì Helper functions defined")


In [None]:
# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
subtask_prediction_frequency = 50  # Generate subtask every N steps
max_steps = 20  # Test for N steps
episode_index = 0  # Which episode to test

print(f"Device: {device}")
print(f"Subtask prediction frequency: {subtask_prediction_frequency}")
print(f"Max steps: {max_steps}")
print(f"Episode index: {episode_index}")


In [None]:
# Load Dataset
print("Loading dataset...")
dataset = LeRobotDataset(DS_ID)
print(f"‚úì Dataset loaded: {len(dataset)} samples")
print(f"Features: {list(dataset.features.keys())}")

# Get a sample episode
episode_data = dataset[episode_index]
print(f"\n‚úì Episode {episode_index} loaded")
print(f"Episode keys: {list(episode_data.keys())}")


In [None]:
# Create Policy Config
print("Creating policy config...")
policy_config = PI05Config(
    pretrained_path=PRETRAINED_PATH,  # Set pretrained path so factory loads weights
    device=device,
)
print(f"‚úì Config created")
print(f"  - Type: {policy_config.type}")
print(f"  - Device: {policy_config.device}")
print(f"  - Chunk size: {policy_config.chunk_size}")


In [None]:
# Load Policy (factory populates input_features/output_features from dataset)
print("Loading policy from pretrained...")
policy = make_xhuman_policy(
    cfg=policy_config,
    ds_meta=dataset.meta,
)

print(f"‚úì Policy loaded: {policy.name}")
print(f"  - Input features: {list(policy.config.input_features.keys())}")
print(f"  - Output features: {list(policy.config.output_features.keys())}")
action_dim = policy.config.output_features['action'].shape[0]
print(f"  - Action dimension: {action_dim}")
print(f"  - Chunk size: {policy.config.chunk_size}")


In [None]:
# Create Preprocessor and Postprocessor
# Note: policy.config now has input_features and output_features set
preprocessor, postprocessor = make_pi05_pre_post_processors_ki(
    policy.config,
    dataset_stats=dataset.stats,  # Important: pass dataset stats for normalization
)

print("‚úì Preprocessor and postprocessor created")


In [None]:
# Load Tokenizer for Visualization
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
print(f"‚úì Tokenizer loaded: {TOKENIZER_NAME}")

# Get task from dataset metadata
task = (
    dataset.tasks[0]
    if hasattr(dataset, 'tasks') and len(dataset.tasks) > 0
    else "pick up object"
)
print(f"‚úì Task: {task}")


## Inference Loop

The loop simulates the inference process:
- Every `subtask_prediction_frequency` steps: generate a new subtask
- Every step: generate actions using the cached subtask


In [None]:
# Initialize inference loop
print("\n" + "="*80)
print("STARTING INFERENCE LOOP")
print("="*80 + "\n")


In [None]:
# Single step example - you can modify time_index to test specific steps
time_index = 0  # Change this to test different time steps

print(f"\n{'‚îÄ'*80}")
print(f"TIME INDEX: {time_index}")
print(f"{'‚îÄ'*80}")

# Prepare observation
obs = {
    "observation.images.top": episode_data["observation.images.top"][time_index:time_index+1],
    "observation.state": episode_data["observation.state"][time_index:time_index+1],
}

# Add complementary data
complementary_data = {
    "task": task,
    "time_index": time_index,
    "subtask": policy.cached_subtask,  # Use cached subtask
}

print(f"‚úì Observation prepared for time_index={time_index}")
print(f"  - Image shape: {obs['observation.images.top'].shape}")
print(f"  - State shape: {obs['observation.state'].shape}")
print(f"  - Current cached subtask: '{policy.cached_subtask}'")


In [None]:
# Check if we should generate subtask
should_generate_subtask = (
    subtask_prediction_frequency > 0
    and time_index % subtask_prediction_frequency == 0
)

if should_generate_subtask:
    print(f"üîÑ GENERATING NEW SUBTASK at step {time_index}")
    
    # Prepare batch for subtask generation
    # The processor will create prompt: "Task: X. Subtask: "
    obs_subtask = {**obs}
    complementary_data_subtask = {
        "task": task,
        "time_index": time_index,
        "subtask": None,  # Force subtask generation prompt
    }
    
    # Note: The preprocessor expects transition format
    transition_subtask = {
        TransitionKey.OBSERVATION: obs_subtask,
        TransitionKey.COMPLEMENTARY_DATA: complementary_data_subtask,
    }
    
    # Preprocess
    batch_subtask = preprocessor(transition_subtask)
    
    # Visualize subtask generation prompt
    visualize_prompt(batch_subtask, tokenizer, time_index, "SUBTASK GENERATION")
    
    # Generate subtask
    policy.update_subtask(batch_subtask)
    
    print(f"‚úÖ Generated subtask: '{policy.cached_subtask}'")
    
    # Update complementary data with new subtask
    complementary_data["subtask"] = policy.cached_subtask
else:
    print(f"‚è≠Ô∏è  Skipping subtask generation (frequency={subtask_prediction_frequency})")
