# Training Core

> Core components for the training pipeline, including loss functions and potentially shared callbacks or utilities.

In [1]:
#| default_exp training.core

In [1]:
#| hide
from nbdev.showdoc import *

In [2]:
#| export
import sys
from pathlib import Path
import os
import warnings

# Assumes the notebook is run from the project root or one level down (e.g., nbs/)
# Navigate up to the project root (where settings.ini or .git likely exists)
project_root = Path(os.getcwd())
# Simple check: If settings.ini is not in cwd, assume we are in nbs/ and go up one level
if not (project_root / 'settings.ini').exists() and (project_root.parent / 'settings.ini').exists():
    project_root = project_root.parent

project_root_str = str(project_root.resolve())

if project_root_str not in sys.path:
    print(f"Adding project root to sys.path: {project_root_str}")
    sys.path.insert(0, project_root_str)
else:
    print(f"Project root already in sys.path: {project_root_str}")

Adding project root to sys.path: /workspace/llava


In [20]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastai.callback.wandb import WandbCallback
from fastai.learner import Learner
from fastai.data.core import DataLoaders
from llava.data.preprocessing import IGNORE_INDEX # Import ignore index constant


from fastai.callback.fp16 import MixedPrecision
from fastai.callback.training import GradientAccumulation, find_bs
from fastai.torch_core import to_float
from transformers.modeling_outputs import ModelOutput, CausalLMOutputWithPast # Base class for HF outputs

This notebook will contain the core elements needed for training, starting with the custom loss function required for LLaVA-style training where certain tokens (like prompts and padding) are ignored.

## Step 3.1: Implement Custom Loss Function

In [25]:
#| export
def extract_loss_from_output(model_output:CausalLMOutputWithPast, *args):
    if hasattr(model_output, 'loss') and model_output.loss is not None:
        return model_output.loss
    else:
        raise ValueError(
            "Loss attribute not found or is None in the model output object. "
            "Ensure 'labels' are correctly passed to the model's forward method during training, "
            "and the model computes loss internally. If this is learner.summary(), ensure dummy labels are passed if model expects them."
        )

In [6]:
#| export
class LLaVALoss(nn.Module):
    """ Custom CrossEntropyLoss that ignores indices where labels are IGNORE_INDEX (default -100).
    
    This loss function handles the standard autoregressive language modeling loss
    by shifting the logits and labels, ensuring the model predicts the next token.
    It specifically ignores tokens marked with `ignore_index` in the labels tensor,
    which is crucial for masking out prompt tokens, padding tokens, and image tokens
    during LLaVA training.
    """
    def __init__(self, ignore_index=IGNORE_INDEX):
        """ Initializes the loss function.
        
        Args:
            ignore_index (int): The label index to be ignored during loss calculation.
                                Defaults to the value imported from llava.data.preprocessing.
        """
        super().__init__()
        self.ignore_index = ignore_index
        # Initialize the standard CrossEntropyLoss with the specified ignore_index
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
        print(f"LLaVALoss initialized, ignoring index: {self.ignore_index}")

    def forward(self, output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """ Calculates the cross-entropy loss, ignoring specified indices.
        
        Args:
            output (torch.Tensor): The model's output logits.
                                   Shape: (batch_size, sequence_length, vocab_size).
            target (torch.Tensor): The target labels (token IDs).
                                   Shape: (batch_size, sequence_length).
                                   Should contain `ignore_index` for tokens to be ignored.
                                   
        Returns:
            torch.Tensor: A scalar tensor representing the calculated loss.
        """
        # --- Shift logits and labels for next token prediction --- 
        # Logits are shifted left (we predict the token *after* the current one)
        # output shape: (batch_size, sequence_length, vocab_size)
        shift_logits = output[..., :-1, :].contiguous()
        # Labels are shifted left (the target for the prediction at time t is the token at t+1)
        # target shape: (batch_size, sequence_length)
        shift_labels = target[..., 1:].contiguous()

        # --- Flatten the tokens for CrossEntropyLoss --- 
        # The CrossEntropyLoss expects input shape (N, C) where N is the number of samples
        # and C is the number of classes (vocab_size). The target shape should be (N).
        # Shift_logits flattened shape: (batch_size * (sequence_length - 1), vocab_size)
        # Shift_labels flattened shape: (batch_size * (sequence_length - 1))
        vocab_size = shift_logits.size(-1)
        loss = self.loss_fct(shift_logits.view(-1, vocab_size), 
                             shift_labels.view(-1))
        
        # Ensure loss is a scalar
        if loss.dim() > 0:
             # This might happen if the batch size or sequence length becomes 0 or 1 after masking/shifting.
             # Although CrossEntropyLoss typically returns a scalar, handle defensively.
             loss = loss.mean() 

        return loss

In [None]:
show_doc(LLaVALoss)

---

### LLaVALoss

>      LLaVALoss (ignore_index=-100)

*Custom CrossEntropyLoss that ignores indices where labels are IGNORE_INDEX (default -100).
    
    This loss function handles the standard autoregressive language modeling loss
    by shifting the logits and labels, ensuring the model predicts the next token.
    It specifically ignores tokens marked with `ignore_index` in the labels tensor,
    which is crucial for masking out prompt tokens, padding tokens, and image tokens
    during LLaVA training.*

#### Example Usage & Test

In [7]:
#| test
# Create dummy data
batch_size = 2
seq_len = 5
vocab_size = 10
ignore_idx = -100

# Dummy logits (B, S, V)
dummy_logits = torch.randn(batch_size, seq_len, vocab_size)
# Dummy labels (B, S) with some ignored indices
dummy_labels = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long)
# Add ignore index (-100)
dummy_labels[0, -1] = ignore_idx  # Last token ignored for sample 0
dummy_labels[1, 2] = ignore_idx # 3rd token ignored for sample 1

