<a href="https://colab.research.google.com/github/jorgemunozl/vla-test/blob/main/tests/train_val_core.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ten Test - Debug KI Loss without the subtask implementation

This test use the same forward pass that mine model, debug here becomes pretty easy.

## Setup

The following cells set up the environment:
1. Clone the XHUMAN repository
2. Install dependencies
3. Authenticate with HuggingFace Hub
4. Import required libraries

In [None]:
# Clone

In [None]:
%cd XHUMAN

In [None]:
!uv pip install -e .[pi]

In [None]:
from huggingface_hub import login
login()

In [None]:
import time
from contextlib import nullcontext
from typing import Any

import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from torch.optim import Optimizer

from lerobot.configs import parser
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import load_training_state
from lerobot.utils.utils import (
    format_big_number,
    has_method,
    init_logging,
)

from xhuman.policies.factory import make_xhuman_policy, make_xhuman_pre_post_processors
from xhuman.configs.train import TrainPipelineConfigXHUMAN
from xhuman.datasets.factory import make_dataset_xhuman
from xhuman.datasets.utils import split_train_eval_episodes
from xhuman.logger import logger

## Helper Functions

These functions handle dataset loading and policy updates. They are designed to work with distributed training using HuggingFace Accelerate.

In [None]:
def load_dataset(cfg: TrainPipelineConfigXHUMAN, episodes: list[int], is_main_process: bool = True, accelerator: Accelerator | None = None):
    """
    Load the dataset for training and evaluation.
    """
    # Dataset loading synchronization: main process downloads first to avoid race conditions
    cfg.dataset.episodes = episodes

    if is_main_process:
        logger.info("Creating dataset")
        dataset = make_dataset_xhuman(cfg)

    accelerator.wait_for_everyone()

    # Now all other processes can safely load the dataset
    if not is_main_process:
        dataset = make_dataset_xhuman(cfg)

    return dataset

In [None]:
def update_policy(
    train_metrics: MetricsTracker,
    policy: PreTrainedPolicy,
    batch: Any,
    optimizer: Optimizer,
    grad_clip_norm: float,
    accelerator: Accelerator,
    lr_scheduler=None,
    lock=None,
) -> tuple[MetricsTracker, dict]:
    """
    Performs a single training step to update the policy's weights.

    This function executes the forward and backward passes, clips gradients, and steps the optimizer and
    learning rate scheduler. Accelerator handles mixed-precision training automatically.

    Args:
        train_metrics: A MetricsTracker instance to record training statistics.
        policy: The policy model to be trained.
        batch: A batch of training data.
        optimizer: The optimizer used to update the policy's parameters.
        grad_clip_norm: The maximum norm for gradient clipping.
        accelerator: The Accelerator instance for distributed training and mixed precision.
        lr_scheduler: An optional learning rate scheduler.
        lock: An optional lock for thread-safe optimizer updates.

    Returns:
        A tuple containing:
        - The updated MetricsTracker with new statistics for this step.
        - A dictionary of outputs from the policy's forward pass, for logging purposes.
    """
    start_time = time.perf_counter()
    policy.train()

    # Let accelerator handle mixed precision
    with accelerator.autocast():
        loss, output_dict = policy.forward(batch)

    # Use accelerator's backward method
    accelerator.backward(loss)

    # Clip gradients if specified
    if grad_clip_norm > 0:
        grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
    else:
        grad_norm = torch.nn.utils.clip_grad_norm_(
            policy.parameters(), float("inf"), error_if_nonfinite=False
        )

    # Optimizer step
    with lock if lock is not None else nullcontext():
        optimizer.step()

    optimizer.zero_grad()

    # Step through pytorch scheduler at every batch instead of epoch
    if lr_scheduler is not None:
        lr_scheduler.step()

    # Update internal buffers if policy has update method
    if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
        accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()

    train_metrics.loss = loss.item()
    train_metrics.grad_norm = grad_norm.item()
    train_metrics.lr = optimizer.param_groups[0]["lr"]
    train_metrics.update_s = time.perf_counter() - start_time
    return train_metrics, output_dict

In [None]:
from xhuman.policies.pi05.configuration_pi05 import PI05Config

policy_config = PI05Config(repo_id="none",device="cuda")

In [None]:
# If test PI05 policy
#from lerobot.policies.pi05 import PI05Config

#policy_config = PI05Config(repo_id="none",device="cuda")

## Configuration and Setup

Configure your dataset and policy settings here. The dataset configuration specifies which HuggingFace repository to load, and the policy configuration sets up the PI05 model architecture.

