# 🚀 GPT

This notebook is an **unofficial PyTorch implementation** of the excellent [Keras example](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/09_transformer/gpt/gpt.ipynb) for transformers and GPT, originally created by David Foster as part of the companion code for the excellent book [Generative Deep Learning, 2nd Edition](https://www.oreilly.com/library/view/generative-deep-learning/9781098134174/).

In this notebook, we'll walk through the steps required to train your own GPT model on the Wine Reviews dataset using PyTorch.

In [None]:
%load_ext autoreload
%autoreload 2

import os

# Get the working directory and the current notebook directory
working_dir = os.getcwd()
exp_dir = os.path.join(working_dir, "notebooks/09_transformer/01_gpt/")

In [None]:
import json
import re
import string
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, normalizers
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torchinfo import summary
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
import numpy as np
from IPython.display import display, HTML

## 0. Parameters <a name="parameters"></a>

In [None]:
VOCAB_SIZE = 10000
MAX_LEN = 80
# Since PyTorch does not take in the projection size for keys, queries, and values, but rather calculates them as embedding_dim/num_heads,
# for our current config of 2 heads, we cannot match Keras in setting the embedding_dim = 256 and the key_dim (internal projection dim)
# to 256. It can either be embed_dim = 256 and the key projection will automatically be 128 (less powerful than Keras), or 
# we can set the embedding to 512 and so the key projection dim will be 256 (more powerful embedding than Keras).
EMBEDDING_DIM = 256
# EMBEDDING_DIM = 512
KEY_DIM = 256
N_HEADS = 2
# N_HEADS = 1
FEED_FORWARD_DIM = 256
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.00001

## 1. Load the data <a name="load"></a>

In [None]:
data_dir = working_dir + "/data"
dataset_dir = data_dir + "/wine-reviews"
data_file = dataset_dir + "/winemag-data-130k-v2.json"

In [None]:
with open(data_file, "r") as json_data:
    wine_data = json.load(json_data)

In [None]:
wine_data[10]

In [None]:
filtered_data = ["wine review : " +
                 f"{x['country']} : " +
                 f"{x['province']} : " +
                 f"{x['variety']} : " +
                 f"{x['description']}"
                 for x in wine_data if "country" in x and x["country"] is not None and
                 "province" in x and x["province"] is not None and 
                 "variety" in x and x["variety"] is not None and
                 "description" in x and x["description"] is not None]

In [None]:
filtered_data[10]

In [None]:
n_wine_reviews = len(filtered_data)
print(f"The number of available wine reviews = {n_wine_reviews}")

## 2. Tokenize the data <a name="tokenize"></a>

In [None]:
def pad_punctuation(str):
    # add space before and after every punctuation
    str = re.sub(f"([{string.punctuation}])", r" \1 ", str)
    # replace multiple spaces with one space
    str = re.sub(" +", " ", str)

    return str

In [None]:
test_text = "Hello   there!"
test_text = pad_punctuation(test_text)
print(test_text) 

In [None]:
text_data = list(map(pad_punctuation, filtered_data))

In [None]:
print(text_data[10])
print(f"Text data size = {len(text_data)}")

In [None]:
# We will modify the TextSeqDataset class that we have implemented in chapter 5 for the lstm to internally create the tokenizer and process 
# string data if required, this will make it more reusable in the future if we want to use the same tokenizer specs
class TextSeqAdvancedDataset(Dataset):
    def __init__(self, data_list, tokenize=False, 
                 vocab_size=VOCAB_SIZE, max_seq_len=MAX_LEN,
                 verbose=0):
        super().__init__()
        if tokenize:
            # we will use the hugging face Tokenizers package to Tokenize the dataset and create the vocab
            # We will use a simple word tokenizer
            # the tokenizer itself will handel assigning a numerical value to each word
            tokenizer = Tokenizer(models.WordLevel(unk_token="<unk>"))
            # the pre tokenizer will pre process the test and split it into words (based on whitespace)
            tokenizer.normalizer = normalizers.Lowercase()
            tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

            # to form the vocabilary using the tokenizer we use trainer
            trainer = trainers.WordLevelTrainer(special_tokens=["<pad>", "<unk>"], vocab_size=vocab_size)
            tokenizer.train_from_iterator(text_data, trainer)

            vocab = tokenizer.get_vocab()
            self.pad_idx = vocab["<pad>"]
            self.unk_idx = vocab["<unk>"]

            # enable trancation and padding for the dataste so that all entries would have the same length
            tokenizer.enable_padding(length=max_seq_len + 1, pad_id=self.pad_idx, pad_token="<pad>")
            tokenizer.enable_truncation(max_length=MAX_LEN + 1)

            self.vectorized_data_list = [tokenizer.encode(sentence).ids for sentence in data_list]
            self.indx_to_word = vocab_idnx_to_word = {vocab[key]: key for key in vocab.keys()}

            if verbose == 1:
                # print details of the resulting vocabulary
                print("Vocabulary size:", tokenizer.get_vocab_size())
                print("Vocabulary:", vocab)
                print("pading index = ", self.pad_idx)
                
                print(vocab_idnx_to_word)
                
        else:
            self.vectorized_data_list = data_list
    
    def __len__(self):
        return(len(self.vectorized_data_list))
    
    def get_pad_idx(self):
        return self.pad_idx
    
    def get_unk_idx(self):
        return self.unk_idx
    
    def get_idx_to_word(self):
        return self.indx_to_word
    
    def get_data_pair(self, idx):
        text = self.vectorized_data_list[idx]
        x = torch.tensor(text[:-1])
        y = torch.tensor(text[1:])

        return x, y
    
    def __getitem__(self, idx):
        return self.get_data_pair(idx)

In [None]:
# we will set the value for the token paralization to avoid getting warning
os.environ["TOKENIZERS_PARALLELISM"] = "true"

train_dataset = TextSeqAdvancedDataset(data_list=text_data, tokenize=True, 
                                       vocab_size=VOCAB_SIZE, max_seq_len=MAX_LEN, 
                                       verbose=1)

pad_idx = train_dataset.get_pad_idx()
unk_idx = train_dataset.get_unk_idx()
vocab_idx_to_word = train_dataset.get_idx_to_word()

for i in range(10):
    print(f"{i}:{vocab_idx_to_word[i]}")

x, y = train_dataset.get_data_pair(0)

print(x.shape)
print(y.shape)
print(x)
print(y)

## 3. Create the Training Set <a name="create"></a>

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

## 5. Create the causal attention mask function <a name="causal"></a>

In [None]:
def causal_attention_mask(num_keys, num_query, dtype):
    #Note 1: In pytorch True means attention disabled and False meanse attention enabled
    # this is the opposit convension to Keras 

    #Note 2: The batch size and batch additional dimention is not required for pytorch since
    #  it will broadcat a mask  of shape (seq_len, seq_len) for each bacth and each head
    #  else the dim 0 size should be num_heads*batch_size

    # following is an implmentation similar to the Keras one
    # j = torch.arange(num_keys)
    # i = torch.arange(num_query).unsqueeze(1)
    # mask = i >= (j - num_query + num_keys)
    # mask = ~mask.to(dtype=dtype)

    # a more Pytorch like implementaion
    mask = torch.triu(torch.ones(num_keys, num_query, dtype=dtype), diagonal=1)
    return mask

In [None]:
causal_attention_mask(10, 10, dtype=torch.bool)

## 6. Create a Transformer Block layer <a name="transformer"></a>

In [None]:
class TransformerBlock(nn.Module):
    """In Pytorch the size of the projected size (i.e dq, dk and dv) can not be direcly set, instead they are calculated as embedded_dim/num_heads
    for the each embedded dimesion of the query, key and value, if the input embedded dimension of the 3 of those is the same we can not set dk with 
    a different value than dv for this reason the pytorch transformer class will not have a key_dim configuration as in the Keras class, instead
    the emded_dim will be the dim of the input embedding sizes and output size and will be used internally by pytorch to calculate the projection size"""

    def __init__(self, num_heads, 
                 embed_dim, ff_dim,
                 seq_len, 
                 dropout_rate=0.1,
                 verbose=0,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate
        self.seq_len = seq_len
        self.verbose = verbose

        self.atten = nn.MultiheadAttention(num_heads=self.num_heads, embed_dim=self.embed_dim, batch_first=True)
        self.dropout_1 = nn.Dropout(self.dropout_rate)
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.ff_1 = nn.Linear(in_features=self.embed_dim, out_features=self.ff_dim)
        self.ff_2 = nn.Linear(in_features=self.ff_dim, out_features=self.embed_dim)
        self.dropout_2 = nn.Dropout(self.dropout_rate)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=1e-6)

    def forward(self, inputs, padding_mask=None):
        B = inputs.shape[0]
        K = inputs.shape[1]
        if self.verbose: print(f"Batch size {B} key and query size {K}")

        device = inputs.device

        causal_mask = causal_attention_mask(K, K, torch.bool).to(device)
        if self.verbose: print(f"causal_mask shape : {causal_mask.shape}")

        # we are adding masking of the padding tokens which is not present in the Keras implementation
        #  but is supposed to improve the training by not wasting capacity on the padding
        atten_output, atten_weights = self.atten(inputs, inputs, inputs, 
                                                 attn_mask=causal_mask, key_padding_mask=padding_mask)
        
        if self.verbose:
            print(f"atten output shape : {atten_output.shape}")
            print(f"atten weights shape : {atten_weights.shape}")

        x = self.dropout_1(atten_output)
        if self.verbose: print(f"drop out 1 size : {x.shape}")

        x = self.layer_norm_1(x +  inputs)
        if self.verbose: print(f"layer_norm 1 size : {x.shape}")

        residual = x

        x = self.ff_1(x)
        if self.verbose: print(f"ff 1 shape : {x.shape}")

        x = F.relu(x)

        x = self.ff_2(x)
        if self.verbose: print(f"ff 2 shape {x.shape}")

        x = self.dropout_2(x)

        x = self.layer_norm_2(residual + x)
        if self.verbose: print(f"layer_norm 2 size : {x.shape}")

        return x, atten_weights


In [None]:
batch_size = 2
seq_length = 10
embed_dim = EMBEDDING_DIM

test_sequence = torch.rand(batch_size, seq_length, embed_dim)

transformer = TransformerBlock(1, embed_dim, ff_dim=200, seq_len=seq_length, verbose=1)

In [None]:
output, atten_weights = transformer(test_sequence)

## 7. Create the Token and Position Embedding <a name="embedder"></a>

In [None]:
class TokenAndPositionEmbedding (nn.Module):
    def __init__(self, vocab_size, max_seq_len, 
                 embed_dim, pad_idx=0, verbose=0,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.max_seq_len = max_seq_len
        self.pad_idx = pad_idx
        self.verbose = verbose

        self.token_emb = nn.Embedding(num_embeddings=self.vocab_size, 
                                      embedding_dim=self.embed_dim, 
                                      padding_idx=self.pad_idx)
        
        self.pos_emb = nn.Embedding(num_embeddings=self.max_seq_len, 
                                    embedding_dim=self.embed_dim)
        
    def forward(self, x):
        seq_len = x.shape[1]
        device = x.device
        #we will add a dimention at 0 to broadcast for the batch dimension
        pos_tensor = torch.arange(seq_len).unsqueeze(0).to(device)
        pos_embedding = self.pos_emb(pos_tensor)
        if self.verbose: print(f"Embed: pos embed size = {pos_embedding.shape}")
        token_embedding = self.token_emb(x)
        if self.verbose: print(f"Embed: token embed size = {token_embedding.shape}")
        embedding = token_embedding + pos_embedding
        return embedding

## 8. Build the GPT model <a name="transformer_decoder"></a>

In [None]:
class GPT(nn.Module):
    def __init__(self, num_heads, embed_dim, 
                 ff_dim, vocab_size, max_seq_len, 
                 pad_idx=0, dropout_rate=0.1,
                 verbose=0, log_dir = "./log",
                 *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.pad_idx = pad_idx
        self.verbose=verbose
        self.dropout_rate = dropout_rate

        self.writter = SummaryWriter(log_dir=log_dir)

        # Creating the GPT layers
        self.embedding_layer = TokenAndPositionEmbedding(self.vocab_size, self.max_seq_len, 
                                                         self.embed_dim, self.pad_idx, verbose=self.verbose)
        self.transformer_layer = TransformerBlock(self.num_heads, self.embed_dim, self.ff_dim, 
                                                  self.max_seq_len, self.dropout_rate, verbose=self.verbose)
        # In the forward function we will pass the output from the FF layer through a softmax
        # Activation
        self.ff_layer = nn.Linear(in_features=self.embed_dim, out_features=self.vocab_size)
    
    def forward(self, x):

        padding_mask = (x == self.pad_idx)

        x = self.embedding_layer(x)
        if self.verbose: print(f"GPT: Embedding size = {x.shape}")

        x, atten_weights = self.transformer_layer(x, padding_mask=padding_mask)
        if self.verbose:print(f"GPT: transformer output size = {x.shape}")
        
        x = self.ff_layer(x)
        if self.verbose:print(f"GPT: FF output size = {x.shape}")
        
        # We will not apply the softmax activaction here since in 
        # pytorch the CrossEntropyLoss loss function works on raw
        # logits and applies the softmax internally
        # x = F.softmax(x, dim=-1)
        # print(f"GPT: Softmax output size = {x.shape}")

        return x, atten_weights
    
    def fit(self, training_dataloader, optimizer, epochs, loss_fn, device, callbacks=None):

        self.optimizer=optimizer
        self.loss_fn = loss_fn
        self.device = device

        for epoch in range(epochs):
            acc_loss = 0

            for training_data in training_dataloader:
                # Run one training step
                loss =self.train_step(training_data)

                acc_loss += loss
            
            acc_loss /= len(training_dataloader)
            print(f"Epoch {epoch + 1}/{epochs}: loss = {acc_loss}")

            self.writter.add_scalar("train_loss", acc_loss, global_step=epoch)

             # run call back functions
            if callbacks is not None:
                logs = {"model":self,
                        "device":self.device,
                        "model_state_dict": self.state_dict(),
                        "loss": acc_loss
                }

                for callback in callbacks:
                    callback.on_epoch_end(epoch, logs=logs)
    
    def train_step(self, training_data):

        text_input, text_gt = training_data

        # print("Train: ")
        # train_str = ""
        # gt_str = ""
        # for i in range(10):
        #     train_str += (vocab_idx_to_word[text_input[0, i].item()] + " ")
        #     gt_str += (vocab_idx_to_word[text_gt[0, i].item()] + " ")

        # print(f"train data: {train_str}")
        # print(f"gt data: {gt_str}")

        text_input = text_input.to(self.device)
        text_gt = text_gt.to(self.device)

        self.train()
        # zero the grad to clear any accumulated grads
        self.optimizer.zero_grad()

        pred_text, atten_weights = self(text_input)

        vocab_size = pred_text.shape[-1]
        pred_text = pred_text.view(-1, vocab_size)
        text_gt = text_gt.view(-1)

        loss = self.loss_fn(pred_text, text_gt)

        loss.backward()
        self.optimizer.step()

        return loss.item()



In [None]:
gpt_test = GPT(num_heads=N_HEADS, embed_dim=EMBEDDING_DIM, ff_dim=FEED_FORWARD_DIM, 
               vocab_size=VOCAB_SIZE, max_seq_len=MAX_LEN, verbose=1, log_dir=(exp_dir + "/log"))

In [None]:
batch_size = 2
seq_length = MAX_LEN
embed_dim = EMBEDDING_DIM

test_sequence = torch.randint(low=0, high=VOCAB_SIZE, size=(batch_size, seq_length))
output, weights = gpt_test(test_sequence)

In [None]:
summary(gpt_test, (2, 80), device="cpu", dtypes=(torch.int32,))

## 9. Train the Transformer <a name="train"></a>

In [None]:
log_dir =  exp_dir + "/log"
os.makedirs(log_dir, exist_ok=True)

sample_dir =  exp_dir + "/sample_gen"
os.makedirs(sample_dir, exist_ok=True)

checkpoint_dir =  exp_dir + "/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

callbacks

In [None]:
class Callback:
    def on_epoch_end(self, epoch, logs=None):
        pass

In [None]:
class SaveCheckpoint(Callback):
    def __init__(self, save_dir, save_every=10):
        super().__init__()
        self.save_dir = save_dir
        self.save_every = save_every
    def on_epoch_end(self, epoch, logs=None):
        
        if (epoch % self.save_every) == 0:
            checkpoint = {"epoch":epoch,
                        "model_state_dict":logs["model_state_dict"],
                        "loss":logs["loss"]
                        }
            checkpoint_file = self.save_dir + f"/checkpoint_{epoch}.pth"

            torch.save(checkpoint, checkpoint_file)

In [None]:
class TextGenerator(Callback):
    def __init__(self, index_to_word, max_tokens=100, top_k=10):
        self.index_to_word = index_to_word
        self.word_to_index = {
            word: index for index, word in enumerate(index_to_word)
        }  
        self.max_tokens = max_tokens

    def sample_from(self, probs, temperature):  
        probs = probs ** (1 / temperature)
        probs = probs / torch.sum(probs)
        sample_token = torch.multinomial(probs, 1).item()
        return sample_token, probs

    def generate(self, model, start_prompt, max_tokens, temperature, device):
        start_tokens = [
            self.word_to_index.get(x, 1) for x in start_prompt.split()
        ] 
        sample_token = None
        info = []
        while len(start_tokens) < max_tokens and sample_token != 0: 

            with torch.no_grad():
                x = torch.tensor([start_tokens]).to(device)
                y, atten_weights = model(x)
                y = y.detach().to("cpu") 
                atten_weights = atten_weights.detach().to("cpu")
                # since we output logits with no softmax we will 
                # apply softmax here to get the probabilities
                y_prob = torch.softmax(y, dim=-1)
                sample_token, probs = self.sample_from(y_prob[0][-1], temperature)  
                info.append({"prompt": start_prompt, 
                             "word_probs": probs,
                             "atts": atten_weights[0, -1, :]})
                start_tokens.append(sample_token)  
                start_prompt = start_prompt + " " + self.index_to_word[sample_token]
        print(f"\ngenerated text:\n{start_prompt}\n")
        return info

    def on_epoch_end(self, epoch, logs=None):

        if logs:
            model = logs["model"]
            device = logs["device"]
            self.generate(model, "wine review", max_tokens=self.max_tokens, temperature=1.0, device=device)

In [None]:
callbacks = [SaveCheckpoint(save_dir=checkpoint_dir, save_every=2),
             TextGenerator(index_to_word=vocab_idx_to_word, max_tokens=MAX_LEN)]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

In [None]:
gpt = GPT(num_heads=N_HEADS, embed_dim=EMBEDDING_DIM, 
          ff_dim=FEED_FORWARD_DIM, vocab_size=VOCAB_SIZE, 
          max_seq_len=MAX_LEN, pad_idx=pad_idx, log_dir=log_dir).to(device)

In [None]:
# check if we have checkpoint to load
if LOAD_MODEL:
    checkpoint_file = checkpoint_dir + "/checkpoint_10.pth"
    checkpoint = torch.load(checkpoint_file)
    gpt.load_state_dict(checkpoint["model_state_dict"])

In [None]:
optimizer = Adam(params=gpt.parameters(), lr=LEARNING_RATE)
# unlike the Keras implementation we will ignore the padding value for the loss function so that the model would not learn to 
# predict padding tokens and to improve the training
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)


Note: the Quality of text generated is not as good as the Keras implementation, and more debugging might be required, this could be due to a few reasons:
- The limitation of setting the Key_dim which we can not set to 256 for embedding of 256 and 2 heads
- The Keras tokenizer might be more sophisticated than the simple tokenizer used here
- The Pytorch training seems to be unstable if a very large learning rate (like 0.001) is used, so we are using a smaller learning rate that needs more epochs to train
- We might need to do a few parameters tuning
(Improving this is work in progress)

In [None]:
gpt.fit(train_dataloader, optimizer=optimizer, 
        epochs=EPOCHS, loss_fn=loss_fn, 
        device=device, callbacks=callbacks)

In [None]:
def print_probs(info, vocab, top_k=5):
    for i in info:
        highlighted_text = []
        for word, att_score in zip(
            i["prompt"].split(), i["atts"]
        ):
            highlighted_text.append(
                '<span style="background-color:rgba(135,206,250,'
                + str(att_score.numpy()/max(i["atts"]).item())
                + ');">'
                + word
                + "</span>"
            )
        highlighted_text = " ".join(highlighted_text)
        display(HTML(highlighted_text))

        word_probs = i["word_probs"].numpy()
        p_sorted = np.sort(word_probs)[::-1][:top_k]
        i_sorted = np.argsort(word_probs)[::-1][:top_k]
        for p, i in zip(p_sorted, i_sorted):
            print(f"{vocab[i]}:   \t{np.round(100*p,2)}%")
        print("--------\n")

In [None]:
text_generator = TextGenerator(index_to_word=vocab_idx_to_word, max_tokens=MAX_LEN)

In [None]:
info = text_generator.generate(
    gpt, start_prompt="wine review : us", max_tokens=80, temperature=1.0, device=device
)

In [None]:
info = text_generator.generate(
    gpt, start_prompt="wine review : italy", max_tokens=80, temperature=0.5, device=device
)

In [None]:
info = text_generator.generate(
    gpt, start_prompt="wine review : germany", device=device, max_tokens=80, temperature=0.5
)
print_probs(info, vocab_idx_to_word)