# Supervised Fine-Tuning

Supervised Fine-Tuning (SFT) is the first step in the entire RLHF fine-tuning pipeline (see Figure 2 in [RLHF paper](https://arxiv.org/abs/2305.18438)).
This notebook would use gpt2 and the corresponding tokenizer model from Hugging Face `transformers` library to perform SFT on `stanfordnlp/sst2` dataset.

### Initialise gpt2 tokenizer and model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

## Testing the Tokenizer

### Encoding

In [None]:
text = "Hello, this is the first step of RLHF training."
tokens = tokenizer(text)
print(tokens)

### Decoding

In [None]:
print(tokenizer.decode(tokens['input_ids']))

### Tokenize a batch

In [None]:
texts = ['Hello, this is the first step of RLHF training.', 'I have a dog', 'I also have a cat']
tokens_obj = tokenizer(texts)

In [None]:
for tokens in tokens_obj['input_ids']:
    print(tokenizer.decode(tokens))

## Working with a dataset

In [None]:
%pip install datasets==3.5.0

### Loading a dataset

In [None]:
from datasets import load_dataset
dataset_name = 'sst2'
ds = load_dataset(dataset_name)

In [None]:
ds

In [None]:
ds_train, ds_val = ds['train'], ds['validation']
ds_train

In [None]:
ds_train[6]

In [None]:
# A batch of rows
ds_train[:10] # collation

## Tokenizing a Dataset

In [None]:
def tokenize(batch):
    return tokenizer(batch['sentence'])

map_kwargs = {
    'batched': True,
    'batch_size': 512,
    'remove_columns': ['idx', 'sentence', 'label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)

In [None]:
tokenized_dataset_train[0]

In [None]:
tokenized_dataset_train[5:10]

### Decoding from the dataset

In [None]:
for i, seq in enumerate(tokenized_dataset_train[5:10]['input_ids']):
    print(f'{i+1}: {tokenizer.decode(seq)}')

### Filter out tweets shorter than 5 tokens

In [None]:
print(len(tokenized_dataset_train), len(tokenized_dataset_val))

In [None]:
tokenized_dataset_train = tokenized_dataset_train.filter(lambda x: len(x['input_ids']) > 5)
tokenized_dataset_val = tokenized_dataset_val.filter(lambda x: len(x['input_ids']) > 5)

In [None]:
print(len(tokenized_dataset_train), len(tokenized_dataset_val))

## Preparing a dataloader

### Set PyTorch format

In [None]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

In [None]:
tokenized_dataset_train[0]

In [None]:
tokenized_dataset_train[:5]

### Padding

In [None]:
# check what the pad token is set to (should be empty)
print(tokenizer.pad_token)

In [None]:
# check what the eos token is set to
print(tokenizer.eos_token)

In [None]:
# N+ Implementation paper (page 5) says otherwise
# but we would use attention_mask to remove extra eos_token used for padding
tokenizer.pad_token = tokenizer.eos_token

### Collation with Padding

In [None]:
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) # labels

dataloader_params = {
    'batch_size': 32,
    'collate_fn': data_collator
}

train_dataloader = DataLoader(tokenized_dataset_train, **dataloader_params)
val_dataloader = DataLoader(tokenized_dataset_val, **dataloader_params)

In [None]:
len(train_dataloader)

In [None]:
1544 * 32

In [None]:
batch = next(iter(train_dataloader))
print(batch.keys())

In [None]:
batch['input_ids'].shape

In [None]:
batch['input_ids'][0]

In [None]:
batch['labels'][0]

In [None]:
batch['attention_mask'][0]

## Supervised Fine-tuning (SFT)

In [None]:
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 1

### Training loop

In [None]:
def validate(epoch):
    model.eval()
    total_loss = 0.0
    for i, batch in enumerate(val_dataloader):
        # iteration = epoch * len(val_dataloader) + i
        batch = batch.to(device)
        with torch.no_grad():
            outputs = model(**batch)
            loss = outputs.loss # Uses transformers.loss.loss_utils.ForCausalLMLoss for loss calculation
            total_loss += loss.item()
    print(f'val_loss at {epoch} epoch:', total_loss / len(val_dataloader))

Code for loss calculation: [transformers.loss.loss_utils.ForCausalLMLoss](https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
validate(0)
for epoch in range(num_epochs):
    model.train()
    for i, batch in enumerate(train_dataloader):
        batch = batch.to(device)
        outputs = model(**batch)
        loss = outputs.loss
        print(f'Loss: {loss.item()}')
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    validate(epoch+1)

### Save the model

In [None]:
model.save_pretrained('./sft_model_epoch_1')

In [None]:
model.from_pretrained('./sft_model_epoch_1')

### Zip the saved model (Optional)

In [None]:
!zip -r sft_model_epoch_1.zip sft_model_epoch_1/