In [None]:
from xhuman.configs.default import LerobotDatasetConfig

dataset_config = LerobotDatasetConfig(
    repo_id="NONHUMAN-RESEARCH/pick-and-place-fruits-v2-test",
)

In [None]:

cfg = TrainPipelineConfigXHUMAN(
    dataset=dataset_config,
    policy=policy_config # Example policy configuration, replace with your actual policy path
)
cfg.validate()

## Training Setup

Initialize the Accelerator for distributed training and set up the training environment. The accelerator automatically handles:
- Multi-GPU training
- Mixed precision training
- Gradient synchronization across processes

In [None]:
# Create Accelerator
# It will automatically detect if running in distributed mode or single-process mode
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])

init_logging(accelerator=accelerator)

# Determine if this is the main process (for logging and checkpointing)
is_main_process = accelerator.is_main_process

# Set seed if specified
if cfg.seed is not None:
    set_seed(cfg.seed, accelerator=accelerator)

# Use accelerator's device
device = accelerator.device
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
DEBUG_MODE = True  # Set to False for full training
DEBUG_MAX_EPISODES = 3  # Use only first N episodes for debugging

# First, get total episodes count (load minimal dataset to check)
if is_main_process:
    temp_dataset = make_dataset_xhuman(cfg)
    total_episodes = temp_dataset.meta.total_episodes
    del temp_dataset
    logger.info(f"Total episodes available: {total_episodes}")
else:
    # For non-main processes, use a reasonable default
    # In practice, this will be synced after main process loads
    total_episodes = 4  # Fallback - adjust if needed

accelerator.wait_for_everyone()

# Limit episodes for debugging
if DEBUG_MODE:
    episodes = list(range(min(DEBUG_MAX_EPISODES, total_episodes)))
    if is_main_process:
        logger.info(f"DEBUG MODE: Using only {len(episodes)} episodes")
else:
    episodes = list(range(total_episodes))

# Split episodes
train_episodes, eval_episodes = split_train_eval_episodes(
    episodes, split_ratio=cfg.split_ratio, seed=42
)

# Load dataset with ONLY train episodes (proper way to filter)
# This uses the load_dataset helper function which sets cfg.dataset.episodes
if is_main_process:
    logger.info(f"Loading train dataset with {len(train_episodes)} episodes")
dataset = load_dataset(cfg, train_episodes, is_main_process=is_main_process, accelerator=accelerator)

In [None]:
# Create policy
if is_main_process:
    logger.info("Creating policy")
policy = make_xhuman_policy(
    cfg=cfg.policy,
    ds_meta=dataset.meta,
)

accelerator.wait_for_everyone()

## Dataset and Model Information

Display metadata about the loaded dataset and model. This includes:
- Total number of episodes and frames
- Model parameter counts
- Effective batch size (accounting for distributed training)

In [None]:
# Display dataset metadata and model configuration
if is_main_process:
    from pprint import pprint

    print("=" * 80)
    print("DATASET METADATA")
    print("=" * 80)
    print(f"\nDataset Repository: {dataset.repo_id}")
    print(f"Total Episodes: {dataset.meta.total_episodes}")
    print(f"Training Episodes: {len(train_episodes)}")
    print(f"Number of Frames: {dataset.num_frames:,}")
    print(f"Number of Episodes (loaded): {dataset.num_episodes}")


    print("\n" + "=" * 80)

In [None]:
processor_kwargs = {}
postprocessor_kwargs = {}

if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
    processor_kwargs["dataset_stats"] = dataset.meta.stats

if cfg.policy.pretrained_path is not None:
    processor_kwargs["preprocessor_overrides"] = {
        "device_processor": {"device": device.type},
        "normalizer_processor": {
            "stats": dataset.meta.stats,
            "features": {**policy.config.input_features, **policy.config.output_features},
            "norm_map": policy.config.normalization_mapping,
        },
    }
    postprocessor_kwargs["postprocessor_overrides"] = {
        "unnormalizer_processor": {
            "stats": dataset.meta.stats,
            "features": policy.config.output_features,
            "norm_map": policy.config.normalization_mapping,
        },
    }

In [None]:
# This cell was removed - autoreload is not needed for production training

In [None]:
# Create processors
preprocessor, postprocessor = make_xhuman_pre_post_processors(
    policy_cfg=cfg.policy,
    pretrained_path=cfg.policy.pretrained_path,
    **processor_kwargs,
    **postprocessor_kwargs,
)

In [None]:
# Create optimizer and scheduler
if is_main_process:
    logger.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)

step = 0  # number of policy updates

