# Core Training Notebook

This notebook contains only the essential parts for training a policy:
- Dataset loading
- Policy updates
- Core training loop

Excludes: evaluation, WandB logging, checkpoint strategies

In [None]:
import time
from contextlib import nullcontext
from typing import Any

import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from torch.optim import Optimizer

from lerobot.configs import parser
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import load_training_state
from lerobot.utils.utils import (
    format_big_number,
    has_method,
    init_logging,
)

from xhuman.policies.factory import make_xhuman_policy, make_xhuman_pre_post_processors
from xhuman.configs.train import TrainPipelineConfigXHUMAN
from xhuman.datasets.factory import make_dataset_xhuman
from xhuman.datasets.utils import split_train_eval_episodes
from xhuman.logger import logger

## Helper Functions

In [None]:
def load_dataset(cfg: TrainPipelineConfigXHUMAN, episodes: list[int], is_main_process: bool = True, accelerator: Accelerator | None = None):
    """
    Load the dataset for training and evaluation.
    """
    # Dataset loading synchronization: main process downloads first to avoid race conditions
    cfg.dataset.episodes = episodes

    if is_main_process:
        logger.info("Creating dataset")
        dataset = make_dataset_xhuman(cfg)

    accelerator.wait_for_everyone()

    # Now all other processes can safely load the dataset
    if not is_main_process:
        dataset = make_dataset_xhuman(cfg)
    
    return dataset

In [None]:
def update_policy(
    train_metrics: MetricsTracker,
    policy: PreTrainedPolicy,
    batch: Any,
    optimizer: Optimizer,
    grad_clip_norm: float,
    accelerator: Accelerator,
    lr_scheduler=None,
    lock=None,
) -> tuple[MetricsTracker, dict]:
    """
    Performs a single training step to update the policy's weights.

    This function executes the forward and backward passes, clips gradients, and steps the optimizer and
    learning rate scheduler. Accelerator handles mixed-precision training automatically.

    Args:
        train_metrics: A MetricsTracker instance to record training statistics.
        policy: The policy model to be trained.
        batch: A batch of training data.
        optimizer: The optimizer used to update the policy's parameters.
        grad_clip_norm: The maximum norm for gradient clipping.
        accelerator: The Accelerator instance for distributed training and mixed precision.
        lr_scheduler: An optional learning rate scheduler.
        lock: An optional lock for thread-safe optimizer updates.

    Returns:
        A tuple containing:
        - The updated MetricsTracker with new statistics for this step.
        - A dictionary of outputs from the policy's forward pass, for logging purposes.
    """
    start_time = time.perf_counter()
    policy.train()

    # Let accelerator handle mixed precision
    with accelerator.autocast():
        loss, output_dict = policy.forward(batch)

    # Use accelerator's backward method
    accelerator.backward(loss)

    # Clip gradients if specified
    if grad_clip_norm > 0:
        grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(
            policy.parameters(), float("inf"), error_if_nonfinite=False
        )

    # Optimizer step
    with lock if lock is not None else nullcontext():
        optimizer.step()

    optimizer.zero_grad()

    # Step through pytorch scheduler at every batch instead of epoch
    if lr_scheduler is not None:
        lr_scheduler.step()

    # Update internal buffers if policy has update method
    if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
        accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()

    train_metrics.loss = loss.item()
    train_metrics.grad_norm = grad_norm.item()
    train_metrics.lr = optimizer.param_groups[0]["lr"]
    train_metrics.update_s = time.perf_counter() - start_time
    return train_metrics, output_dict

## Configuration and Setup

In [None]:
# Load configuration (you can modify this to load from a config file or set directly)
# Example: cfg = TrainPipelineConfigXHUMAN.from_dict({...})
# For now, we'll use the parser to load from command line or config file
# In notebook, you might want to set cfg directly

# Uncomment and modify as needed:
# @parser.wrap()
# def get_config():
#     return TrainPipelineConfigXHUMAN()

# cfg = get_config()
# cfg.validate()

## Training Setup

In [None]:
# Create Accelerator
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])

init_logging(accelerator=accelerator)

# Determine if this is the main process (for logging and checkpointing)
is_main_process = accelerator.is_main_process

