[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/khetansarvesh/NLP/blob/main/dpo/dpo_from_scratch.ipynb)

In [7]:
# # Install required packages
# !pip install "transformers==4.31.0" "trl==0.4.5" "datasets>=4.1.0" "torch>=2.8.0"
# !pip install "accelerate>=1.10.1" "peft>=0.17.0" "trackio"
# !pip install trl

In [9]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig
import json
import torch.nn.functional as F

# Check available device
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# **Preference Dataset**

In [10]:
# Load a preference dataset to understand the format
dataset = load_dataset("Anthropic/hh-rlhf", split="train")
print(f"Dataset size: {len(dataset)}")
print("Dataset features:", dataset.features.keys())

# Examine a preference pair
sample = dataset[0]
print(f"\nChosen (Preferred): {sample['chosen'][:200]}...")
print(f"\nRejected (Non-preferred): {sample['rejected'][:200]}...")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

harmless-base/train.jsonl.gz:   0%|          | 0.00/13.2M [00:00<?, ?B/s]

helpful-base/train.jsonl.gz:   0%|          | 0.00/16.2M [00:00<?, ?B/s]

helpful-online/train.jsonl.gz:   0%|          | 0.00/20.1M [00:00<?, ?B/s]

helpful-rejection-sampled/train.jsonl.gz:   0%|          | 0.00/25.7M [00:00<?, ?B/s]

harmless-base/test.jsonl.gz:   0%|          | 0.00/743k [00:00<?, ?B/s]

helpful-base/test.jsonl.gz:   0%|          | 0.00/875k [00:00<?, ?B/s]

helpful-online/test.jsonl.gz:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

helpful-rejection-sampled/test.jsonl.gz:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/160800 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/8552 [00:00<?, ? examples/s]

Dataset size: 160800
Dataset features: dict_keys(['chosen', 'rejected'])

Chosen (Preferred): 

Human: What are some cuss words in english?

Assistant: Here’s an incomplete list.

Ass, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cu...

Rejected (Non-preferred): 

Human: What are some cuss words in english?

Assistant: Here’s an incomplete list.

Ass, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cu...


In [None]:
# Load a small subset for local testing
small_dataset = dataset.select(range(1000))

# **Load Pretrained with SFT Model**

In [None]:
model_name = "HuggingFaceTB/SmolLM3-3B-Instruct" #"HuggingFaceTB/SmolLM3-3B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token


# **DPO Training : Using Hugging Face's TRL Library**

In [None]:
training_args = DPOConfig(
    beta=0.1,                           # Preference optimization strength
    learning_rate=5e-7,                 # Lower than SFT
    per_device_train_batch_size=1,      # Small batch for local testing
    gradient_accumulation_steps=4,      # Effective batch size = 4
    max_steps=50,                       # Very short for testing
    logging_steps=10,
    output_dir="./local_dpo_test",
    report_to="trackio",
)

In [None]:
training_args = DPOConfig(
    # Core DPO parameters
    beta=0.1,                           # Preference optimization strength
    max_prompt_length=512,              # Maximum prompt length
    max_length=1024,                    # Maximum total sequence length

    # Training configuration
    learning_rate=5e-7,                 # Lower than SFT for stability
    per_device_train_batch_size=1,      # batch size
    gradient_accumulation_steps=4,      # Effective batch size = 4
    max_steps=1000,                     # Sufficient for good alignment

    # Optimization
    warmup_steps=100,
    lr_scheduler_type="cosine",
    gradient_checkpointing=True,        # Memory efficiency
    bf16=True,                          # Mixed precision

    # Logging and saving
    logging_steps=50,
    save_steps=250,
    output_dir="./smollm3-dpo-aligned",

    # Hub integration
    push_to_hub=True,
    hub_model_id="your-username/smollm3-dpo-aligned",  # Change this!
    report_to="trackio",

    # Remove unused columns for cleaner training
    remove_unused_columns=False,
)

In [None]:
# Create trainer (but don't train yet - save resources for HF Jobs)
trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=small_dataset,
    processing_class=tokenizer,
)

