# Exercise 5

## Group
- **ID**: 5

- **Members**:
    - Hasan Algafri
    - Emre Dursunluer
    - Taha El Amine Kassabi

## Hand-in
- Please hand in this notebook with your code implementation via Ilias 
- Please make sure that there is exactly **one** submission per group

## Task Description

In this exercise, you will implement the DPO algorithm and integrate it into a finetuning pipeline. The algorithm is described in the paper [DPO: Direct Preference Optimization for Language Models](https://arxiv.org/abs/2305.18290).
You will use DPO to finetune a pretrained language model which is already aligned to conversational tasks.

You will use the dataset published by Anthropic. The dataset is a collection of human preference comparisons between two model outputs. The dataset is available at [Huggingface (Anthropic's dataset)](https://huggingface.co/datasets/Anthropic/hh-rlhf).

## Grading scheme
Total: 5 points
1. **Preparing the Dataloader** (1 point)
2. **Loss Function** (1.5 points)
3. **Choose a suitable Model** (0.5 points)
4. **Training loop** (1.5 points)
5. **Generation of Question Answer pairs** (0.5 points)

In [1]:
import os
from functools import partial

from datasets import load_dataset
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm.auto import tqdm, trange

## Task 1: Preparing the Dataloader (1 point) <br>
Get familiar with the dataset and prepare the dataloader.

In [2]:
class AnthropicDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx].values()

In [3]:
def custom_collate_fn(batch, tokenizer):
    N = len(batch)
    chosen, rejected = zip(*batch)

    encodings = tokenizer(
        chosen + rejected,
        padding='longest',
        truncation=True,
        return_tensors='pt',
        return_attention_mask=True,
        add_special_tokens=True
    )

    return {
        'chosen': {
            'input_ids': encodings['input_ids'][:N],
            'attention_mask': encodings['attention_mask'][:N],
        },
        'rejected': {
            'input_ids': encodings['input_ids'][N:],
            'attention_mask': encodings['attention_mask'][N:],
        }
    }

In [4]:
def load_data(collate_fn, batch_size=8, num_workers=0):
    ds = load_dataset('Anthropic/hh-rlhf')
    train_ds, val_ds = AnthropicDataset(ds['train']), AnthropicDataset(ds['test'])
    return (DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True),
            DataLoader(val_ds, collate_fn=collate_fn, num_workers=num_workers))

## Task 2: Loss Function (1.5 points) <br>
Implement the DPO loss function.

In [5]:
class DPOLoss(nn.Module):
    def __init__(self, beta=1.0):
        super().__init__()
        self.beta = beta

    def forward(self, input_ids_w, input_ids_l, logits_theta_w, logits_theta_l, logits_ref_w, logits_ref_l):
        logp_theta_w = self.to_seq_logp(logits_theta_w, input_ids_w)
        logp_theta_l = self.to_seq_logp(logits_theta_l, input_ids_l)
        logp_ref_w = self.to_seq_logp(logits_ref_w, input_ids_w)
        logp_ref_l = self.to_seq_logp(logits_ref_l, input_ids_l)

        score_w = self.beta * (logp_theta_w - logp_ref_w)
        score_l = self.beta * (logp_theta_l - logp_ref_l)

        return -F.logsigmoid(score_w - score_l).mean()

    @staticmethod
    def to_seq_logp(logits, input_ids):
        return F.log_softmax(logits, dim=-1).gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1).sum(dim=-1)

## Task 3: Choose a suitable Model (0.5 points) <br>
Choose a **suitable** model **for the DPO finetuning**. You can use any pretrained model from the Huggingface Hub. Please justify your choice in 1-2 sentences.

In [6]:
class Configuration:
    def __init__(self):
        self.num_epochs = 3
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')

        self.model_name = 'gpt2-medium'
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model_theta = self.get_model()
        self.model_ref = self.get_model()
        for p in self.model_ref.parameters(): p.requires_grad = False

        lr = 1e-6
        self.optimizer = optim.AdamW(self.model_theta.parameters(), lr=lr)

        self.criterion = DPOLoss(beta=1.0)

        collate_fn = partial(custom_collate_fn, tokenizer=self.tokenizer)

        num_workers = os.cpu_count() if self.device.type == 'cuda' else 0
        batch_size = 8
        self.train_dl, self.val_dl = load_data(collate_fn, batch_size=batch_size, num_workers=num_workers)

        self.model_path = '../models/ex05_best_model_min_loss.pt'

    def get_model(self):
        model = GPT2LMHeadModel.from_pretrained(self.model_name)
        model.config.use_cache = True
        return model