# Resume from checkpoint if needed
if cfg.resume:
    step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)

In [None]:
# Print training info
if is_main_process:
    num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
    num_total_params = sum(p.numel() for p in policy.parameters())
    logger.info(f"Output dir: {cfg.output_dir}")
    logger.info(f"Steps: {cfg.steps} ({format_big_number(cfg.steps)})")
    logger.info(f"Dataset frames: {dataset.num_frames} ({format_big_number(dataset.num_frames)})")
    logger.info(f"Dataset episodes: {dataset.num_episodes}")
    num_processes = accelerator.num_processes
    effective_bs = cfg.batch_size * num_processes
    logger.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
    logger.info(f"Learnable params: {num_learnable_params} ({format_big_number(num_learnable_params)})")
    logger.info(f"Total params: {num_total_params} ({format_big_number(num_total_params)})")

In [None]:
# Create dataloader
if hasattr(cfg.policy, "drop_n_last_frames"):
    logger.info(f"Dropping {cfg.policy.drop_n_last_frames} last frames")
    shuffle = False
    sampler = EpisodeAwareSampler(
        dataset.meta.episodes["dataset_from_index"],
        dataset.meta.episodes["dataset_to_index"],
        drop_n_last_frames=cfg.policy.drop_n_last_frames,
        shuffle=True,
    )
else:
    logger.info("Not dropping any frames")
    shuffle = True
    sampler = None



HERE

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,
    batch_size=2,
    shuffle=shuffle and not cfg.dataset.streaming,
    sampler=sampler,
    pin_memory=device.type == "cuda",
    drop_last=False,
    prefetch_factor=2 if cfg.num_workers > 0 else None,
)

In [None]:
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)

policy.train()

# Setup metrics tracking
train_metrics = {
    "loss": AverageMeter("loss", ":.3f"),
    "grad_norm": AverageMeter("grdn", ":.3f"),
    "lr": AverageMeter("lr", ":0.1e"),
    "update_s": AverageMeter("updt_s", ":.3f"),
    "dataloading_s": AverageMeter("data_s", ":.3f"),
}

effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
    effective_batch_size,
    dataset.num_frames,
    dataset.num_episodes,
    train_metrics,
    initial_step=step,
    accelerator=accelerator,
)

HERE

## Training Loop

The main training loop iterates through batches, performs forward/backward passes, and updates the policy weights.

**Note**: The loop below runs for 6 steps as an example. For full training, replace with:
```python
while step < cfg.steps:
    # ... training code ...
```

Metrics are logged at intervals specified by `cfg.log_freq`.

In [None]:
# Training initialization
# This logs the start of training and shows how many episodes will be used
if is_main_process:
    logger.info("Start offline training on a fixed dataset")
    logger.info(f"Train episodes: {len(train_episodes)}")
    logger.info(f"Total training steps: {cfg.steps}")

In [None]:
from torch.utils.data import Dataset