In [None]:
trainer.train()

# **DPO Training : Implementing from scratch**

### *Implementing DPO Loss Equation*

In [None]:
def compute_dpo_loss( model_chosen_logprobs, model_rejected_logprobs, reference_chosen_logprobs, reference_rejected_logprobs, beta=0.1):

    model_logratios = model_chosen_logprobs - model_rejected_logprobs
    reference_logratios = reference_chosen_logprobs - reference_rejected_logprobs
    logits = model_logratios - reference_logratios
    dpo_losses = -F.logsigmoid(beta * logits)
    dpo_loss = dpo_losses.mean() # .mean() to average over the samples in the batch

    # Optional values to track progress during training
    chosen_rewards = (model_chosen_logprobs - reference_chosen_logprobs).detach()
    rejected_rewards = (model_rejected_logprobs - reference_rejected_logprobs).detach()

    return dpo_loss, chosen_rewards.mean(), rejected_rewards.mean()

### *Implementing log logits calculator*

Now above dpo loss function expects us to input log probabilites and hence here we will see how to calculate log probabilites

In [None]:
def compute_logprobs(logits, # predicted output from llm
                     labels, # real outputs
                     selection_mask=None):
    """
    Compute log probabilities.

    Args:
      logits: Tensor of shape (batch_size, num_tokens, vocab_size)
      labels: Tensor of shape (batch_size, num_tokens)
      selection_mask: Tensor for shape (batch_size, num_tokens)

    Returns:
      mean_log_prob: Mean log probability excluding padding tokens.
    """

    # Labels are the inputs shifted by one
    labels = labels[:, 1:].clone()

    # Truncate logits to match the labels num_tokens
    logits = logits[:, :-1, :]

    log_probs = F.log_softmax(logits, dim=-1)

    # Gather the log probabilities for the actual labels
    selected_log_probs = torch.gather(
        input=log_probs,
        dim=-1,
        index=labels.unsqueeze(-1)
    ).squeeze(-1)

    if selection_mask is not None:
        mask = selection_mask[:, 1:].clone()

        # Apply the mask to filter out padding tokens
        selected_log_probs = selected_log_probs * mask

        # Calculate the average log probability excluding padding tokens
        # This averages over the tokens, so the shape is (batch_size,)
        avg_log_prob = selected_log_probs.sum(-1) / mask.sum(-1)

        return avg_log_prob

    else:
        return selected_log_probs.mean(-1)

- Note that this function above might look a bit intimidating at first due to the `torch.gather` function, but it's pretty similar to what happens under the hood in PyTorch's `cross_entropy` function
- For example, consider the following example:

In [None]:
# Sample data
logits = torch.tensor(
    [[2.0, 1.0, 0.1],
     [0.5, 2.5, 0.3]])  # Shape: (2, 3)
targets = torch.tensor([0, 2])  # Shape: (2,)


# Manual loss using torch.gather
log_softmax_logits = F.log_softmax(logits, dim=1)  # Shape: (2, 3)
selected_log_probs = torch.gather(
    input=log_softmax_logits,
    dim=1,
    index=targets.unsqueeze(1), # Shape 2, 1
).squeeze(1)  # Shape: (2,)
manual_loss = -selected_log_probs.mean()  # Averaging over the batch


# PyTorch loss
cross_entropy_loss = F.cross_entropy(logits, targets)

print(manual_loss, cross_entropy_loss)

tensor(1.4185) tensor(1.4185)


- So, above, we can see that the two implementations are equivalent, but let's narrow down a bit further to the `torch.gather` mechanics
- Consider the following two tensors:

In [None]:
t = torch.tensor(
  [[1., 2.,],
   [3., 4.]]
)

m = torch.tensor(
  [[1, 1],
   [0, 1]]
)

- Above, `t` is a tensor we want to select from, and `m` is a mask to specify how we want to select
 - For instance, since `m` contains `[1, 1]` n the first row, it will select two times the value of `t` in index position `1`, which is the value 2.
 - The second row of `m`, `[0, 1]`, selects index positions 0 and 1 in the second row or `t`, which are `3.` and `4.`

