# Fine-tuning GPT-2 on custom datasets

In this notebook we will see how we can fine-tune a transformer model on our own custom datasets. Here we will be using pre-trained transformer models, which are an advanced neural network architecture primarly used for text understanding and generation. We will be covering transformers in more detail in Term 2 (AI for Media). 

To run this code you will need to install the following `transformers` library from [huggingface](https://huggingface.co/docs/transformers/index), this allows us to use and fine-tune many pre-trained transformer models. 

This code is originally [sourced from here](https://github.com/mf1024/Transformers/blob/master/Fine-tuning%20GPT2-medium%20in%20PyTorch.ipynb), and has been adapted to be clearer and easier to load in different kinds of datasets:

In [1]:
!pip install transformers

Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/12/dd/f17b11a93a9ca27728e12512d167eb1281c151c4c6881d3ab59eb58f4127/transformers-4.35.2-py3-none-any.whl.metadata
  Downloading transformers-4.35.2-py3-none-any.whl.metadata (123 kB)
     ---------------------------------------- 0.0/123.5 kB ? eta -:--:--
     --------- ----------------------------- 30.7/123.5 kB 1.3 MB/s eta 0:00:01
     -------------------------------------- 123.5/123.5 kB 2.4 MB/s eta 0:00:00
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Obtaining dependency information for huggingface-hub<1.0,>=0.16.4 from https://files.pythonhosted.org/packages/05/09/1945ca6ba3ad8ad6e2872ba682ce8d68c5e63c8e55458ed8ab4885709f1d/huggingface_hub-0.19.4-py3-none-any.whl.metadata
  Downloading huggingface_hub-0.19.4-py3-none-any.whl.metadata (14 kB)
Collecting pyyaml>=5.1 (from transformers)
  Obtaining dependency information for pyyaml>=5.1 from https



First lets do some imports and set our device:

In [2]:
import os
import csv
import torch
import logging
import warnings
import numpy as np

from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup

logging.getLogger().setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore')

device = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


We will now need to download the GPT2 models. This is nearly 2GB so it may take some time to download:

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

vocab.json: 100%|██████████| 1.04M/1.04M [00:00<00:00, 2.78MB/s]
merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 2.88MB/s]
tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 1.71MB/s]
config.json: 100%|██████████| 718/718 [00:00<00:00, 476kB/s]
model.safetensors: 100%|██████████| 1.52G/1.52G [03:18<00:00, 7.65MB/s]
generation_config.json: 100%|██████████| 124/124 [00:00<00:00, 49.1kB/s]


Here is our code for sampling from our predictions:

In [4]:
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob) # Normalize
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    return int(token_id)

#### Dataset classes

Here we define our dataset classes (by inheriting from the [PyTorch Dataset class](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)). There is a TXTDataset class, that automatically loads in a dataset of `.txt` files in a folder (such as the nursery rhymes dataset). There is also a TSVDataset class. This will allow you to load in data from a `.tsv` file. Change the `dataset_path` parameter when initialising the class to load in your own dataset.

In [5]:
class TXTDataset(Dataset):
    def __init__(self, dataset_path = '../data/nursery-rhymes'):
        super().__init__()

        self.data_list = []
        self.end_of_text_token = "<|endoftext|>"
        
        for root, _, files in os.walk(dataset_path):
            for file in files:
                if file.endswith(".txt"):
                    with open(os.path.join(root, file), 'r', encoding='utf-8') as f:
                        data_str = f.read()
                        self.data_list.append(f'TEXT:{data_str}{self.end_of_text_token}')
        
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, item):
        return self.data_list[item]
    
class TSVDataset(Dataset):
    def __init__(self, dataset_path = '../data/lyric_data.tsv', data_row_index = 4):
        super().__init__()

        self.data_list = []
        self.end_of_text_token = "<|endoftext|>"
        
        with open(dataset_path) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter='\t')
            
            x = 0
            for row in csv_reader:
                data_str = f"TEXT:{row[data_row_index]}{self.end_of_text_token}"
                self.data_list.append(data_str)
        
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, item):
        return self.data_list[item]


### Hyperparameters

Depending on the size of your dataset, you will want to adjust the number of epochs you are training for. It will take a long time to process each epoch with a large dataset, so you will want to keep it low. But for a small dataset, training for a small number of epochs will not be sufficient to change the output from the model.

In [6]:
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 3e-5
WARMUP_STEPS = 5000
MAX_SEQ_LEN = 400

