<a href="https://colab.research.google.com/github/jorgemunozl/vla-test/blob/main/tests/train_val_pi05.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fourtheen Test

This test run the training using the same architecture that the scrip that we use for train seriosuly our models.

In [1]:
import os

# Set this BEFORE importing pytorch/tensorflow
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
# Check if it worked (should return 1 if you selected a single GPU)
print(torch.cuda.device_count()) 

1


In [2]:
from contextlib import nullcontext
from typing import Any

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from torch.optim import Optimizer

from lerobot.configs import parser
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import load_training_state
from lerobot.utils.utils import (
    format_big_number,
    has_method,
    init_logging,
)

from xhuman.policies.factory import make_xhuman_policy, make_xhuman_pre_post_processors
from xhuman.configs.train import TrainPipelineConfigXHUMAN
from xhuman.datasets.factory import make_dataset_xhuman
from xhuman.datasets.utils import split_train_eval_episodes
from xhuman.logger import logger

## Helper Functions

These functions handle dataset loading and policy updates. They are designed to work with distributed training using HuggingFace Accelerate.

In [3]:
def load_dataset(cfg: TrainPipelineConfigXHUMAN, episodes: list[int], is_main_process: bool = True, accelerator: Accelerator | None = None):
    """
    Load the dataset for training and evaluation.
    """
    # Dataset loading synchronization: main process downloads first to avoid race conditions
    cfg.dataset.episodes = episodes
    cfg.dataset.train_with_subtasks = True

    if is_main_process:
        logger.info("Creating dataset")
        dataset = make_dataset_xhuman(cfg)

    accelerator.wait_for_everyone()

    # Now all other processes can safely load the dataset
    if not is_main_process:
        dataset = make_dataset_xhuman(cfg)

    return dataset

In [4]:
def update_policy(
    train_metrics: MetricsTracker,
    step: int,
    policy: PreTrainedPolicy,
    batch: Any,
    optimizer: Optimizer,
    grad_clip_norm: float,
    accelerator: Accelerator,
    lr_scheduler=None,
    lock=None,
) -> tuple[MetricsTracker, dict]:
    """
    Performs a single training step to update the policy's weights.

    This function executes the forward and backward passes,
    clips gradients, and steps the optimizer and
    learning rate scheduler. Accelerator handles
    mixed-precision training automatically.

    Args:
        train_metrics: A MetricsTracker instance to record training statistics.
        policy: The policy model to be trained.
        batch: A batch of training data.
        optimizer: The optimizer used to update the policy's parameters.
        grad_clip_norm: The maximum norm for gradient clipping.
        accelerator: The Accelerator instance for distributed training and mixed precision.
        lr_scheduler: An optional learning rate scheduler.
        lock: An optional lock for thread-safe optimizer updates.

    Returns:
        A tuple containing:
        - The updated MetricsTracker with new statistics for this step.
        - A dictionary of outputs from the policy's forward pass, for logging purposes.
    """
    policy.train()

    # Let accelerator handle mixed precision
    with accelerator.autocast():
        if step % 3 == 0:
            loss, output_dict = policy.forward_subtask(batch)
        else:
            loss, output_dict = policy.forward(batch)

    # Use accelerator's backward method
    accelerator.backward(loss)

    # Clip gradients if specified
    if grad_clip_norm > 0:
        grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(
            policy.parameters(), float("inf"), error_if_nonfinite=False
        )

    # Optimizer step
    with lock if lock is not None else nullcontext():
        optimizer.step()

    optimizer.zero_grad()

    # Step through pytorch scheduler at every batch instead of epoch
    if lr_scheduler is not None:
        lr_scheduler.step()

    # Update internal buffers if policy has update method
    if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
        accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()

    # Track losses for logging
    # total_loss is per-sample [B], so we track the mean
    train_metrics.loss = loss.item()
    train_metrics.grad_norm = grad_norm.item()
    train_metrics.lr = optimizer.param_groups[0]["lr"]

    # Create output dictionary with tracked losses
    output_dict = {
        "loss": loss.mean().item(),
    }

    return train_metrics, output_dict

In [5]:
from xhuman.policies.pi05ki.configuration_pi05ki import PI05KIConfig

policy_config = PI05KIConfig(repo_id="none",device="cuda",pretrained_path="lerobot/pi05_base")

## Configuration and Setup

Configure your dataset and policy settings here. The dataset configuration specifies which HuggingFace repository to load, and the policy configuration sets up the PI05 model architecture.

In [6]:
from xhuman.configs.default import LerobotDatasetConfig

dataset_config = LerobotDatasetConfig(
    repo_id="NONHUMAN-RESEARCH/test-general-idx",
)