class SmartSubset(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

    def __getattr__(self, name):
        # This is the magic part:
        # If the code asks for 'meta', 'fps', etc., and this class doesn't have it,
        # it automatically looks inside the original dataset.
        return getattr(self.dataset, name)

# --- USAGE ---
# Use SmartSubset instead of torch.utils.data.Subset
debug_subset = SmartSubset(dataset, range(0, 50))

# Now create your loader normally
train_dataloader = torch.utils.data.DataLoader(
    debug_subset,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    drop_last=True
)

In [None]:
dl_iter = iter(train_dataloader)

In [None]:
train_metrics = {
    "loss": AverageMeter("loss", ":.3f"),
    "grad_norm": AverageMeter("grdn", ":.3f"),
    "lr": AverageMeter("lr", ":0.1e"),
    "update_s": AverageMeter("updt_s", ":.3f"),
    "dataloading_s": AverageMeter("data_s", ":.3f"),
}

effective_batch_size = cfg.batch_size * accelerator.num_processes
train_tracker = MetricsTracker(
    effective_batch_size,
    dataset.num_frames,
    dataset.num_episodes,
    train_metrics,
    initial_step=step,
    accelerator=accelerator,
)

In [None]:
for step in range(0, 2): # Or however many steps you want
    # --- SAFE ITERATOR LOGIC ---
    try:
        batch_ = next(dl_iter)
    except StopIteration:
        # We finished the 50 items! Restart from the beginning.
        dl_iter = iter(train_dataloader)
        batch_ = next(dl_iter)
    start_time = time.perf_counter()
    batch = preprocessor(batch_)

    # print(batch)

    train_tracker, output_dict = update_policy(
        train_tracker,
        policy,
        batch,
        optimizer,
        cfg.optimizer.grad_clip_norm,
        accelerator=accelerator,
        lr_scheduler=lr_scheduler,
    )
    print(train_tracker)
    print(output_dict)

In [None]:
frames = next(dl_iter)

In [None]:
batch = preprocessor(frames)

In [None]:
batch.keys()

# Test loss calculation

In [None]:
tokens_id = "observation.language.tokens"
mask_id = "observation.language.attention_mask"

images, img_masks = policy._preprocess_images(batch)
tokens, masks = batch[tokens_id], batch[mask_id]
noise = None
time = None
actions = policy.prepare_action(batch)

In [None]:
        tokens = tokens.clone()
        # Embed actions and observations.
        if noise is None:
            noise = policy.model.sample_noise(actions.shape, actions.device)

        if time is None:
            time = policy.model.sample_time(actions.shape[0], actions.device)

        # actions shape: (B, 1, action_dim)
        actions_clone = actions.clone().to("cpu")

        # discrete_actions list of ints: (B, discrete_action_seq_len)
        discrete_actions = policy.model.fast_tokenizer(actions_clone)
        list_discrete_actions = (policy.model.paligemma_with_expert.act_tokens_to_paligemma_tokens(discrete_actions))  # noqa: E501


In [None]:
        list_len_discrete_actions = [
            t.shape[0] for t in list_discrete_actions
        ]

        # Find first padding position (where tokens == 0) for each sample
        # If no padding found, argmax returns 0, so we need to handle that
        pad_positions = (tokens == 0).int().argmax(dim=1)
        # If no padding token found, set to sequence length
        no_pad_mask = (tokens == 0).any(dim=1)
        pad_positions = torch.where(
            no_pad_mask, pad_positions, tokens.shape[1]
        )

        # token_loss_mask shape: (B, seq_len)
        token_loss_mask = torch.zeros_like(
            tokens, dtype=torch.bool, device=tokens.device
        )

        # Clone masks to avoid modifying the original
        masks = masks.clone()

        # Insert discrete action tokens at padding positions and update masks
        for i in range(tokens.shape[0]):
            pad_idx = pad_positions[i].item()
            discrete_len = list_len_discrete_actions[i]
            discrete_tokens = list_discrete_actions[i].to(
                device=tokens.device, dtype=tokens.dtype
            )

            # Check if discrete actions fit in the sequence
            if pad_idx + discrete_len > tokens.shape[1]:
                # Truncate discrete actions if they don't fit
                available_space = tokens.shape[1] - pad_idx
                if available_space > 0:
                    discrete_tokens = discrete_tokens[:available_space]
                    discrete_len = available_space
                else:
                    # No space available, skip this sample's discrete actions
                    continue

            # Insert discrete action tokens at padding position
            tokens[i, pad_idx:pad_idx + discrete_len] = discrete_tokens
            # Mark these positions for loss computation
            token_loss_mask[i, pad_idx:pad_idx + discrete_len] = True
            # Mark discrete action positions as valid (True) in masks
            masks[i, pad_idx:pad_idx + discrete_len] = True

In [None]:
        prefix_embs, prefix_pad_masks, prefix_att_masks = policy.model.embed_prefix(
            images, img_masks, tokens, masks,
        )
        # Clone prefix_att_masks before modifying since it's a view
        prefix_att_masks = prefix_att_masks.clone()
        # prefix_att_masks structure:
        # [image_embs..., original_lang_tokens, discrete_actions]
        num_image_embs = prefix_embs.shape[1] - tokens.shape[1]


In [None]:
        for i in range(tokens.shape[0]):
            pad_idx = pad_positions[i].item()
            discrete_len = list_len_discrete_actions[i]

            # Check if discrete actions fit
            if pad_idx + discrete_len <= tokens.shape[1]:
                # Calculate position in embedded space
                discrete_action_start_idx = num_image_embs + pad_idx
                discrete_action_end_idx = (
                    discrete_action_start_idx + discrete_len
                )

                # Set att_mask=1 for discrete action tokens (causal masking)
                if discrete_action_end_idx <= prefix_att_masks.shape[1]:
                    prefix_att_masks[
                        i, discrete_action_start_idx:discrete_action_end_idx
                    ] = 1

In [None]:
from xhuman.policies.pi05.modeling_pi_05 import make_att_2d_masks

In [None]:

        time_expanded = time[:, None, None]
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions

        suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = policy.model.embed_suffix(  # noqa: E501
            x_t, time,
        )

        q_proj_dtype = (
            policy.model.paligemma_with_expert.paligemma.language_model.layers[0]
            .self_attn.q_proj.weight.dtype
        )

        if q_proj_dtype == torch.bfloat16:
            suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
            prefix_embs = prefix_embs.to(dtype=torch.bfloat16)

        pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
        att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)

        att_2d_masks = make_att_2d_masks(pad_masks, att_masks)

        suffix_start = prefix_pad_masks.shape[1]

