# Exercise 3

## Group
- **ID**: 5

- **Members**:
    - Hasan Algafri
    - Emre Dursunluer
    - Taha El Amine Kassabi

## Hand-in
- Please hand in this notebook with your code implementation via Ilias
- Please make sure that there is exactly **one** submission per group

## Task Description

In this exercise, you will implement Supervised Finetuning (SFT) for the pretrained GPT-2 model. You should use the `transformers` library to load the pretrained model and tokenizer. You will finetune the model on the `Alpaca` dataset, which is a collection of instruction-following examples. The dataset can be found [here](https://huggingface.co/datasets/tatsu-lab/alpaca).
Your implementation should contain the four parts specified below.

## Grading scheme
Total: 5 points
1. **Preparing the Dataloader** (1 point)
2. **Sensible Configurations** (1 point)
3. **Training loop** (2 points)
4. **Generation of Question Answer pairs** (1 point)

## Task 1: Preparing the data

In [1]:
!pip install -U datasets



In [2]:
import os
from functools import partial

import torch
from IPython.core.debugger import prompt
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from tqdm.auto import tqdm, trange

In [3]:
load_dataset('tatsu-lab/alpaca')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output', 'text'],
        num_rows: 52002
    })
})

In [4]:
class AlpacaDataset(Dataset):

    def __init__(self, data, return_full_text=True):
        self.data = [instance['text'] for instance in data]
        self.return_full_text = return_full_text

        response_start = 'Response:\n'
        self.prompt_ends = [instance['text'].find(response_start) + len(response_start) for instance in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        prompt = self.data[idx][:self.prompt_ends[idx]]
        if self.return_full_text:
            return prompt, self.data[idx]
        else:
            return None, prompt

In [5]:
def custom_collate_fn(batch, tokenizer):
    prompts, texts = zip(*batch)

    encodings = tokenizer(
        texts,
        padding='longest',
        truncation=True,
        return_tensors='pt',
        return_attention_mask=True,
        add_special_tokens=True
    )

    input_ids = encodings['input_ids']
    attention_mask = encodings['attention_mask']

    if prompts[0] is None:
        prompt_mask = False
        prompts = texts
    else:
        prompt_encodings = tokenizer(prompts,
                                     padding=False,
                                     truncation=True)
        prompt_lengths = torch.tensor([len(ids) for ids in prompt_encodings['input_ids']], device=input_ids.device)
        prompt_mask = torch.arange(input_ids.size(1)).unsqueeze(0) < prompt_lengths.unsqueeze(1)

    labels = input_ids.clone()
    labels[(attention_mask == 0) | prompt_mask] = -100

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'prompts': prompts,
    }

In [6]:
def load_data(collate_fn, batch_size=8, test_size=0.05, num_workers=0):
    data = load_dataset('tatsu-lab/alpaca')['train'].train_test_split(test_size=test_size, shuffle=True)

    train_ds, val_ds = AlpacaDataset(data['train']), AlpacaDataset(data['test'], return_full_text=False)

    return (DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True),
            DataLoader(val_ds, collate_fn=collate_fn, num_workers=num_workers))

## Task 2: Sensible Config

In [7]:
class Configuration:
    def __init__(self):
        self.num_epochs = 3
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')

        self.model_name = 'gpt2'
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.reset_model()

        learning_rate = 1e-3
        weight_decay = 1e-2
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        collate_fn = partial(custom_collate_fn, tokenizer=self.tokenizer)

        test_size = 1e-2
        num_workers = 4
        batch_size = 16
        self.train_dl, self.val_dl = load_data(collate_fn, batch_size=batch_size, test_size=test_size, num_workers=num_workers)

        self.model_path = './models/best_model_min_val_loss.pt'

    def reset_model(self):
        self.model = GPT2LMHeadModel.from_pretrained(self.model_name).to(self.device)
        self.model = torch.compile(self.model)
        self.model.config.use_cache = True


config = Configuration()



In [8]:
for i, batch in enumerate(config.val_dl):
    print(batch['input_ids'].shape)
    print(batch['labels'])
    if i == 3: break

torch.Size([1, 45])
tensor([[21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,
           257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,
         21017, 46486,    25,   198,  9771,  3129,   378,   262,  1612,   286,
           262,  3146,   362,    11,    18,    11,    22,    11,    16,   198,
           198, 21017, 18261,    25,   198]])
torch.Size([1, 49])
tensor([[21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,
           257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,
         21017, 46486,    25,   198, 15946,   485,   257,  1351,   286,  4568,
           329,   257,  1115,   614,  1468,  1141,   257,   718,    12,  9769,
          6614,  6594,    13,   198,   198, 21017, 18261,    25,   198]])
torch.Size([1, 42])
tensor([[21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,
           257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,
         21017, 46486,    25, 

## Task 3: Train loop

In [9]:
def train_one_epoch(model, train_dl, optimizer, device):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_dl):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return epoch_loss / len(train_dl)


@torch.inference_mode()
def validate(model, val_dl, device):
    model.eval()
    epoch_loss = 0
    for batch in tqdm(val_dl):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        epoch_loss += loss.item()

    return epoch_loss / len(val_dl)


def finetune(config):
    model = config.model

    dict_log = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')
    model = model.to(config.device)

    pbar = trange(config.num_epochs)
    for epoch in pbar:
        train_loss = train_one_epoch(model, config.train_dl, config.optimizer, config.device)
        val_loss = validate(model, config.val_dl, config.device)

        pbar.set_postfix_str(f'Train Loss = {train_loss:.4f} | Val Loss = {val_loss:.4f}')

        dict_log['train_loss'].append(train_loss)
        dict_log['val_loss'].append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': config.optimizer.state_dict(),
                'loss': val_loss,
            }, config.model_path)

In [None]:
finetune(config)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3218 [00:00<?, ?it/s]

W0524 02:56:58.241000 4318 torch/_inductor/utils.py:1137] [0/0_1] Not enough SMs to use max_autotune_gemm mode
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


## Task 4: Generation of QA Pairs

In [None]:
def qa(config, max_batches=5):
    for i, batch in enumerate(config.val_dl):
        if i == max_batches: break

        encoded = batch['input_ids'].to(config.device)
        prompt = batch["prompts"][0]
        attention_mask = batch['attention_mask'].to(config.device)

        gen_ids = config.model.generate(
            encoded,
            attention_mask=attention_mask,
            max_new_tokens=100,
            do_sample=True,
            top_k=50,
            temperature=0.8,
            pad_token_id=config.tokenizer.eos_token_id
        )

        out = config.tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        assert out.startswith(prompt)

        print(out)
        print("-" * 40)

In [None]:
# check finetuned model
qa(config)

In [None]:
# compare default model
config.reset_model()
qa(config)