# Training Routine

This notebook contains the training routine for all models.



In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install -U transformers
!pip install -U datasets

In [None]:
# This notebook was prepared with help of
# https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Zephyr_(7B)-DPO.ipynb#scrollTo=E8-BWi7MzkRz
# https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb#scrollTo=QmUBVEnvCDJv
# (part of the official unsloth documentation)
# According to doc. this needs to run first
from unsloth import PatchDPOTrainer

PatchDPOTrainer()

In [None]:
import os
import sys
from google.colab import userdata, drive

from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset, Dataset
from trl import DPOConfig, DPOTrainer
from transformers import TrainerCallback
from huggingface_hub import login
from transformers.trainer_utils import get_last_checkpoint

import wandb
import torch
import torch.nn.functional as F
import gc
from collections import defaultdict

In [None]:
# This is only needed if the data gets loaded from google drive
drive.mount('/content/drive')

In [None]:
os.environ['WANDB_API_KEY'] = userdata.get('WB_TOKEN')
wandb.login()

os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
login(token = os.environ['HF_TOKEN'])

In [None]:
# Limit reserved but unallocated memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Empty memory
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

## Load Data

In [None]:
dataset = load_dataset( 'parquet',
    data_files={
        'train':    '/content/drive/MyDrive/practical_course2/data/agent_train.parquet',
        'test':     '/content/drive/MyDrive/practical_course2/data/agent_test.parquet'
    }
)

# Empty chosen strings are deleted
def is_valid(example):
    return example['chosen'][0]['content'].strip() != ''

dataset['train'] = dataset['train'].filter(is_valid)
dataset['test'] = dataset['test'].filter(is_valid)

In [None]:
# Create train-test split
train_split = dataset['train'].train_test_split(test_size=0.2, seed=42)

train_dataset = train_split['train']
eval_dataset = train_split['test']

## Load Model

In [None]:
# Get model using unsloth, base model gets loaded automatically
# For SFT, no base model loading is necessary: This is handled by unsloth

model_string = 'Qwen2.5-3B-Instruct'        # Or 'Qwen2.5-3B-persona-SFT' for SFT
model_name = f'Qwen/{model_string}'         # For baseline model
# model_name = f'nicomu99/{model_string}'   # For SFT models

MAX_SEQ_LENGTH = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name,
    max_seq_length  = MAX_SEQ_LENGTH,
    dtype           = None,     # automatic detection (float16 or bfloat16)
    load_in_4bit    = False,
)

In [None]:
# Only needed for base model
model = FastLanguageModel.get_peft_model(
    model,
    r               = 16,
    target_modules  = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha      = 32,
    lora_dropout    = 0.01,
    bias            = 'none',
    random_state    = 42,
    use_gradient_checkpointing = False,
)

## Data Sampling

The dataset contains the columns 'prompt', 'chosen' and 'rejected'. Only each
tuple <prompt, chosen, rejected> is truly unique. We therefore want the model
to see each prompt only once per epoch.

In [None]:
class PromptSampler:
    """
    Creates a new data sample each time the function sample_unique_per_prompt
    is called. The sampling is deterministic. The sorting is defined by the
    order of occurrence of a prompt/completion pair in the input dataset.
    Each completion of a sample is sampled at most once. The class keeps a
    counter which is incremented after a new sample is created.

    Attributes:
        dataset (Dataset): The dataset used to generate the samples.
    """

    def __init__(self, dataset):
        """
        Initializes the PromptSampler.

        Args:
            dataset (Dataset): The dataset from which the samples are generated.
        """
        self.dataset = dataset
        self.prompt_to_examples = defaultdict(list)
        self._build_index()
        self.epoch = 0
        self.selected_indices = []

    def _build_index(self):
        """
        Builds an index of the prompts with the index of their corresponding
        completions.
        """
        for idx, sample in enumerate(self.dataset):
            prompt = sample['prompt'][-1]['content']
            self.prompt_to_examples[prompt].append(idx)

    def sample_unique_per_prompt(self) -> Dataset:
        """
        Returns the next sample, with each sample containing a prompt only once.

        Returns:
            Dataset: The next sample.
        """

        self.selected_indices = []
        for prompt, idx_list in self.prompt_to_examples.items():
            self.selected_indices.append(idx_list[self.epoch % len(idx_list)])

        self.epoch += 1
        return self.dataset.select(self.selected_indices)

    def __len__(self) -> int:
        """
        Returns the number of unique prompts in the dataset.

        Returns:
            int: The number of unique prompts.
        """
        return len(self.prompt_to_examples)