# Instantiate the loss
loss_func = LLaVALoss(ignore_index=ignore_idx)

print(f"Target shape: {dummy_labels.shape}")
print(f"Target:\n{dummy_labels}")
print(f"Logits shape: {dummy_logits.shape}")

print("Calculating loss...")
# Calculate loss
loss = loss_func(dummy_logits, dummy_labels)

# Check if loss is a scalar tensor
assert loss.dim() == 0, f"Loss should be scalar, but got shape {loss.shape}"

# Manually calculate expected loss for verification (optional, complex)
# shift_logits_manual = dummy_logits[:, :-1, :].reshape(-1, vocab_size)
# shift_labels_manual = dummy_labels[:, 1:].reshape(-1)
# expected_loss = F.cross_entropy(shift_logits_manual, shift_labels_manual, ignore_index=ignore_idx)
# print(f"Expected Loss (Manual): {expected_loss.item():.4f}")

# Get shifted labels to show what's being used in loss
shift_labels_print = dummy_labels[..., 1:].contiguous()
print(f"Shifted Logits shape: {dummy_logits[..., :-1, :].shape}")
print(f"Shifted Labels shape: {shift_labels_print.shape}")
print(f"Shifted Labels:\n{shift_labels_print}")

print(f"Calculated Loss: {loss.item():.4f}")

assert torch.isfinite(loss), "Loss calculation resulted in non-finite value."
print("\nLoss calculation test passed.")

LLaVALoss initialized, ignoring index: -100
Target shape: torch.Size([2, 5])
Target:
tensor([[   5,    3,    4,    4, -100],
        [   7,    1, -100,    8,    9]])
Logits shape: torch.Size([2, 5, 10])
Calculating loss...
Shifted Logits shape: torch.Size([2, 4, 10])
Shifted Labels shape: torch.Size([2, 4])
Shifted Labels:
tensor([[   3,    4,    4, -100],
        [   1, -100,    8,    9]])
Calculated Loss: 3.2114

Loss calculation test passed.


## Training Utilities (e.g., Learner Setup - To be implemented later)

In [None]:
# Placeholder for learner setup functions (e.g., get_stage1_learner)
# These will utilize WandbCallback imported above.

In [21]:
#| export
class SafeGradientAccumulation(GradientAccumulation):
    "A GradientAccumulation callback that clones loss if it requires grad before in-place division."
    def after_loss(self):
        "Divides `loss_grad` by `n_acc`."
        # If loss_grad is a tensor that requires grad, clone it before the in-place op
        if isinstance(self.learn.loss_grad, torch.Tensor) and self.learn.loss_grad.requires_grad:
            self.learn.loss_grad = self.learn.loss_grad.clone() / (self.n_acc/find_bs(self.learn.yb))
        else: # Otherwise, proceed as normal (handles cases where loss_grad might not require grad, e.g. after detach)
            self.learn.loss_grad /= (self.n_acc/find_bs(self.learn.yb))

In [16]:
#| export
class LLaVAMixedPrecision(MixedPrecision):
    "Mixed precision training specifically handling HF model outputs"
    def after_pred(self):
        pred = self.learn.pred
        if isinstance(pred, ModelOutput):
            if hasattr(pred, 'logits'):
                logits_val = getattr(pred, 'logits')
                if isinstance(logits_val, torch.Tensor) and logits_val.is_floating_point():
                    try:
                        # This attempts to modify the logits attribute of the ModelOutput object
                        setattr(pred, 'logits', to_float(logits_val))
                        # self.learn.pred remains the ModelOutput object, but with float logits
                    except AttributeError:
                        # If pred is immutable (e.g. namedtuple, though HF usually uses dataclasses)
                        # This path is less likely for HF outputs but good for robustness
                        warnings.warn("LLaVAMixedPrecision: Could not set 'logits' attribute directly on model output. Applying to_float to the whole output.", UserWarning)
                        self.learn.pred = to_float(pred)
                # else: logits might not be float or not a tensor, do nothing to it.
            # else: pred is ModelOutput but no logits, do nothing.
        elif isinstance(pred, torch.Tensor) and pred.is_floating_point():
            self.learn.pred = to_float(pred)
        elif isinstance(pred, (list, tuple)): # Handle cases where pred might be a list/tuple of tensors
             self.learn.pred = apply(lambda x: to_float(x) if isinstance(x, torch.Tensor) and x.is_floating_point() else x, pred)
        # else: pred is some other type, do nothing.

In [27]:
#| hide
import nbdev; nbdev.nbdev_export()