# Fine-tuning GPT2-medium on a jokes dataset in PyTorch

This is an experimental notebook for fine-tuning pre-trained GPT2-medium model on a jokes dataset. Let's see if it can learn to crack some jokes. 

For this purpose, I will use the pre-trained GPT2 medium size model from huggingface [transformers repository](https://github.com/huggingface/transformers).

#### First, check out the *GPT2LMHeadModel* text generation experiments in this [gist](www.com). 

In [2]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np

import logging
logging.getLogger().setLevel(logging.CRITICAL)

import warnings
warnings.filterwarnings('ignore')

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = model.to(device)

In [3]:
# Function to select topN() tokens from the probability list(p) and then based on the P(p|top) distribution
# select random element
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)top[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob) # Normalize
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    return int(token_id)

### PyTorch Dataset module for Reddit jokes

For fine-tuning the GPT2 model, I will use [this](https://github.com/taivop/joke-dataset/blob/master/reddit_jokes.json) jokes dataset. After each joke sample, I add "<|endofext|>" which is recognized by the GPT2 model as and end of text marker. The marker will allow me to concatenate many jokes in one sequence.

In [8]:
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
import os
import json

class JokesDataset(Dataset):
    def __init__(self, jokes_dataset_path = 'jokes_data/'):
        super().__init__()

        reddit_jokes_path = os.path.join(jokes_dataset_path, 'reddit_jokes.json')

        with open(reddit_jokes_path) as f:
            data = json.load(f)

        self.joke_list = []
        self.end_of_text_token = "<|endoftext|>"

        for idx, joke_json in enumerate(data):
            joke_str = f"{self.end_of_text_token}START:{joke_json['title']} {joke_json['body']}{self.end_of_text_token}"
            self.joke_list.append(joke_str)
            
    def __len__(self):
        return len(self.joke_list)

    def __getitem__(self, item):
        return self.joke_list[item]


In [9]:
dataset = JokesDataset()
joke_loader = DataLoader(dataset, batch_size=1, shuffle=True)

### Fine-tuning GPT2-medium on a single GPU
 
Large Transformer models are usually trained in multi-GPU(or TPU) settings because training on reasonable batch size and sequence length requires lots of [tensor|graphical] processing unit memory. My machine is equipped with a single GeForce 1080 Ti, which has 11 GB of memory. By empirical tests on the GPT2 medium model, I found that the maximum total sequence element count in all batches for my GPU to backprop trough is approximately 550, which is not a lot and might not be sufficient for successful fine-tuning. 

But there are some things we can take into account and improve the situation. 

The first thing to notice is that batch size in forward/ backward pass of the Transformer does not play a role because [Layer Normalization](https://arxiv.org/abs/1607.06450) is used instead of Batch Normalization. In [Layer Normalization](https://mlexplained.com/2018/11/30/an-overview-of-normalization-methods-in-deep-learning/), each feature is normalized across the feature dimension. 

Second, we can collect gradients over multiple forward-backward passes, and only then do the model weight update(optimization step). This way, we can store in the memory of the GPU a computational graph of one sequence at a time instead of storing a computational graph of all of the batch. With this strategy, we can get the same result as if the batch would have been processed in a single forward/backward pass, only with *BATCH_SIZE* times less memory.

Putting it all together - I will process one sequence at a time with a maximum length of 400. The length of joke sequences varies a lot in the dataset I use, and to make the total sequence element count in one optimization step more consistent, I will try to fit in as many jokes as possible in the 400 element sequence. 

### Hyperparameters

I tested many(I think 5) hyperparameter sets till I found one that works the best. I mostly changed ***BATCH_SIZE*** (in this case, it's the number of forward-backward passes between each optimization step), ***EOPOCHS***, and ***LEARNING_RATE***.

For a parameter value starting point for fine-tuning, I inspired from [this](https://github.com/huggingface/transformers/blob/master/examples/run_squad.py) and [this](https://github.com/huggingface/transformers/blob/master/examples/run_glue.py).

In [11]:
BATCH_SIZE = 8
EPOCHS = 3
LEARNING_RATE = 3e-5
WARMUP_STEPS = 10000
MAX_SEQ_LEN = 400
from transformers import AdamW, WarmupLinearSchedule

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

### Model training

In [12]:
model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=WARMUP_STEPS, t_total = -1)
proc_seq_count = 0
sum_loss = 0.0
batch_count = 0

tmp_jokes_tens = None

for epoch in range(EPOCHS):
    
    print(f"EPOCH {epoch} started" + '=' * 30)
    
    for idx,joke in enumerate(joke_loader):
        
        #################### "Fit as many joke sequences into MAX_SEQ_LEN sequence as possible" logic start ####
        joke_tens = torch.tensor(tokenizer.encode(joke[0])).unsqueeze(0).to(device)
        #Skip sample from dataset if it is longer than MAX_SEQ_LEN
        if joke_tens.size()[1] > MAX_SEQ_LEN:
            continue
        
        #The first joke sequence in the sequence
        if not torch.is_tensor(tmp_jokes_tens):
            tmp_jokes_tens = joke_tens
            continue
        else:
            #The next joke does not fit in so we process the sequence and leave the last joke 
            #as the start for next sequence 
            if tmp_jokes_tens.size()[1] + joke_tens.size()[1] > MAX_SEQ_LEN:
                work_jokes_tens = tmp_jokes_tens
                tmp_jokes_tens = joke_tens
            else:
                #Add the joke to sequence, continue and try to add more
                tmp_jokes_tens = torch.cat([tmp_jokes_tens, joke_tens[:,1:]], dim=1)
                continue
        ################## Sequence ready, process it trough the model ##################
            
        outputs = model(work_jokes_tens, labels=work_jokes_tens)
        loss, logits = outputs[:2]                        
        loss.backward()
        sum_loss = sum_loss + loss.detach().data
                       
        proc_seq_count = proc_seq_count + 1
        if proc_seq_count == BATCH_SIZE:
            proc_seq_count = 0    
            batch_count += 1
            optimizer.step()
            scheduler.step() 
            optimizer.zero_grad()
            model.zero_grad()
            
        if batch_count == 1000:
            print(f"sum loss {sum_loss}")
            batch_count = 0
            sum_loss = 0.0

sum loss 3436.03076171875
sum loss 3271.525390625
sum loss 3117.43603515625
sum loss 3020.50732421875
sum loss 2940.70654296875
sum loss 2915.05078125
sum loss 2895.248779296875
sum loss 2849.1494140625
sum loss 2863.0771484375
sum loss 2827.261474609375
sum loss 2824.12109375
sum loss 2795.527587890625
sum loss 2803.104248046875
sum loss 2802.185791015625
sum loss 2786.28515625
sum loss 2778.56982421875
sum loss 2762.30615234375
sum loss 2770.957763671875
sum loss 2754.240478515625
sum loss 2747.343994140625
sum loss 2726.651611328125
sum loss 2729.69189453125
sum loss 2744.18212890625
sum loss 2724.10400390625
sum loss 2714.15283203125
sum loss 2711.803955078125
sum loss 2703.7119140625
sum loss 2704.89208984375
sum loss 2697.152099609375
sum loss 2689.19189453125
sum loss 2671.143798828125
sum loss 2677.786376953125
sum loss 2645.66064453125
sum loss 2622.15966796875
sum loss 2662.0419921875
sum loss 2641.111083984375
sum loss 2650.60546875
sum loss 2637.7275390625
sum loss 2634.773

### Generating some jokes

In [None]:
model.eval()
with torch.no_grad():
    
    for joke_idx in range(500):

        cur_ids = torch.tensor(tokenizer.encode("<|endoftext|>START:")).unsqueeze(0).to(device)
        
        for i in range(250):
            outputs = model(cur_ids, labels=cur_ids)
            loss, logits = outputs[:2]
            softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(only one in this case) batch and the last predicted embedding
            if i < 2:
                n = 15
            else:
                n = 3
            next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) choose the next word from the top n words
            cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word

            if next_token_id in tokenizer.encode('<|endoftext|>'):
                break
            
        output_list = list(cur_ids.squeeze().to('cpu').numpy())
        output_text = tokenizer.decode(output_list)
        print(f"JOKE NR {joke_idx}: {output_text} \n")

The output was too long and I stored it in [this file](github.link).

### Store the model for later use

In [None]:
torch.save(model.state_dict(), os.path.join("gpt2_medium_joker.pt"))

### Load stored model to generate some more jokes

In [None]:
model.load_state_dict(torch.load("gpt2_medium_joker.pt"))