# Transformers QA

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import io
import re
import math

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F

from torchtext.datasets import YahooAnswers
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchtext.data.functional import sentencepiece_tokenizer, load_sp_model

from tqdm.notebook import trange, tqdm

In [None]:
from torch.distributions import Categorical

In [None]:
# Define the hyperparameters
learning_rate = 1e-4

nepochs = 100

batch_size = 32

max_len_q = 32
max_len_a = 64

data_set_root = "../../datasets"

# We'll be using the YahooAnswers Dataset
# Note that for torchtext these datasets are NOT Pytorch dataset classes "YahooAnswers" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58

# Depending on the dataset sometimes the dataset doesn't download and gives an error
# and you'll have to download and extract manually 
# "The datasets supported by torchtext are datapipes from the torchdata project, which is still in Beta status"

# Un-comment to triger the DataPipe to download the data vvv
# dataset_train = YahooAnswers(root=data_set_root, split="train")
# data = next(iter(dataset_train))

# Side-Note I've noticed that the WikiText dataset is no longer able to be downloaded :(

In [None]:
# ## "Train" a Sentence Piece Tokenizer with the train data capping the vocab size to 20000 tokens
# from torchtext.data.functional import generate_sp_model

# with open(os.path.join(data_set_root, "datasets/YahooAnswers/train.csv")) as f:
#     with open(os.path.join(data_set_root, "datasets/YahooAnswers/data.txt"), "w") as f2:
#         for i, line in enumerate(f):
#             text_only = "".join(line.split(",")[1:])
#             filtered = re.sub(r'\\|\\n|;', ' ', text_only.replace('"', ' ').replace('\n', ' ')) # remove newline characters
#             f2.write(filtered.lower() + "\n")


# generate_sp_model(os.path.join(data_set_root, "datasets/YahooAnswers/data.txt"), 
#                   vocab_size=20000, model_prefix='spm_user_ya')

In [None]:
class YahooQA(Dataset):
    def __init__(self, num_datapoints, test_train="train"):
        self.df = pd.read_csv(os.path.join(data_set_root, "datasets/YahooAnswers/" + test_train + ".csv"),
                              names=["Class", "Q_Title", "Q_Content", "A"])
        
        self.df.fillna('', inplace=True)
        self.df['Q'] = self.df['Q_Title'] + ': ' + self.df['Q_Content']
        self.df.drop(['Q_Title', 'Q_Content'], axis=1, inplace=True)
        self.df['Q'] = self.df['Q'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)
        self.df['A'] = self.df['A'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)

    def __getitem__(self, index):
        question_text = self.df.loc[index]["Q"].lower()
        answer_text = self.df.loc[index]["A"].lower()

        return question_text, answer_text

    def __len__(self):
        return len(self.df)

In [None]:
dataset_train = YahooQA(num_datapoints=data_set_root, test_train="train")
dataset_test = YahooQA(num_datapoints=data_set_root, test_train="test")

In [None]:
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
sp_model = load_sp_model("spm_user_ya.model")
tokenizer = sentencepiece_tokenizer(sp_model)

In [None]:
def yield_tokens(file_path):
    with io.open(file_path, encoding = 'utf-8') as f:
        for line in f:
            yield [line.split("\t")[0]]
            
vocab = build_vocab_from_iterator(yield_tokens("spm_user_ya.vocab"), 
                                  specials= ['<pad>', '<soq>', '<eoq>', '<soa>', '<eoa>', '<unk>'], # special case tokens
                                  special_first=True)
vocab.set_default_index(vocab['<unk>'])

In [None]:
tokenizer_transform = T.SentencePieceTokenizer("spm_user_ya.model")

In [None]:
q_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_user_ya.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len_q),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)

a_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_user_ya.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(3, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len_a),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(4, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)

In [None]:
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

    
# Attention block with self-attention with/without causal masking
class AttentionBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, masking=True):
        super(AttentionBlock, self).__init__()
        self.masking = masking

        self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads=num_heads, batch_first=True)
                
    def forward(self, x_in, kv_in):
        if self.masking:
            bs, l, h = x_in.shape
            mask = torch.triu(torch.ones(l, l, device=x_in.device), 1).bool()
        else:
            mask = None
            
        return self.multihead_attn(x_in, kv_in, kv_in, attn_mask=mask)[0]

    
# Transformer block with self-attention with/without causal masking
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size=128, num_heads=4, decoder=False, masking=True):
        super(TransformerBlock, self).__init__()
        self.decoder = decoder

        self.norm1 = nn.LayerNorm(hidden_size)
        self.attn1 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads, masking=masking)
        
        if self.decoder:
            self.norm2 = nn.LayerNorm(hidden_size)
            self.attn2 = AttentionBlock(hidden_size=hidden_size, num_heads=num_heads, masking=False)
        
        self.norm_mlp = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4),
                                 nn.ELU(),
                                 nn.Linear(hidden_size * 4, hidden_size))
                
    def forward(self, x, kv_cross=None):
        x = self.attn1(x, x) + x
        x = self.norm1(x)

        if self.decoder:
            x = self.attn2(x, kv_cross) + x
            x = self.norm2(x)

        x = self.mlp(x) + x
        return self.norm_mlp(x)
    
    
class Encoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Encoder, self).__init__()
        
        # Create an embedding for each token
        self.embedding = nn.Embedding(num_emb, hidden_size)
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads, decoder=False, masking=False) for _ in range(num_layers)
        ])
                
    def forward(self, input_seq):        
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape

        # Add a unique embedding to each token embedding depending on it's position in the sequence
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb
        
        for block in self.blocks:
            output = block(embs)
        
        return output

    
class Decoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Decoder, self).__init__()
        
        # Create an embedding for each token
        self.embedding = nn.Embedding(num_emb, hidden_size)
        self.pos_emb = SinusoidalPosEmb(hidden_size)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads, decoder=True) for _ in range(num_layers)
        ])
                
        self.fc_out = nn.Linear(hidden_size, num_emb)
        
    def forward(self, input_seq, encoder_output):        
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape

        # Add a unique embedding to each token embedding depending on it's position in the sequence
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb
        
        for block in self.blocks:
            output = block(embs, kv_cross=encoder_output)
        
        return self.fc_out(output)

    
# "Encoder-Decoder" Style Transformer with self-attention
class EncoderDecoder(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=(3, 3), num_heads=4):
        super(EncoderDecoder, self).__init__()
        
        # Create an embedding for each token
        self.encoder = Encoder(num_emb=num_emb, hidden_size=hidden_size, 
                               num_layers=num_layers[0], num_heads=num_heads)
        
        self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size, 
                               num_layers=num_layers[1], num_heads=num_heads)

        
    def forward(self, input_seq, target_seq):        
        encoded_seq = self.encoder(input_seq)
        decoded_seq = self.decoder(target_seq, encoded_seq)

        return decoded_seq

In [None]:
device = torch.device(0 if torch.cuda.is_available() else 'cpu')

In [None]:
hidden_size = 512

num_layers = (3, 6)
num_heads = 16

# Create model
tf_generator = EncoderDecoder(num_emb=len(vocab), num_layers=num_layers, 
                              hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(tf_generator.parameters(), lr=learning_rate, weight_decay=1e-4)

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Custom transform that will randomly replace a token with <pad>
# td = TokenDrop(prob=0.2)

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in tf_generator.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
training_loss_logger = []

In [None]:
for epoch in trange(0, nepochs, leave=False, desc="Epoch"):
    tf_generator.train()
    steps = 0
    for q_text, a_text in tqdm(data_loader_train, desc="Training", leave=False):
        q_text_tokens = q_tranform(list(q_text)).to(device)
        a_text_tokens = a_tranform(list(a_text)).to(device)
        a_input_text = a_text_tokens[:, 0:-1]
        a_output_text = a_text_tokens[:, 1:]
        
        bs = q_text_tokens.shape[0]

        pred = tf_generator(q_text_tokens, a_input_text)

        loss = loss_fn(pred.transpose(1, 2), a_output_text)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        training_loss_logger.append(loss.item())
        

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(training_loss_logger[1000:])
_ = plt.title("Training Loss")

In [None]:
window_size = 512
data = np.convolve(np.array(training_loss_logger), np.ones(window_size)/window_size, mode="valid")
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(data[10000:])
_ = plt.title("Training Loss")

In [None]:
q_text, a_text = next(iter(data_loader_test))

In [None]:
q_text[0]

In [None]:
a_text[0]

In [None]:
# init_prompt = ["what is that largest ocean in the world?: "]
init_prompt = [q_text[0]]

input_tokens = q_tranform(init_prompt).to(device)

# Add Start-Of-Answer token to prompt the network to start generating the answer!
# input_tokens = torch.cat((input_tokens, 3 * torch.ones(1, 1, device=device).long()), 1)
soa_token = 3 * torch.ones(1, 1).long()
print(input_tokens)
print(vocab.lookup_tokens(input_tokens[0].cpu().numpy()))

In [None]:
temp = 0.8

In [None]:
log_tokens = [soa_token]
tf_generator.eval()

with torch.no_grad():
    encoded_seq = tf_generator.encoder(input_tokens.to(device))

    for i in range(100):
        input_tokens = torch.cat(log_tokens, 1)
        data_pred = tf_generator.decoder(input_tokens.to(device), encoded_seq)
#         We can take the token with the highest prob
#         input_tokens = data_pred[:, -1].argmax().reshape(1, 1)
        
        # Or sample from the distribution of probs!
        dist = Categorical(logits=data_pred[:, -1]/temp)
        next_tokens = dist.sample().reshape(1, 1)
        
        log_tokens.append(next_tokens.cpu())
        
        if next_tokens.item() == 4:
            break

In [None]:
pred_text = "".join(vocab.lookup_tokens(torch.cat(log_tokens, 1)[0].numpy()))
print(pred_text)

In [None]:
pred_text.replace("▁", " ").replace("<unk>", "").replace("<eoa>", "")

In [None]:
plt.plot(F.softmax(data_pred[:, -1]/temp, -1).cpu().numpy().flatten())