In [None]:
        for i in range(tokens.shape[0]):
            pad_idx = pad_positions[i].item()
            discrete_len = list_len_discrete_actions[i]

            # Check if discrete actions fit
            if pad_idx + discrete_len <= tokens.shape[1]:
                # Calculate position in embedded space
                discrete_action_start_idx = num_image_embs + pad_idx
                discrete_action_end_idx = (
                    discrete_action_start_idx + discrete_len
                )

                # Prevent suffix (flow matching) from attending to discrete
                # action tokens
                if discrete_action_end_idx <= att_2d_masks.shape[2]:
                    att_2d_masks[
                        i,
                        suffix_start:,
                        discrete_action_start_idx:discrete_action_end_idx,
                    ] = False

In [None]:

        position_ids = torch.cumsum(pad_masks, dim=1) - 1

        att_2d_masks_4d = policy.model._prepare_attention_masks_4d(att_2d_masks)

In [None]:
import torch.nn.functional as F

In [None]:
        def forward_func(prefix_embs,
                         suffix_embs,
                         att_2d_masks_4d,
                         position_ids,
                         adarms_cond,
                         ):
            # Returns both prefix_out and suffix_out for dual loss computation
            (prefix_out, suffix_out), _ = policy.model.paligemma_with_expert.forward(
                attention_mask=att_2d_masks_4d,
                position_ids=position_ids,
                past_key_values=None,
                inputs_embeds=[prefix_embs, suffix_embs],
                use_cache=False,
                adarms_cond=[None, adarms_cond],
            )
            return prefix_out, suffix_out

        prefix_out, suffix_out = policy.model._apply_checkpoint(
            forward_func, prefix_embs, suffix_embs,
            att_2d_masks_4d, position_ids, adarms_cond
        )

        suffix_out = suffix_out[:, -policy.model.config.chunk_size:]
        suffix_out = suffix_out.to(dtype=torch.float32)

In [None]:
        def action_out_proj_func(suffix_out):
            return policy.model.action_out_proj(suffix_out)

        v_t = policy.model._apply_checkpoint(action_out_proj_func, suffix_out)

        # v_t shape: [B, chunk_size, action_dim]
        flow_matching_loss = F.mse_loss(u_t, v_t, reduction="none")
        flow_matching_loss = flow_matching_loss.mean(dim=(1, 2))

        # prefix_out shape: [B, prefix_seq_len, hidden_dim]
        def lm_head_func(prefix_out):
            return policy.model.paligemma_with_expert.paligemma.lm_head(prefix_out)

        # logits shape: [B, prefix_seq_len, vocab_size]
        logits = policy.model._apply_checkpoint(lm_head_func, prefix_out)

        num_image_embs = prefix_embs.shape[1] - tokens.shape[1]
        lang_logits = logits[:, num_image_embs:, :]

        # shift_logits shape: [B, lang_seq_len-1, vocab_size]
        shift_logits = lang_logits[:, :-1, :].contiguous()
        # shift_labels shape: [B, lang_seq_len-1]
        shift_labels = tokens[:, 1:].contiguous()

        # shift_mask shape: [B, lang_seq_len-1]
        shift_mask = token_loss_mask[:, 1:].contiguous()

In [None]:
        # flat_logits shape: [B*(lang_seq_len-1), vocab_size]
        flat_logits = shift_logits.view(-1, shift_logits.size(-1))
        # flat_labels shape: [B*(lang_seq_len-1)]
        flat_labels = shift_labels.view(-1)
        flat_mask = shift_mask.view(-1)

        # ce_loss_per_token shape: [B*(lang_seq_len-1)]
        ce_loss_per_token = F.cross_entropy(
            flat_logits, flat_labels, reduction='none'
        )

        # masked_ce shape: [B, lang_seq_len-1]
        masked_ce = (ce_loss_per_token * flat_mask).view(tokens.shape[0], -1)  # noqa: E501
        # ce_loss shape: [B]
        ce_loss = masked_ce.sum(dim=1) / (shift_mask.sum(dim=1) + 1e-8)

In [None]:
flat_logits

In [None]:
flat_mask

In [None]:

ce_loss, flow_matching_loss