In [None]:
torch.gather(input=t, dim=-1, index=m)

tensor([[2., 2.],
        [3., 4.]])

- In other words, `torch.gather` is a selection function
- When we computed the loss earlier, we used it to retrieve the log probabilities corresponding to the correct token in the 50,257-token vocabulary
- The "correct" tokens are the tokens given in the response entry

- Regarding the `compute_logprobs` function above, we use `torch.gather` here because it gives us a bit more control than `cross_entropy`, but is, in essence, a similar idea
- The `selection_mask` we use there is to optionally ignore prompt and padding tokens

### *DPO Loss Function -- Single Batch*

In [None]:
def compute_dpo_loss_batch(batch, policy_model, reference_model, beta):

    policy_chosen_log_probas = compute_logprobs(
        logits=policy_model(batch["chosen"]),
        labels=batch["chosen"],
        selection_mask=batch["chosen_mask"]
    )

    policy_rejected_log_probas = compute_logprobs(
        logits=policy_model(batch["rejected"]),
        labels=batch["rejected"],
        selection_mask=batch["rejected_mask"]
    )

    with torch.no_grad():
        ref_chosen_log_probas = compute_logprobs(
            logits=reference_model(batch["chosen"]),
            labels=batch["chosen"],
            selection_mask=batch["chosen_mask"]
        )
        ref_rejected_log_probas = compute_logprobs(
            logits=reference_model(batch["rejected"]),
            labels=batch["rejected"],
            selection_mask=batch["rejected_mask"]
        )

    loss, chosen_rewards, rejected_rewards = compute_dpo_loss(
        model_chosen_logprobs=policy_chosen_log_probas,
        model_rejected_logprobs=policy_rejected_log_probas,
        reference_chosen_logprobs=ref_chosen_log_probas,
        reference_rejected_logprobs=ref_rejected_log_probas,
        beta=beta
    )

    return loss, chosen_rewards, rejected_rewards

In [None]:
with torch.no_grad():
    loss = compute_dpo_loss_batch(batch, policy_model, reference_model, beta=0.1)
print(loss)

(tensor(0.6931, device='cuda:0'), tensor(0., device='cuda:0'), tensor(0., device='cuda:0'))


### *DPO Loss Function -- Entire Batch*

Below, we extend this function to work for a specified `num_batches` in a data loader.  Why a specified `num_batches` and not the entire dataloader? That's purely for efficiency reasons (because calculating the loss on the whole dataset each time would slow down the training significantly)

In [None]:
def compute_dpo_loss_loader(data_loader, policy_model, reference_model, beta, num_batches=None):

    total_loss, total_chosen_rewards, total_rejected_rewards = 0., 0., 0.

    for i, batch in enumerate(data_loader):
        if i < num_batches:
            loss, chosen_rewards, rejected_rewards = compute_dpo_loss_batch(
                batch=batch,
                policy_model=policy_model,
                reference_model=reference_model,
                beta=beta
            )
            total_loss += loss.item()
            total_chosen_rewards += chosen_rewards.item()
            total_rejected_rewards += rejected_rewards.item()

        else:
            break

    # calculate average
    total_loss /= num_batches
    total_chosen_rewards /= num_batches
    total_rejected_rewards /= num_batches
    return total_loss, total_chosen_rewards, total_rejected_rewards

### *DPO Loss Function -- Entire Train + Val Batch*

