# Training Core

> Core components for the training pipeline, including loss functions and potentially shared callbacks or utilities.

In [None]:
#| default_exp training.core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastai.callback.wandb import WandbCallback
from fastai.learner import Learner
from fastai.data.core import DataLoaders

This notebook will contain the core elements needed for training, starting with the custom loss function required for LLaVA-style training where certain tokens (like prompts and padding) are ignored.

## Loss Function (To be implemented in Step 3.1)

In [None]:
# Placeholder for LLaVALoss implementation
#| export
class LLaVALoss(nn.Module):
    """ Custom CrossEntropyLoss that ignores indices where labels are -100. """
    def __init__(self, ignore_index=-100):
        super().__init__()
        self.ignore_index = ignore_index
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)

    def forward(self, output, target):
        # Shift logits and labels for next token prediction
        # output shape: (batch_size, sequence_length, vocab_size)
        # target shape: (batch_size, sequence_length)
        shift_logits = output[..., :-1, :].contiguous()
        shift_labels = target[..., 1:].contiguous()

        # Flatten the tokens
        # Shift_logits flattened shape: (batch_size * (sequence_length - 1), vocab_size)
        # Shift_labels flattened shape: (batch_size * (sequence_length - 1))
        loss = self.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                             shift_labels.view(-1))
        return loss

## Training Utilities (e.g., Learner Setup - To be implemented later)

In [None]:
# Placeholder for learner setup functions (e.g., get_stage1_learner)
# These will utilize WandbCallback imported above.

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()