# Set seed if specified
if cfg.seed is not None:
    set_seed(cfg.seed, accelerator=accelerator)

# Use accelerator's device
device = accelerator.device
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
# Load dataset
if is_main_process:
    logger.info("Creating dataset")
    dataset = make_dataset_xhuman(cfg)

accelerator.wait_for_everyone()

if not is_main_process:
    dataset = make_dataset_xhuman(cfg)

# Split episodes
episodes = list(range(dataset.meta.total_episodes))
# Modify episode selection as needed
# episodes = episodes[1700:]
train_episodes, eval_episodes = split_train_eval_episodes(episodes, split_ratio=cfg.split_ratio, seed=42)

del dataset

# Load train dataset
if is_main_process:
    logger.info(f"Loading train dataset with {len(train_episodes)} episodes")
dataset = load_dataset(cfg, train_episodes, is_main_process, accelerator)

In [None]:
# Create policy
if is_main_process:
    logger.info("Creating policy")
policy = make_xhuman_policy(
    cfg=cfg.policy,
    ds_meta=dataset.meta,
)

accelerator.wait_for_everyone()

## Dataset and Model Information

In [None]:
# Display dataset metadata and model configuration
if is_main_process:
    from pprint import pprint
    
    print("=" * 80)
    print("DATASET METADATA")
    print("=" * 80)
    print(f"\nDataset Repository: {dataset.repo_id}")
    print(f"Total Episodes: {dataset.meta.total_episodes}")
    print(f"Training Episodes: {len(train_episodes)}")
    print(f"Number of Frames: {dataset.num_frames:,}")
    print(f"Number of Episodes (loaded): {dataset.num_episodes}")
    
    print(f"\nFeatures:")
    for key, feature in dataset.meta.features.items():
        print(f"  - {key}: {feature.type.name} (shape: {feature.shape})")
    
    print(f"\nCamera/Video Keys:")
    if hasattr(dataset.meta, 'camera_keys') and dataset.meta.camera_keys:
        for key in dataset.meta.camera_keys:
            print(f"  - {key}")
    elif hasattr(dataset.meta, 'video_keys') and dataset.meta.video_keys:
        for key in dataset.meta.video_keys:
            print(f"  - {key}")
    
    print(f"\nDataset Statistics (normalization):")
    for key, stats in dataset.meta.stats.items():
        if isinstance(stats, dict):
            print(f"  {key}:")
            for stat_name, stat_val in stats.items():
                if hasattr(stat_val, 'shape'):
                    if stat_val.numel() > 0:
                        print(f"    {stat_name}: shape={stat_val.shape}, "
                              f"min={stat_val.min().item():.4f}, max={stat_val.max().item():.4f}, "
                              f"mean={stat_val.mean().item():.4f}")
                elif isinstance(stat_val, (int, float)):
                    print(f"    {stat_name}: {stat_val}")
    
    print("\n" + "=" * 80)
    print("MODEL CONFIGURATION")
    print("=" * 80)
    print(f"\nPolicy Type: {policy.config.type}")
    print(f"Policy Class: {policy.__class__.__name__}")
    
    print(f"\nInput Features:")
    for key, feature in policy.config.input_features.items():
        print(f"  - {key}: {feature.type.name} (shape: {feature.shape})")
    
    print(f"\nOutput Features:")
    for key, feature in policy.config.output_features.items():
        print(f"  - {key}: {feature.type.name} (shape: {feature.shape})")
    
    if hasattr(policy.config, 'normalization_mapping') and policy.config.normalization_mapping:
        print(f"\nNormalization Mapping:")
        for key, value in policy.config.normalization_mapping.items():
            print(f"  {key} -> {value}")
    
    # Display model-specific configuration
    print(f"\nModel-Specific Configuration:")
    config_dict = policy.config.to_dict()
    # Filter out common fields to show only model-specific ones
    common_fields = {'type', 'device', 'input_features', 'output_features', 'normalization_mapping'}
    model_specific = {k: v for k, v in config_dict.items() if k not in common_fields}
    if model_specific:
        pprint(model_specific, width=80, indent=2)
    
    print("\n" + "=" * 80)

In [None]:
# Create processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
    processor_kwargs["dataset_stats"] = dataset.meta.stats