In [7]:
cfg = TrainPipelineConfigXHUMAN(
    dataset=dataset_config,
    policy=policy_config # Example policy configuration, replace with your actual policy path
)
cfg.validate()

## Training Setup

Initialize the Accelerator for distributed training and set up the training environment. The accelerator automatically handles:
- Multi-GPU training
- Mixed precision training
- Gradient synchronization across processes

In [8]:
# Create Accelerator
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])

init_logging(accelerator=accelerator)

# Determine if this is the main process (for logging and checkpointing)
is_main_process = accelerator.is_main_process

# Set seed if specified
if cfg.seed is not None:
    set_seed(cfg.seed, accelerator=accelerator)

# Use accelerator's device
device = accelerator.device
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

In [9]:
# ============================================================================
# Dataset Loading with Episode Filtering
# ============================================================================
# This cell loads the dataset with proper episode filtering.
# For debugging: Set DEBUG_MODE = True to use only a subset of episodes
# For production: Set DEBUG_MODE = False to use all available episodes

DEBUG_MODE = True  # Set to False for full training
DEBUG_MAX_EPISODES = 200  # Use only first N episodes for debugging

# First, get total episodes count (load minimal dataset to check)
if is_main_process:
    temp_dataset = make_dataset_xhuman(cfg)
    total_episodes = temp_dataset.meta.total_episodes
    del temp_dataset
    logger.info(f"Total episodes available: {total_episodes}")
else:
    # For non-main processes, use a reasonable default
    # In practice, this will be synced after main process loads
    total_episodes = 4  # Fallback - adjust if needed

accelerator.wait_for_everyone()

# Limit episodes for debugging
if DEBUG_MODE:
    episodes = list(range(min(DEBUG_MAX_EPISODES, total_episodes)))
    if is_main_process:
        logger.info(f"DEBUG MODE: Using only {len(episodes)} episodes")
else:
    episodes = list(range(total_episodes))

# Split episodes
train_episodes, eval_episodes = split_train_eval_episodes(
    episodes, split_ratio=cfg.split_ratio, seed=42
)

# Load dataset with ONLY train episodes (proper way to filter)
# This uses the load_dataset helper function which sets cfg.dataset.episodes
if is_main_process:
    logger.info(f"Loading train dataset with {len(train_episodes)} episodes")
dataset = load_dataset(cfg, train_episodes, is_main_process=is_main_process, accelerator=accelerator)
dataset.train_with_subtask = True

Failed to read file '/home/lperez/.cache/huggingface/lerobot/NONHUMAN-RESEARCH/test-general-idx/data/chunk-000/file-000.parquet' with error <class 'datasets.table.CastError'>: Couldn't cast
action: list<element: float>
  child 0, element: float
observation.state: list<element: float>
  child 0, element: float
timestamp: float
frame_index: int64
episode_index: int64
index: int64
task_index: int64
general_task_index: int64
-- schema metadata --
pandas: '{"index_columns": [], "column_indexes": [], "columns": [{"name":' + 1051
to
{'action': List(Value('float32'), length=14), 'observation.state': List(Value('float32'), length=14), 'timestamp': Value('float32'), 'frame_index': Value('int64'), 'episode_index': Value('int64'), 'index': Value('int64'), 'task_index': Value('int64')}
because column names don't match
ERROR 2026-01-29 10:16:41 /parquet.py:108 Failed to read file '/home/lperez/.cache/huggingface/lerobot/NONHUMAN-RESEARCH/test-general-idx/data/chunk-000/file-000.parquet' with error <





In [10]:
# Create policy
if is_main_process:
    logger.info("Creating policy")
policy = make_xhuman_policy(
    cfg=cfg.policy,
    ds_meta=dataset.meta,
)

accelerator.wait_for_everyone()

Loading model from: lerobot/pi05_base




âœ“ Loaded state dict from model.safetensors
Remapped: action_in_proj.bias -> model.action_in_proj.bias
Remapped: action_in_proj.weight -> model.action_in_proj.weight
Remapped: action_out_proj.bias -> model.action_out_proj.bias
Remapped: action_out_proj.weight -> model.action_out_proj.weight
Remapped: paligemma_with_expert.gemma_expert.lm_head.weight -> model.paligemma_with_expert.gemma_expert.lm_head.weight
Remapped: paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.bias -> model.paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.bias
Remapped: paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.weight -> model.paligemma_with_expert.gemma_expert.model.layers.0.input_layernorm.dense.weight
Remapped: paligemma_with_expert.gemma_expert.model.layers.0.mlp.down_proj.weight -> model.paligemma_with_expert.gemma_expert.model.layers.0.mlp.down_proj.weight
Remapped: paligemma_with_expert.gemma_expert.model.layers.0.mlp.gate_proj.wei