train_sampler = PromptSampler(train_dataset)
eval_sampler = PromptSampler(eval_dataset)

## Train

In [None]:
TRAIN_BATCH_SIZE    = 4
ACCUMULATION_STEPS  = 32

NUM_EPOCHS  = 5
LOG_STEPS   = 30

LR          = 1e-5
WS          = 50
BETA        = 0.1
LR_STRAT    = 'cosine'
DPO_L       = 0.95      # should be 0 <= x <= 1.0
DPO_L_STRAT = 'fixed'   # either 'fixed', 'decrease_linear' or 'increase_linear'
DPO_MIN     = 0.5
DPO_MAX     = 1.0


PROJECT_NAME = 'pr2-train'

# For DPOShift
NAME = (
    f'{model_string}-lr{LR}-ws{WS}-{LR_STRAT}-'
    f'beta{BETA}-dpol{DPO_L}-{DPO_L_STRAT}'
).replace(".", "_")

# For DPO
# NAME = (
#     f'{model_string}-lr{LR}-ws{WS}-{LR_STRAT}-'
#     f'beta{BETA}'
# ).replace(".", "_")

output_dir = './out/'
training_args = DPOConfig(
    # Logging
    logging_strategy            = 'steps',
    logging_steps               = LOG_STEPS,
    report_to                   = 'wandb',
    run_name                    = PROJECT_NAME + NAME,
    logging_first_step          = True,

    # Batch and dataloader settings
    per_device_train_batch_size = TRAIN_BATCH_SIZE,
    per_device_eval_batch_size  = TRAIN_BATCH_SIZE,
    gradient_accumulation_steps = ACCUMULATION_STEPS,
    dataloader_num_workers      = 2,

    # Optimization hyperparameters
    learning_rate               = LR,
    weight_decay                = 0.0,
    warmup_steps                = WS,
    lr_scheduler_type           = LR_STRAT,
    fp16                        = not is_bfloat16_supported(),
    bf16                        = is_bfloat16_supported(),

    # Training epochs and model saving
    num_train_epochs            = NUM_EPOCHS,
    eval_strategy               = 'epoch',
    output_dir                  = output_dir,
    save_strategy               = 'epoch',
    load_best_model_at_end      = True,
    save_total_limit            = 2,

    #DPO-Shift Config
    beta    = BETA,

    # Misc
    seed                        = 42,
    data_seed                   = 42,
    max_prompt_length           = 1024,
    max_completion_length       = 256,
)