if cfg.policy.pretrained_path is not None:
    processor_kwargs["preprocessor_overrides"] = {
        "device_processor": {"device": device.type},
        "normalizer_processor": {
            "stats": dataset.meta.stats,
            "features": {**policy.config.input_features, **policy.config.output_features},
            "norm_map": policy.config.normalization_mapping,
        },
    }
    postprocessor_kwargs["postprocessor_overrides"] = {
        "unnormalizer_processor": {
            "stats": dataset.meta.stats,
            "features": policy.config.output_features,
            "norm_map": policy.config.normalization_mapping,
        },
    }

preprocessor, postprocessor = make_xhuman_pre_post_processors(
    policy_cfg=cfg.policy,
    pretrained_path=cfg.policy.pretrained_path,
    **processor_kwargs,
    **postprocessor_kwargs,
)

In [None]:
# Create optimizer and scheduler
if is_main_process:
    logger.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)

step = 0  # number of policy updates

# Resume from checkpoint if needed
if cfg.resume:
    step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)

In [None]:
# Print training info
if is_main_process:
    num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
    num_total_params = sum(p.numel() for p in policy.parameters())
    logger.info(f"Output dir: {cfg.output_dir}")
    logger.info(f"Steps: {cfg.steps} ({format_big_number(cfg.steps)})")
    logger.info(f"Dataset frames: {dataset.num_frames} ({format_big_number(dataset.num_frames)})")
    logger.info(f"Dataset episodes: {dataset.num_episodes}")
    num_processes = accelerator.num_processes
    effective_bs = cfg.batch_size * num_processes
    logger.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
    logger.info(f"Learnable params: {num_learnable_params} ({format_big_number(num_learnable_params)})")
    logger.info(f"Total params: {num_total_params} ({format_big_number(num_total_params)})")

In [None]:
# Create dataloader
if hasattr(cfg.policy, "drop_n_last_frames"):
    logger.info(f"Dropping {cfg.policy.drop_n_last_frames} last frames")
    shuffle = False
    sampler = EpisodeAwareSampler(
        dataset.meta.episodes["dataset_from_index"],
        dataset.meta.episodes["dataset_to_index"],
        drop_n_last_frames=cfg.policy.drop_n_last_frames,
        shuffle=True,
    )
else:
    logger.info("Not dropping any frames")
    shuffle = True
    sampler = None

dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,
    batch_size=cfg.batch_size,
    shuffle=shuffle and not cfg.dataset.streaming,
    sampler=sampler,
    pin_memory=device.type == "cuda",
    drop_last=False,
    prefetch_factor=2 if cfg.num_workers > 0 else None,
)

In [None]:
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)

policy.train()

# Setup metrics tracking
train_metrics = {
    "loss": AverageMeter("loss", ":.3f"),
    "grad_norm": AverageMeter("grdn", ":.3f"),
    "lr": AverageMeter("lr", ":0.1e"),
    "update_s": AverageMeter("updt_s", ":.3f"),
    "dataloading_s": AverageMeter("data_s", ":.3f"),
}

effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
    effective_batch_size,
    dataset.num_frames,
    dataset.num_episodes,
    train_metrics,
    initial_step=step,
    accelerator=accelerator,
)

## Training Loop

In [None]:
if is_main_process:
    logger.info("Start offline training on a fixed dataset")
    logger.info(f"Train episodes: {len(train_episodes)}")

for _ in range(step, cfg.steps):
    start_time = time.perf_counter()
    batch = next(dl_iter)
    batch = preprocessor(batch)
    train_tracker.dataloading_s = time.perf_counter() - start_time

    train_tracker, output_dict = update_policy(
        train_tracker,
        policy,
        batch,
        optimizer,
        cfg.optimizer.grad_clip_norm,
        accelerator=accelerator,
        lr_scheduler=lr_scheduler,
    )

    step += 1
    train_tracker.step()
    
    # Log to terminal at log_freq intervals
    is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
    if is_log_step:
        logger.info(train_tracker)
        train_tracker.reset_averages()

if is_main_process:
    logger.info("End of training")

In [None]:
# Clean up
accelerator.wait_for_everyone()
accelerator.end_training()