## Dataset and Model Information

Display metadata about the loaded dataset and model. This includes:
- Total number of episodes and frames
- Model parameter counts
- Effective batch size (accounting for distributed training)

In [11]:
# Display dataset metadata and model configuration
if is_main_process:
    from pprint import pprint

    print("=" * 80)
    print("DATASET METADATA")
    print("=" * 80)
    print(f"\nDataset Repository: {dataset.repo_id}")
    print(f"Total Episodes: {dataset.meta.total_episodes}")
    print(f"Training Episodes: {len(train_episodes)}")
    print(f"Number of Frames: {dataset.num_frames:,}")
    print(f"Number of Episodes (loaded): {dataset.num_episodes}")


    print("\n" + "=" * 80)

DATASET METADATA

Dataset Repository: NONHUMAN-RESEARCH/test-general-idx
Total Episodes: 437
Training Episodes: 160
Number of Frames: 151,548
Number of Episodes (loaded): 160



In [12]:
processor_kwargs = {}
postprocessor_kwargs = {}

if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
    processor_kwargs["dataset_stats"] = dataset.meta.stats

if cfg.policy.pretrained_path is not None:
    processor_kwargs["preprocessor_overrides"] = {
        "device_processor": {"device": device.type},
        "normalizer_processor": {
            "stats": dataset.meta.stats,
            "features": {**policy.config.input_features, **policy.config.output_features},
            "norm_map": policy.config.normalization_mapping,
        },
    }
    postprocessor_kwargs["postprocessor_overrides"] = {
        "unnormalizer_processor": {
            "stats": dataset.meta.stats,
            "features": policy.config.output_features,
            "norm_map": policy.config.normalization_mapping,
        },
    }

In [13]:
# Create processors
preprocessor, postprocessor = make_xhuman_pre_post_processors(
    policy_cfg=cfg.policy,
    **processor_kwargs,
    **postprocessor_kwargs,
)

In [14]:
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)

In [15]:
# Create optimizer and scheduler
if is_main_process:
    logger.info("Creating optimizer and scheduler")

step = 0  # number of policy updates

# Resume from checkpoint if needed
if cfg.resume:
    step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)

In [16]:
# Print training info
if is_main_process:
    num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
    num_total_params = sum(p.numel() for p in policy.parameters())
    logger.info(f"Output dir: {cfg.output_dir}")
    logger.info(f"Steps: {cfg.steps} ({format_big_number(cfg.steps)})")
    logger.info(f"Dataset frames: {dataset.num_frames} ({format_big_number(dataset.num_frames)})")
    logger.info(f"Dataset episodes: {dataset.num_episodes}")
    num_processes = accelerator.num_processes
    effective_bs = cfg.batch_size * num_processes
    logger.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
    logger.info(f"Learnable params: {num_learnable_params} ({format_big_number(num_learnable_params)})")
    logger.info(f"Total params: {num_total_params} ({format_big_number(num_total_params)})")

In [17]:
# Create dataloader
if hasattr(cfg.policy, "drop_n_last_frames"):
    logger.info(f"Dropping {cfg.policy.drop_n_last_frames} last frames")
    shuffle = False
    sampler = EpisodeAwareSampler(
        dataset.meta.episodes["dataset_from_index"],
        dataset.meta.episodes["dataset_to_index"],
        drop_n_last_frames=cfg.policy.drop_n_last_frames,
        shuffle=True,
    )
else:
    logger.info("Not dropping any frames")
    shuffle = True
    sampler = None



In [18]:

# Now create your loader normally
train_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

In [19]:
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    policy, optimizer, train_dataloader, lr_scheduler
)

policy.train()

# Setup metrics tracking
train_metrics = {
    "loss": AverageMeter("loss", ":.3f"),
    "grad_norm": AverageMeter("grdn", ":.3f"),
    "lr": AverageMeter("lr", ":0.1e"),
    "update_s": AverageMeter("updt_s", ":.3f"),
    "dataloading_s": AverageMeter("data_s", ":.3f"),
}

effective_batch_size = cfg.batch_size * accelerator.num_processes

In [20]:
train_tracker = MetricsTracker(
    effective_batch_size,
    dataset.num_frames,
    dataset.num_episodes,
    train_metrics,
    initial_step=step,
    accelerator=accelerator,
)