## Task 4: Training loop (1.5 points) <br>
Implement the training loop. Which uses your dataloader, the DPO loss function and the model you chose. Also use sensible hyperparameters.

In [7]:
def train_one_epoch(model_theta, model_ref, train_dl, optimizer, criterion, device):
    model_theta.train()
    epoch_loss = 0
    for batch in tqdm(train_dl):
        chosen, rejected = batch['chosen'], batch['rejected']
        w_input_ids = chosen['input_ids'].to(device)
        w_attention_mask = chosen['attention_mask'].to(device)
        l_input_ids = rejected['input_ids'].to(device)
        l_attention_mask = rejected['attention_mask'].to(device)

        logits_theta_w = model_theta(w_input_ids, attention_mask=w_attention_mask).logits
        logits_theta_l = model_theta(l_input_ids, attention_mask=l_attention_mask).logits
        with torch.inference_mode():
            logits_ref_w = model_ref(w_input_ids, attention_mask=w_attention_mask).logits
            logits_ref_l = model_ref(l_input_ids, attention_mask=l_attention_mask).logits

        loss = criterion(logits_theta_w, logits_theta_l, logits_ref_w, logits_ref_l)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return epoch_loss / len(train_dl)


@torch.inference_mode()
def validate(model_theta, model_ref, val_dl, criterion, device):
    model_theta.eval()
    epoch_loss = 0
    for batch in tqdm(val_dl):
        chosen, rejected = batch['chosen'], batch['rejected']
        chosen_input_ids = chosen['input_ids'].to(device)
        chosen_attention_mask = chosen['attention_mask'].to(device)
        rejected_input_ids = rejected['input_ids'].to(device)
        rejected_attention_mask = rejected['attention_mask'].to(device)

        logits_theta_chosen = model_theta(chosen_input_ids, attention_mask=chosen_attention_mask)
        logits_ref_chosen = model_ref(chosen_input_ids, attention_mask=chosen_attention_mask)
        logits_theta_rejected = model_theta(rejected_input_ids, attention_mask=rejected_attention_mask)
        logits_ref_rejected = model_ref(rejected_input_ids, attention_mask=rejected_attention_mask)

        loss = criterion(logits_theta_chosen, logits_ref_chosen, logits_theta_rejected, logits_ref_rejected)
        epoch_loss += loss.item()

        epoch_loss += loss.item()

    return epoch_loss / len(val_dl)


def finetune(config):
    dict_log = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')
    model_theta = config.model_theta.to(config.device)
    model_ref = config.model_ref.to(config.device)

    pbar = trange(config.num_epochs)
    for epoch in pbar:
        train_loss = train_one_epoch(model_theta, model_ref, config.train_dl, config.optimizer, config.criterion, config.device)
        val_loss = validate(model_theta, model_ref, config.val_dl, config.criterion, config.device)

        print(f'Epoch {epoch + 1}/{config.num_epochs}: Train Loss = {train_loss:.4f} | Val Loss = {val_loss:.4f}')

        dict_log['train_loss'].append(train_loss)
        dict_log['val_loss'].append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model_theta.state_dict(),
                'optimizer_state_dict': config.optimizer.state_dict(),
                'loss': val_loss,
            }, config.model_path)

In [8]:
config = Configuration()

In [9]:
finetune(config)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/20100 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 17.71 GB, other allocations: 232.69 MB, max allowed: 18.13 GB). Tried to allocate 209.48 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

## Task 5: Generation of Question Answer pairs (0.5 points) <br>
Generate question-answer pairs using the finetuned model and compare them to the base model.

In [None]:
def qa(config, max_batches=5):
    for i, batch in enumerate(config.val_dl):
        if i == max_batches: break

        encoded = batch['input_ids'].to(config.device)
        prompt = batch["prompts"][0]
        attention_mask = batch['attention_mask'].to(config.device)

        gen_ids = config.model.generate(
            encoded,
            attention_mask=attention_mask,
            max_new_tokens=100,
            do_sample=True,
            top_k=50,
            temperature=0.8,
            pad_token_id=config.tokenizer.eos_token_id
        )

        out = config.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        assert out.startswith(prompt)

        print(out)
        print("-" * 40)