In [7]:
dataset = TXTDataset(dataset_path = '../data/nursery-rhymes')
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [11]:
pip install torch


Note: you may need to restart the kernel to use updated packages.




### Model training

I will train the model and save the model weights after each epoch and then I will try to generate jokes with each version of the weight to see which performs the best.

**Warning:** depending on the size of the dataset this can take **a very long time** to train. Make sure your laptop is plugged in while doing this!

In [9]:
model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps = -1)
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

tmp_text_tokens = None
models_folder = "ckpt/gpt2"
# if not os.path.exists(models_folder):
#     os.mkdir(models_folder)

if not os.path.exists(models_folder):
    os.makedirs(models_folder, exist_ok=True)

for epoch in range(EPOCHS):
    
    print(f"EPOCH {epoch} started" + '=' * 30)
    
    for idx, text_str in enumerate(data_loader):
        
        
        text_tokens = torch.tensor(tokenizer.encode(text_str[0])).unsqueeze(0).to(device)
        #Skip sample from dataset if it is longer than MAX_SEQ_LEN
        if text_tokens.size()[1] > MAX_SEQ_LEN:
            continue
        
        #The first joke sequence in the sequence
        if not torch.is_tensor(tmp_text_tokens):
            tmp_text_tokens = text_tokens
            continue
        else:
            #The next joke does not fit in so we process the sequence and leave the last joke 
            #as the start for next sequence 
            if tmp_text_tokens.size()[1] + text_tokens.size()[1] > MAX_SEQ_LEN:
                work_text_tokens = tmp_text_tokens
                tmp_text_tokens = text_tokens
            else:
                #Add the joke to sequence, continue and try to add more
                tmp_text_tokens = torch.cat([tmp_text_tokens, text_tokens[:,1:]], dim=1)
                continue
        ################## Sequence ready, process it trough the model ##################
            
        outputs = model(work_text_tokens, labels=work_text_tokens)
        loss, logits = outputs[:2]                        
        loss.backward()
        sum_loss = sum_loss + loss.detach().data
                       
        proc_seq_count = proc_seq_count + 11
        if proc_seq_count == BATCH_SIZE:
            proc_seq_count = 0    
            batch_count += 1
            optimizer.step()
            scheduler.step() 
            optimizer.zero_grad()
            model.zero_grad()

        if batch_count == 100:
            print(f"sum loss {sum_loss}")
            batch_count = 0
            sum_loss = 0.0
    
    # Store the model after each epoch to compare the performance of them
    torch.save(model.state_dict(), os.path.join(models_folder, f"gpt2_medium_finetuned_{epoch}.pt"))
            



### Generating text from the model:

Here we will generated text from the model we have trained. If you have set `EPOCHS` for longer then change the parameter here for which checkpoint (defined in `MODEL_EPOCH`) that you want to load in and generate from. If the model is not sufficiently mimicing your training data then you may need to train for more epochs:

In [10]:
MODEL_EPOCH = 4
models_folder = "ckpt/gpt2"
model_path = os.path.join(models_folder, f"gpt2_medium_finetuned_{MODEL_EPOCH}.pt")
model.load_state_dict(torch.load(model_path))

model.eval()
    
with torch.no_grad():
   
        for idx in range(1):
        
            text_finished = False
            cur_ids = torch.tensor(tokenizer.encode("TEXT:")).unsqueeze(0).to(device)

            for i in range(300):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]
                softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(from only one in this case) batch and the last predicted embedding
                if i < 3:
                    n = 20
                else:
                    n = 3
                next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) select the next word
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    text_finished = True
                    break

            output_list = list(cur_ids.squeeze().to('cpu').numpy())
            output_text = tokenizer.decode(output_list)
            print(output_text)

TEXT:

The first time we saw him in a video, he had been playing with his dog, and he was wearing a black T-shirt and black shorts.

"I was like, 'What's going on?'" he said. "And he said, 'I'm not wearing a shirt, I just want to play with my dog.'"


The video shows him playing with his dog and then walking away.

He said he was wearing his shirt and shorts because he was afraid he would be arrested. He said he had no idea what the police were doing.


"It's just crazy, I'm like, 'What is going on?'"

Police said he was arrested and charged with disorderly conduct and obstructing a peace officer.

The incident was caught on camera and posted on YouTube.

The video has since been taken down.

Police said they were investigating the case, but said they had not received any complaints from the public.

The video was taken down after it was posted.

Copyright 2017 by WKMG ClickOrlando - All rights reserved.<|endoftext|>