In [21]:
"""
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,
    batch_size=2,
    shuffle=shuffle and not cfg.dataset.streaming,
    sampler=sampler,
    pin_memory=device.type == "cuda",
    drop_last=False,
    prefetch_factor=2 if cfg.num_workers > 0 else None,
)
"""

'\ndataloader = torch.utils.data.DataLoader(\n    dataset,\n    num_workers=cfg.num_workers,\n    batch_size=2,\n    shuffle=shuffle and not cfg.dataset.streaming,\n    sampler=sampler,\n    pin_memory=device.type == "cuda",\n    drop_last=False,\n    prefetch_factor=2 if cfg.num_workers > 0 else None,\n)\n'

In [22]:
# Training initialization
# This logs the start of training and shows how many episodes will be used
if is_main_process:
    logger.info("Start offline training on a fixed dataset")
    logger.info(f"Train episodes: {len(train_episodes)}")
    logger.info(f"Total training steps: {cfg.steps}")

In [23]:
dl_iter = cycle(train_dataloader)

In [24]:
batch = next(dl_iter)
batch.keys()

dict_keys(['observation.images.left', 'observation.images.top', 'observation.images.right', 'action', 'observation.state', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index', 'general_task_index', 'action_is_pad', 'task', 'general_task', 'train_with_subtask'])

In [25]:
processed = preprocessor.subtask(batch)
processed.keys()

dict_keys(['action', 'next.reward', 'next.done', 'next.truncated', 'info', 'action_is_pad', 'task', 'index', 'task_index', 'episode_index', 'train_with_subtask', 'subtask', 'subtask_tokens', 'observation.images.left', 'observation.images.top', 'observation.images.right', 'observation.state', 'observation.language.tokens', 'observation.language.attention_mask'])

In [26]:
loss = policy.forward_subtask(processed)
loss

DEBUG: LINE 1089 modeling_pi05ki.py: final_loss: 


(tensor(2.2335, device='cuda:0', grad_fn=<MeanBackward0>),
 {'vlm_loss': 2.2334561347961426})

In [29]:
from lerobot.utils.constants import OBS_LANGUAGE_TOKENS, OBS_LANGUAGE_ATTENTION_MASK

In [30]:
images, img_masks = policy._preprocess_images(batch)
tokens = processed[f"{OBS_LANGUAGE_TOKENS}"]
masks = processed[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
subtasks_tokenized = processed["subtask_tokens"]


tensor([[[[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          ...,
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.]],

         [[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          ...,
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.]],

         [[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          ...,
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.]]],


        [[[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3.

In [33]:
images[0]

tensor([[[[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          ...,
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.]],

         [[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          ...,
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.]],

         [[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          ...,
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3., -3., -3.]]],


        [[[-3., -3., -3.,  ..., -3., -3., -3.],
          [-3., -3., -3.,  ..., -3.

In [None]:
img_masks

[tensor([True, True], device='cuda:0'),
 tensor([True, True], device='cuda:0'),
 tensor([True, True], device='cuda:0')]

In [None]:
img_masks

[tensor([False, False], device='cuda:0'),
 tensor([True, True], device='cuda:0'),
 tensor([False, False], device='cuda:0')]

In [None]:
from xhuman.policies.pi05ki.modeling_pi05ki import make_att_2d_masks
from torch.nn import functional as F

In [None]:
        img_masks = [torch.zeros_like(img_masks[0]), torch.ones_like(img_masks[0]), torch.zeros_like(img_masks[0])]  # noqa: E501
        tokens = tokens.clone()

        pad_positions = (tokens == 0).int().argmax(dim=1)
        # If no padding token found, set to sequence length
        no_pad_mask = (tokens == 0).any(dim=1)
        pad_positions = torch.where(
            no_pad_mask, pad_positions, tokens.shape[1]
        )
        token_loss_mask = torch.zeros_like(
            tokens, dtype=torch.bool, device=tokens.device
        )

        list_len_subtask = [t.shape[0] for t in subtasks_tokenized]



In [None]:
policy._detokenize_subtask(tokens)

'Task: pick the fruits from the table and place them in the basket. Subtask:'

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")


In [None]:
detokenized_tokens = tokenizer.batch_decode(tokens)
detokenized_tokens

['<bos>Task: pick the fruits from the table and place them in the basket. Subtask: <pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<bos>Task: p

In [None]:
        for i in range(tokens.shape[0]):
            pad_idx = pad_positions[i].item()
            discrete_len = list_len_subtask[i]
            subtask_tokens = subtasks_tokenized[i].to(
                device=tokens.device, dtype=tokens.dtype
            )

            # Check if discrete actions fit in the sequence
            if pad_idx + discrete_len > tokens.shape[1]:
                # Truncate discrete actions if they don't fit
                available_space = tokens.shape[1] - pad_idx
                if available_space > 0:
                    subtask_tokens = subtask_tokens[:available_space]
                    discrete_len = available_space
                else:
                    # No space available, skip this sample's discrete actions
                    continue

            # Insert discrete action tokens at padding position
            tokens[i, pad_idx:pad_idx + discrete_len] = subtask_tokens
            # Mark these positions for loss computation
            token_loss_mask[i, pad_idx:pad_idx + discrete_len] = True
            # Mark discrete action positions as valid (True) in masks
            masks[i, pad_idx:pad_idx + discrete_len] = True

In [None]:
token_loss_mask


tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

In [None]:
masks

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

In [None]:
pad_positions

tensor([20, 20], device='cuda:0')

In [None]:
tokens[0][41]


tensor(0, device='cuda:0')

In [None]:
        prefix_embs, prefix_pad_masks, prefix_att_masks = policy.model.embed_prefix(
            images, img_masks, tokens, masks,
        )
        prefix_att_masks = prefix_att_masks.clone()

        num_image_embs = prefix_embs.shape[1] - tokens.shape[1]

        # Causal masking for discrete actions
        for i in range(tokens.shape[0]):
            pad_idx = pad_positions[i].item()
            discrete_len = list_len_subtask[i]

            if pad_idx + discrete_len <= tokens.shape[1]:
                discrete_action_start_idx = num_image_embs + pad_idx
                discrete_action_end_idx = (
                    discrete_action_start_idx + discrete_len
                )

                if discrete_action_end_idx <= prefix_att_masks.shape[1]:
                    print("es")
                    prefix_att_masks[
                        i, discrete_action_start_idx:discrete_action_end_idx
                    ] = 1

es
es


In [None]:
prefix_att_masks[0][790]

tensor(True, device='cuda:0')

In [None]:
discrete_action_end_idx

799

In [None]:
discrete_action_start_idx

788

In [None]:
prefix_att_masks.shape


torch.Size([2, 968])

In [None]:
        att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
        position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
        att_2d_masks_4d = policy.model._prepare_attention_masks_4d(
            att_2d_masks, dtype=prefix_embs.dtype
        )

        def forward_func(prefix_embs, att_2d_masks_4d, position_ids):
            (prefix_out, _), _ = policy.model.paligemma_with_expert.forward(
                attention_mask=att_2d_masks_4d,
                position_ids=position_ids,
                past_key_values=None,
                inputs_embeds=[prefix_embs, None],
                use_cache=False,
                adarms_cond=[None, None],
            )
            return prefix_out

        prefix_out = policy.model._apply_checkpoint(
            forward_func, prefix_embs, att_2d_masks_4d, position_ids
        )

        # prefix_out shape: [B, prefix_seq_len, hidden_dim]
        def lm_head_func(prefix_out):
            return policy.model.paligemma_with_expert.paligemma.lm_head(prefix_out)


In [None]:
        logits = policy.model._apply_checkpoint(lm_head_func, prefix_out)

        num_image_embs = prefix_embs.shape[1] - tokens.shape[1]
        lang_logits = logits[:, num_image_embs:, :]

        # shift_logits shape: [B, lang_seq_len-1, vocab_size]
        shift_logits = lang_logits[:, :-1, :].contiguous()

        # shift_labels shape: [B, lang_seq_len-1]
        shift_labels = tokens[:, 1:].contiguous()

        # shift_mask shape: [B, lang_seq_len-1]
        shift_mask = token_loss_mask[:, 1:].contiguous()

In [None]:
        shift_labels = shift_labels.masked_fill(shift_mask == 0, -100)

        flat_logits = shift_logits.view(-1, shift_logits.size(-1))
        flat_labels = shift_labels.view(-1)

In [None]:


# 3. Compute Loss
# reduction='sum' is usually safer/easier here if we manually normalize later,
# but 'none' allows you to inspect per-sample loss if needed.
loss_per_token = F.cross_entropy(
    flat_logits, 
    flat_labels, 
    reduction='none', 
    ignore_index=-100 # This handles the masking internally!
)

# Reshape back to [B, seq_len]
loss_per_token = loss_per_token.view(shift_labels.shape[0], -1)

# 4. Correct Reduction
# Option A: Global Batch Average (Most common for pre-training/finetuning)
# Sum all losses / Total valid tokens
total_loss = loss_per_token.sum()
total_valid_tokens = shift_mask.sum()
final_loss = total_loss / (total_valid_tokens + 1e-8)

In [None]:
final_loss

tensor(14.0556, device='cuda:0', grad_fn=<DivBackward0>)