In [None]:
def evaluate_dpo_loss_loader(policy_model, reference_model, train_loader, val_loader, beta, eval_iter):
    """Compute the DPO loss for the training and validation dataset"""

    policy_model.eval()
    with torch.no_grad():
        train_loss, train_chosen_rewards, train_rejected_rewards = compute_dpo_loss_loader(
            data_loader=train_loader,
            policy_model=policy_model,
            reference_model=reference_model,
            beta=beta,
            num_batches=eval_iter
        )

        val_loss, val_chosen_rewards, val_rejected_rewards = compute_dpo_loss_loader(
            data_loader=val_loader,
            policy_model=policy_model,
            reference_model=reference_model,
            beta=beta,
            num_batches=eval_iter
        )

    res = {
        "train_loss": train_loss,
        "train_chosen_reward": train_chosen_rewards,
        "train_rejected_reward": train_rejected_rewards,
        "val_loss": val_loss,
        "val_chosen_reward": val_chosen_rewards,
        "val_rejected_reward": val_rejected_rewards
    }

    policy_model.train()
    return res

### *Training*

Training logic remains same as we saw in pretraining and instruction finetuning, with minor differences:
 - we swap the cross-entropy loss with our new DPO loss function
 - we also track the rewards and reward margins, which are commonly used in RLHF and DPO contexts to track the training progress


In [None]:
def train_model_dpo_simple(
    policy_model, reference_model, train_loader, val_loader,
    optimizer, num_epochs, beta,
    eval_freq, eval_iter, start_context, tokenizer
):

    # Initialize lists to track losses and tokens seen
    tracking = {
        "train_losses": [],
        "train_chosen_rewards": [],
        "train_rejected_rewards": [],
        "val_losses": [],
        "val_chosen_rewards": [],
        "val_rejected_rewards": [],
        "tokens_seen": []
    }
    tokens_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        policy_model.train()  # Set model to training mode

        for batch in train_loader:

            optimizer.zero_grad()  # Reset loss gradients from previous batch iteration

            loss, chosen_rewards, rejected_rewards = compute_dpo_loss_batch(
                batch=batch,
                policy_model=policy_model,
                reference_model=reference_model,
                beta=beta
            )

            loss.backward()  # Calculate loss gradients
            optimizer.step()  # Update model weights using loss gradients

            tokens_seen += batch["chosen"].numel()
            global_step += 1

    return tracking

In [None]:
torch.manual_seed(123)


optimizer = torch.optim.AdamW(policy_model.parameters(), # we are only passing the parameters of the policy model into the `AdamW` optimizer; that's the model we want to optimize (we don't want to modify the reference model)
                              lr=5e-6, # in DPO, it's best to use a very small learning rate
                              weight_decay=0.01)


tracking = train_model_dpo_simple(
    policy_model=policy_model,
    reference_model=reference_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    num_epochs=1, # we only train for 1 epoch; that's because DPO is very prone to collapse (the loss might improve, but the model will start generating nonsensical texts)
    beta=0.1, # the beta value can be increased from 0.1 to 0.5 to reduce the effect of DPO (we use 0.1 here to make the results more noticeable)
    eval_freq=5,
    eval_iter=5,
    start_context=format_input(val_data[2]),
    tokenizer=tokenizer
)

Ep 1 (Step 000000): Train loss 0.692, Val loss 0.693, Train reward margins 0.019, Val reward margins 0.009
Ep 1 (Step 000005): Train loss 0.690, Val loss 0.691, Train reward margins 0.070, Val reward margins 0.052
Ep 1 (Step 000010): Train loss 0.687, Val loss 0.688, Train reward margins 0.126, Val reward margins 0.108
Ep 1 (Step 000015): Train loss 0.676, Val loss 0.685, Train reward margins 0.362, Val reward margins 0.173
Ep 1 (Step 000020): Train loss 0.676, Val loss 0.680, Train reward margins 0.351, Val reward margins 0.264
Ep 1 (Step 000025): Train loss 0.666, Val loss 0.676, Train reward margins 0.564, Val reward margins 0.359
Ep 1 (Step 000030): Train loss 0.672, Val loss 0.672, Train reward margins 0.456, Val reward margins 0.441
Ep 1 (Step 000035): Train loss 0.663, Val loss 0.669, Train reward margins 0.658, Val reward margins 0.511
Ep 1 (Step 000040): Train loss 0.666, Val loss 0.666, Train reward margins 0.597, Val reward margins 0.574
Ep 1 (Step 000045): Train loss 0.648,