- If `prompt_length` > `max_prompt_length`: Truncated on right side
- For completions (`chosen` and `rejected)`: Truncated on left side

In [None]:
# TAKEN FROM:
# https://stackoverflow.com/questions/71596920/how-to-perform-a-single-epoch-of-training-with-huggingfaces-trainer
class StopCallback(TrainerCallback):
    """
    Callback that stops the training after the epoch ends so new dataset can be inserted.
    """

    def on_epoch_end(self, args, state, control, logs=None, **kwargs):
        """
        Stops the training after the epoch ends. We need this to update the
        data samples before the next epoch starts.
        """
        control.should_training_stop = True

In [None]:
# This code snippet was partially taken from https://github.com/Meaquadddd/DPO-Shift
# dposhift summitted to ICML 2025
# developed based on the original DPOShiftTrainer from huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the Licens
class DPOShiftTrainer(DPOTrainer):
    def __init__(self, dpo_lambda, dpo_lambda_strategy, dpo_lambda_min, dpo_lambda_max, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.dpo_lambda = dpo_lambda
        self.dpo_lambda_strategy = dpo_lambda_strategy
        self.dpo_lambda_min = dpo_lambda_min
        self.dpo_lambda_max = dpo_lambda_max

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
    ):
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        device = self.accelerator.device

        chosen_logratios = policy_chosen_logps.to(device) - (
            not self.reference_free
        ) * reference_chosen_logps.to(device)
        rejected_logratios = policy_rejected_logps.to(device) - (
            not self.reference_free
        ) * reference_rejected_logps.to(device)


        pi_logratios = policy_chosen_logps - self.dpo_lambda*policy_rejected_logps
        if self.reference_free:
            ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
        else:
            ref_logratios = reference_chosen_logps - self.dpo_lambda*reference_rejected_logps


        pi_logratios = pi_logratios.to(self.accelerator.device)
        ref_logratios = ref_logratios.to(self.accelerator.device)
        logits = pi_logratios - ref_logratios

        # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
        # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
        # calculates a conservative DPO loss.
        if self.loss_type == "sigmoid":
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )
        else:
            raise ValueError(
                f"Unknown loss type: {self.loss_type}. Should be 'sigmoid'"
            )

        chosen_rewards = (
            self.beta
            * (
                policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
            ).detach()
        )
        rejected_rewards = (
            self.beta
            * (
                policy_rejected_logps.to(self.accelerator.device)
                - reference_rejected_logps.to(self.accelerator.device)
            ).detach()
        )

        return losses, chosen_rewards, rejected_rewards

In [None]:
# Init wandb once before loop
wandb.init(
    project = PROJECT_NAME,
    name    = NAME,
    resume  = 'allow'
)

In [None]:
# Custom training loop: HF does not provide options to modify training after each epoch
# Small hack to allow update of the dataset
for epoch in range(NUM_EPOCHS):
    print(f"\n===== Epoch {epoch + 1}/{NUM_EPOCHS} =====\n")

    # For DPO Shift
    trainer = DPOShiftTrainer(
        dpo_lambda          = DPO_L,
        dpo_lambda_strategy = DPO_L_STRAT,
        dpo_lambda_min      = DPO_MIN,
        dpo_lambda_max      = DPO_MAX,
        model               = model,
        ref_model           = None,
        beta                = BETA,
        args                = training_args,
        tokenizer           = tokenizer,
        train_dataset       = train_sampler.sample_unique_per_prompt(),
        eval_dataset        = eval_sampler.sample_unique_per_prompt(),
        callbacks           = [StopCallback()],
    )

    # For DPO
    # trainer = DPOShiftTrainer(
    #     dpo_lambda          = DPO_L,
    #     dpo_lambda_strategy = DPO_L_STRAT,
    #     dpo_lambda_min      = DPO_MIN,
    #     dpo_lambda_max      = DPO_MAX,
    #     model               = model,
    #     ref_model           = None,
    #     beta                = BETA,
    #     args                = training_args,
    #     tokenizer           = tokenizer,
    #     train_dataset       = train_sampler.sample_unique_per_prompt(),
    #     eval_dataset        = eval_sampler.sample_unique_per_prompt(),
    #     callbacks           = [StopCallback()],
    # )

    trainer.train(resume_from_checkpoint=False if epoch == 0 else True)

wandb.finish()

last_checkpoint = get_last_checkpoint(training_args.output_dir)
model, tokenizer = FastLanguageModel.from_pretrained(
    last_checkpoint,
    max_seq_length  = MAX_SEQ_LENGTH,
    dtype           = None,     # automatic detection (float16 or bfloat16)
    load_in_4bit    = False,
)

# Save LoRA adapters
save_name = 'nicomu99/prompt-dpo-' + NAME
model.push_to_hub(save_name, private=True)
tokenizer.push_to_hub(save_name, private=True)