# Core Training Notebook

This notebook contains the essential components for training a PI05 policy:

## What's Included
- **Dataset Loading**: Load and filter episodes from HuggingFace datasets
- **Policy Setup**: Initialize PI05 policy with proper configuration
- **Training Loop**: Core training loop with gradient updates and metrics tracking

## What's Excluded
- Evaluation loops
- WandB logging
- Checkpoint saving strategies
- Model inference/testing

## Usage
1. Set your dataset and policy configuration in the configuration cells
2. Enable `DEBUG_MODE = True` for quick testing with a subset of episodes
3. Set `DEBUG_MODE = False` for full training

## Setup

The following cells set up the environment:
1. Clone the XHUMAN repository
2. Install dependencies
3. Authenticate with HuggingFace Hub
4. Import required libraries

Cloning into 'XHUMAN'...
remote: Enumerating objects: 2067, done.[K
remote: Counting objects: 100% (273/273), done.[K
remote: Compressing objects: 100% (170/170), done.[K
remote: Total 2067 (delta 171), reused 178 (delta 100), pack-reused 1794 (from 2)[K
Receiving objects: 100% (2067/2067), 7.81 MiB | 15.44 MiB/s, done.
Resolving deltas: 100% (1314/1314), done.


In [1]:
%cd XHUMAN

/content/XHUMAN


In [2]:
!uv pip install -e .[pi]

[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m215 packages[0m [2min 500ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 796ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 0.37ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 0.84ms[0m[0m
 [31m-[39m [1mxhuman[0m[2m==0.1.0 (from file:///content/XHUMAN/XHUMAN)[0m
 [32m+[39m [1mxhuman[0m[2m==0.1.0 (from file:///content/XHUMAN)[0m


In [11]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import time
from contextlib import nullcontext
from typing import Any

import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from torch.optim import Optimizer

from lerobot.configs import parser
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import load_training_state
from lerobot.utils.utils import (
    format_big_number,
    has_method,
    init_logging,
)

from xhuman.policies.factory import make_xhuman_policy, make_xhuman_pre_post_processors
from xhuman.configs.train import TrainPipelineConfigXHUMAN
from xhuman.datasets.factory import make_dataset_xhuman
from xhuman.datasets.utils import split_train_eval_episodes
from xhuman.logger import logger

ESO


## Helper Functions

These functions handle dataset loading and policy updates. They are designed to work with distributed training using HuggingFace Accelerate.

In [4]:
def load_dataset(cfg: TrainPipelineConfigXHUMAN, episodes: list[int], is_main_process: bool = True, accelerator: Accelerator | None = None):
    """
    Load the dataset for training and evaluation.
    """
    # Dataset loading synchronization: main process downloads first to avoid race conditions
    cfg.dataset.episodes = episodes

    if is_main_process:
        logger.info("Creating dataset")
        dataset = make_dataset_xhuman(cfg)

    accelerator.wait_for_everyone()

    # Now all other processes can safely load the dataset
    if not is_main_process:
        dataset = make_dataset_xhuman(cfg)

    return dataset

In [5]:
def update_policy(
    train_metrics: MetricsTracker,
    policy: PreTrainedPolicy,
    batch: Any,
    optimizer: Optimizer,
    grad_clip_norm: float,
    accelerator: Accelerator,
    lr_scheduler=None,
    lock=None,
) -> tuple[MetricsTracker, dict]:
    """
    Performs a single training step to update the policy's weights.

    This function executes the forward and backward passes, clips gradients, and steps the optimizer and
    learning rate scheduler. Accelerator handles mixed-precision training automatically.

    Args:
        train_metrics: A MetricsTracker instance to record training statistics.
        policy: The policy model to be trained.
        batch: A batch of training data.
        optimizer: The optimizer used to update the policy's parameters.
        grad_clip_norm: The maximum norm for gradient clipping.
        accelerator: The Accelerator instance for distributed training and mixed precision.
        lr_scheduler: An optional learning rate scheduler.
        lock: An optional lock for thread-safe optimizer updates.

    Returns:
        A tuple containing:
        - The updated MetricsTracker with new statistics for this step.
        - A dictionary of outputs from the policy's forward pass, for logging purposes.
    """
    start_time = time.perf_counter()
    policy.train()

    # Let accelerator handle mixed precision
    with accelerator.autocast():
        loss, output_dict = policy.forward(batch)

    # Use accelerator's backward method
    accelerator.backward(loss)

    # Clip gradients if specified
    if grad_clip_norm > 0:
        grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(
            policy.parameters(), float("inf"), error_if_nonfinite=False
        )

    # Optimizer step
    with lock if lock is not None else nullcontext():
        optimizer.step()

    optimizer.zero_grad()

    # Step through pytorch scheduler at every batch instead of epoch
    if lr_scheduler is not None:
        lr_scheduler.step()

    # Update internal buffers if policy has update method
    if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
        accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()

    train_metrics.loss = loss.item()
    train_metrics.grad_norm = grad_norm.item()
    train_metrics.lr = optimizer.param_groups[0]["lr"]
    train_metrics.update_s = time.perf_counter() - start_time
    return train_metrics, output_dict

In [6]:
from lerobot.policies.pi05 import PI05Config

policy_config = PI05Config(repo_id="none",device="cuda")

## Configuration and Setup

Configure your dataset and policy settings here. The dataset configuration specifies which HuggingFace repository to load, and the policy configuration sets up the PI05 model architecture.

In [None]:
from xhuman.configs.default import LerobotDatasetConfig

dataset_config = LerobotDatasetConfig(
    repo_id="NONHUMAN-RESEARCH/pick-and-place-fruits-v2-test",
)

In [130]:
# Load configuration (you can modify this to load from a config file or set directly)
# Example: cfg = TrainPipelineConfigXHUMAN.from_dict({...})
# For now, we'll use the parser to load from command line or config file
# In notebook, you might want to set cfg directly


cfg = TrainPipelineConfigXHUMAN(
    dataset=dataset_config,
    policy=policy_config # Example policy configuration, replace with your actual policy path
)
cfg.validate()

## Training Setup

Initialize the Accelerator for distributed training and set up the training environment. The accelerator automatically handles:
- Multi-GPU training
- Mixed precision training
- Gradient synchronization across processes

In [131]:
# Create Accelerator
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])

init_logging(accelerator=accelerator)

# Determine if this is the main process (for logging and checkpointing)
is_main_process = accelerator.is_main_process

# Set seed if specified
if cfg.seed is not None:
    set_seed(cfg.seed, accelerator=accelerator)

# Use accelerator's device
device = accelerator.device
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
# ============================================================================
# Dataset Loading with Episode Filtering
# ============================================================================
# This cell loads the dataset with proper episode filtering.
# For debugging: Set DEBUG_MODE = True to use only a subset of episodes
# For production: Set DEBUG_MODE = False to use all available episodes

DEBUG_MODE = True  # Set to False for full training
DEBUG_MAX_EPISODES = 3  # Use only first N episodes for debugging

# First, get total episodes count (load minimal dataset to check)
if is_main_process:
    temp_dataset = make_dataset_xhuman(cfg)
    total_episodes = temp_dataset.meta.total_episodes
    del temp_dataset
    logger.info(f"Total episodes available: {total_episodes}")
else:
    # For non-main processes, use a reasonable default
    # In practice, this will be synced after main process loads
    total_episodes = 4  # Fallback - adjust if needed

accelerator.wait_for_everyone()

# Limit episodes for debugging
if DEBUG_MODE:
    episodes = list(range(min(DEBUG_MAX_EPISODES, total_episodes)))
    if is_main_process:
        logger.info(f"DEBUG MODE: Using only {len(episodes)} episodes")
else:
    episodes = list(range(total_episodes))

# Split episodes
train_episodes, eval_episodes = split_train_eval_episodes(
    episodes, split_ratio=cfg.split_ratio, seed=42
)

# Load dataset with ONLY train episodes (proper way to filter)
# This uses the load_dataset helper function which sets cfg.dataset.episodes
if is_main_process:
    logger.info(f"Loading train dataset with {len(train_episodes)} episodes")
dataset = load_dataset(cfg, train_episodes, is_main_process=is_main_process, accelerator=accelerator)

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

In [17]:
# Create policy
if is_main_process:
    logger.info("Creating policy")
policy = make_xhuman_policy(
    cfg=cfg.policy,
    ds_meta=dataset.meta,
)

accelerator.wait_for_everyone()



## Dataset and Model Information

Display metadata about the loaded dataset and model. This includes:
- Total number of episodes and frames
- Model parameter counts
- Effective batch size (accounting for distributed training)

In [146]:
# Display dataset metadata and model configuration
if is_main_process:
    from pprint import pprint

    print("=" * 80)
    print("DATASET METADATA")
    print("=" * 80)
    print(f"\nDataset Repository: {dataset.repo_id}")
    print(f"Total Episodes: {dataset.meta.total_episodes}")
    print(f"Training Episodes: {len(train_episodes)}")
    print(f"Number of Frames: {dataset.num_frames:,}")
    print(f"Number of Episodes (loaded): {dataset.num_episodes}")


    print("\n" + "=" * 80)

DATASET METADATA

Dataset Repository: NONHUMAN-RESEARCH/TEST_RECORD_ANNOTATIONS
Total Episodes: 4
Training Episodes: 3
Number of Frames: 1,866
Number of Episodes (loaded): 3



In [147]:
processor_kwargs = {}
postprocessor_kwargs = {}

if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
    processor_kwargs["dataset_stats"] = dataset.meta.stats

if cfg.policy.pretrained_path is not None:
    processor_kwargs["preprocessor_overrides"] = {
        "device_processor": {"device": device.type},
        "normalizer_processor": {
            "stats": dataset.meta.stats,
            "features": {**policy.config.input_features, **policy.config.output_features},
            "norm_map": policy.config.normalization_mapping,
        },
    }
    postprocessor_kwargs["postprocessor_overrides"] = {
        "unnormalizer_processor": {
            "stats": dataset.meta.stats,
            "features": policy.config.output_features,
            "norm_map": policy.config.normalization_mapping,
        },
    }

In [None]:
# This cell was removed - autoreload is not needed for production training

In [148]:
# Create processors
preprocessor, postprocessor = make_xhuman_pre_post_processors(
    policy_cfg=cfg.policy,
    pretrained_path=cfg.policy.pretrained_path,
    **processor_kwargs,
    **postprocessor_kwargs,
)

<class 'lerobot.policies.act.configuration_act.ACTConfig'>
PI05Config(n_obs_steps=1, input_features={'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(14,)), 'observation.images.left': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 376, 672)), 'observation.images.top': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 376, 672)), 'observation.images.right': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 376, 672))}, output_features={'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(14,))}, device='cuda', use_amp=False, use_peft=False, push_to_hub=True, repo_id='none', private=None, tags=None, license=None, pretrained_path=None, paligemma_variant='gemma_2b', action_expert_variant='gemma_300m', dtype='float32', chunk_size=50, n_action_steps=50, max_state_dim=32, max_action_dim=32, num_inference_steps=10, time_sampling_beta_alpha=1.5, time_sampling_beta_beta=1.0, time_sampling_scale=0.999, time_sampling_of

In [149]:
# Create optimizer and scheduler
if is_main_process:
    logger.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)

step = 0  # number of policy updates

# Resume from checkpoint if needed
if cfg.resume:
    step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)

In [150]:
# Print training info
if is_main_process:
    num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
    num_total_params = sum(p.numel() for p in policy.parameters())
    logger.info(f"Output dir: {cfg.output_dir}")
    logger.info(f"Steps: {cfg.steps} ({format_big_number(cfg.steps)})")
    logger.info(f"Dataset frames: {dataset.num_frames} ({format_big_number(dataset.num_frames)})")
    logger.info(f"Dataset episodes: {dataset.num_episodes}")
    num_processes = accelerator.num_processes
    effective_bs = cfg.batch_size * num_processes
    logger.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
    logger.info(f"Learnable params: {num_learnable_params} ({format_big_number(num_learnable_params)})")
    logger.info(f"Total params: {num_total_params} ({format_big_number(num_total_params)})")

In [151]:
# Create dataloader
if hasattr(cfg.policy, "drop_n_last_frames"):
    logger.info(f"Dropping {cfg.policy.drop_n_last_frames} last frames")
    shuffle = False
    sampler = EpisodeAwareSampler(
        dataset.meta.episodes["dataset_from_index"],
        dataset.meta.episodes["dataset_to_index"],
        drop_n_last_frames=cfg.policy.drop_n_last_frames,
        shuffle=True,
    )
else:
    logger.info("Not dropping any frames")
    shuffle = True
    sampler = None

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,
    batch_size=cfg.batch_size,
    shuffle=shuffle and not cfg.dataset.streaming,
    sampler=sampler,
    pin_memory=device.type == "cuda",
    drop_last=False,
    prefetch_factor=2 if cfg.num_workers > 0 else None,
)

In [152]:
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)

policy.train()

# Setup metrics tracking
train_metrics = {
    "loss": AverageMeter("loss", ":.3f"),
    "grad_norm": AverageMeter("grdn", ":.3f"),
    "lr": AverageMeter("lr", ":0.1e"),
    "update_s": AverageMeter("updt_s", ":.3f"),
    "dataloading_s": AverageMeter("data_s", ":.3f"),
}

effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
    effective_batch_size,
    dataset.num_frames,
    dataset.num_episodes,
    train_metrics,
    initial_step=step,
    accelerator=accelerator,
)

## Training Loop

The main training loop iterates through batches, performs forward/backward passes, and updates the policy weights. 

**Note**: The loop below runs for 6 steps as an example. For full training, replace with:
```python
while step < cfg.steps:
    # ... training code ...
```

Metrics are logged at intervals specified by `cfg.log_freq`.

In [None]:
# Training initialization
# This logs the start of training and shows how many episodes will be used
if is_main_process:
    logger.info("Start offline training on a fixed dataset")
    logger.info(f"Train episodes: {len(train_episodes)}")
    logger.info(f"Total training steps: {cfg.steps}")

In [82]:
from torch.utils.data import Dataset

class SmartSubset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

    def __getattr__(self, name):
        # This is the magic part:
        # If the code asks for 'meta', 'fps', etc., and this class doesn't have it,
        # it automatically looks inside the original dataset.
        return getattr(self.dataset, name)

# --- USAGE ---
# Use SmartSubset instead of torch.utils.data.Subset
debug_subset = SmartSubset(dataset, range(0, 50))

# Now create your loader normally
train_dataloader = DataLoader(
    debug_subset,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

dl_iter = iter(train_dataloader)

In [162]:
for step in range(0, 2): # Or however many steps you want
    # --- SAFE ITERATOR LOGIC ---
    try:
        batch_ = next(dl_iter)
    except StopIteration:
        # We finished the 50 items! Restart from the beginning.
        dl_iter = iter(train_dataloader)
        batch_ = next(dl_iter)
    start_time = time.perf_counter()
    batch = preprocessor(batch_)

    print(batch)

    train_tracker, output_dict = update_policy(
        train_tracker,
        policy,
        batch,
        optimizer,
        cfg.optimizer.grad_clip_norm,
        accelerator=accelerator,
        lr_scheduler=lr_scheduler,
    )
    print(train_tracker)
    print(output_dict)

KeyError: 7

In [51]:
dl_iter = iter(dataloader)

## Training Completion

After the training loop completes, metrics are logged and the accelerator is cleaned up. 

**Note**: Any debug cells below this point (if present) can be ignored - they were used during development to test dataset subsetting approaches. The proper way to filter episodes is now handled via `DEBUG_MODE` in the dataset loading cell.

In [None]:
# Main training loop
# Replace the range(0, 6) with while step < cfg.steps: for full training
if is_main_process:
    logger.info("Starting training loop")

while step < cfg.steps:
    # Measure data loading time
    start_time = time.perf_counter()
    batch = next(dl_iter)
    batch = preprocessor(batch)
    train_tracker.dataloading_s = time.perf_counter() - start_time

    # Update policy (forward pass, backward pass, optimizer step)
    train_tracker, output_dict = update_policy(
        train_tracker,
        policy,
        batch,
        optimizer,
        cfg.optimizer.grad_clip_norm,
        accelerator=accelerator,
        lr_scheduler=lr_scheduler,
    )

    step += 1
    train_tracker.step()

    # Log metrics at specified intervals
    is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
    if is_log_step:
        logger.info(train_tracker)
        train_tracker.reset_averages()

    # Optional: Add checkpoint saving here
    # if cfg.checkpoint_freq > 0 and step % cfg.checkpoint_freq == 0:
    #     save_checkpoint(...)

if is_main_process:
    logger.info(f"Training completed! Total steps: {step}")

KeyError: Caught KeyError in DataLoader worker process 1.
Original Traceback (most recent call last):
  File "/content/XHUMAN/xhuman/datasets/xhuman_dataset.py", line 156, in __getitem__
    query_result = self._query_hf_dataset(query_indices)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/lerobot/datasets/lerobot_dataset.py", line 993, in _query_hf_dataset
    else [self._absolute_to_relative_idx[idx] for idx in q_idx]
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: 42377

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/content/XHUMAN/xhuman/datasets/xhuman_dataset.py", line 166, in __getitem__
    query_result = self._query_hf_dataset(safe_indices)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/lerobot/datasets/lerobot_dataset.py", line 993, in _query_hf_dataset
    else [self._absolute_to_relative_idx[idx] for idx in q_idx]
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: 42377


In [None]:
# ============================================================================
# Cleanup
# ============================================================================
# Synchronize all processes and clean up accelerator resources
accelerator.wait_for_everyone()
accelerator.end_training()

if is_main_process:
    logger.info("Training session ended successfully")