# Transformer, many-to-many with text
### Decoder Only Architecture

Let's revisit text generation and see how we can use Attention to create a Transformer instead of an LSTM. The type of Transformer network we will be using is called a "Decoder only" as there is no cross-attention.
We will instead be using a type of Self-Attention called "Masked Self-Attention". Similar to regular Self-Attention, "Masked Self-Attention" uses "Causal Masking" to prevent tokens from querying other tokens that are later in the input sequence.<br>
Why?
<br>
As we're doing next-token prediction, if we let a token query every token in the sequence it will be able to simply "look" at what comes next in the sequence and return exactly what comes next! This is not useful as at test time when we want it to generate text it won't be able to do that!<br>
With "Causal Masking" we simply mask-out (multiply by zero) regions of the attention map that correspond to a token querying tokens that are later in the sequence. As a result a token will only be able to query itself, or any token that came BEFORE it in the sequence!

<img src="../data/llm_architecture_comparison.png" width="600" align="center">
<br>
We don't be exactly implementing the Decoder only Transformer detailed above, our network will basically be the Decoder from the Encoder-Decoder network without the Cross-Attention (input from the Encoder).
<br>
NOTE: We will cover Encoder-Decoder networks in the next notebook!

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import io
import re
import math
from tqdm.notebook import trange, tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.distributions import Categorical

from torchtext.datasets import WikiText2, EnWik9, AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchtext.data.functional import sentencepiece_tokenizer, load_sp_model

torch.backends.cuda.matmul.allow_tf32 = True

In [7]:
# Define the hyperparameters
# Learning rate for the optimizer
learning_rate = 1e-4

# Number of epochs for training
nepochs = 100

# Batch size for data loaders
batch_size = 128

# Maximum sequence length for text inputs
max_len = 64

# Root directory of the dataset
data_set_root = "/home/artem216/transformer-recsys/data/test_data/"

## Data processing and Tokenization

In [8]:
# We'll be using the AG News Dataset
# Which contains a short news article and a single label to classify the "type" of article
# Note that for torchtext these datasets are NOT Pytorch dataset classes "AG_NEWS" is a function that
# returns a Pytorch DataPipe!

# Pytorch DataPipes vvv
# https://pytorch.org/data/main/torchdata.datapipes.iter.html

# vvv Good Blog on the difference between DataSet and DataPipe
# https://medium.com/deelvin-machine-learning/comparison-of-pytorch-dataset-and-torchdata-datapipes-486e03068c58
# Depending on the dataset sometimes the dataset doesn't download and gives an error
# and you'll have to download and extract manually 
# "The datasets supported by torchtext are datapipes from the torchdata project, which is still in Beta status"

# Un-comment to triger the DataPipe to download the data vvv
# dataset_train = YahooAnswers(root=data_set_root, split="train")
# data = next(iter(dataset_train))

# Side-Note I've noticed that the WikiText dataset is no longer able to be downloaded :(

In [4]:
# Un-Comment to train sentence-piece model for tokenizer and vocab!

from torchtext.data.functional import generate_sp_model

with open(os.path.join(data_set_root, "train.csv")) as f:
    with open(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), "w") as f2:
        for i, line in enumerate(f):
            text_only = "".join(line.split(",")[1:])
            filtered = re.sub(r'\\|\\n|;', ' ', text_only.replace('"', ' ').replace('\n', ' ')) # remove newline characters
            filtered = filtered.replace(' #39;', "'")
            filtered = filtered.replace(' #38;', "&")
            filtered = filtered.replace(' #36;', "$")
            filtered = filtered.replace(' #151;', "-")

            f2.write(filtered.lower() + "\n")

generate_sp_model(os.path.join(data_set_root, "datasets/AG_NEWS/data.txt"), 
                  vocab_size=20000, model_prefix='spm_ag_news')

In [5]:
class AGNews(Dataset):
    """
    The AGNews class is a custom Dataset for handling the AG News dataset.
    This dataset consists of news articles categorized into four classes.
    The class loads the data from CSV files, preprocesses the text by cleaning and combining
    relevant columns, and provides an interface to access individual samples along with their
    corresponding class labels.
    
    Attributes:
        df (pd.DataFrame): The DataFrame containing the preprocessed dataset.
    """
    
    def __init__(self, num_datapoints, test_train="train"):
        # Load the dataset from the specified CSV file
        self.df = pd.read_csv(os.path.join(data_set_root, "datasets/AG_NEWS/" + test_train + ".csv"),
                              names=["Class", "Title", "Content"])
        
        # Fill any missing values with empty strings
        self.df.fillna('', inplace=True)
        
        # Combine the Title and Content columns into a single Article column
        self.df['Article'] = self.df['Title'] + " : " + self.df['Content']
        
        # Drop the now redundant Title and Content columns
        self.df.drop(['Title', 'Content'], axis=1, inplace=True)
        
        # Clean the Article column by removing unwanted characters and replacing HTML codes
        self.df['Article'] = self.df['Article'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)
        self.df['Article'] = self.df['Article'].replace({' #39;': "'", 
                                                         ' #38;': "&", 
                                                         ' #36;': "$",
                                                         ' #151;': "-"}, 
                                                        regex=True)

    def __getitem__(self, index):
        # Retrieve the article text and convert it to lowercase
        text = self.df.loc[index]["Article"].lower()

        # Return a tuple of the class index and the article text
        return text
    
    def __len__(self):
        # Return the number of data points in the dataset
        return len(self.df)

In [6]:
# Create training and testing datasets
dataset_train = AGNews(num_datapoints=data_set_root, test_train="train")
dataset_test = AGNews(num_datapoints=data_set_root, test_train="test")

# Create data loaders for the training and testing datasets
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8)

FileNotFoundError: [Errno 2] No such file or directory: '../../datasets/datasets/AG_NEWS/train.csv'

In [None]:
def yield_tokens(file_path):
    with io.open(file_path, encoding='utf-8') as f:
        # Iterate through each line in the file
        for line in f:
            # Yield the token from the first column (split by tab)
            yield [line.split("\t")[0]]

# Build a vocabulary from the tokens yielded by the yield_tokens function
# We will also add "special" tokens that we'll use to signal something to our model
# <pad> is a padding token that is added to the end of a sentence to ensure 
# the length of all sequences in a batch is the same
# <sos> signals the "Start-Of-Sentence" aka the start of the sequence
# <eos> signal the "End-Of-Sentence" aka the end of the sequence
# <unk> "unknown" token is used if a token is not contained in the vocab
vocab = build_vocab_from_iterator(yield_tokens("spm_ag_news.vocab"), 
                                  specials=['<pad>', '<sos>', '<eos>', '<unk>'],
                                  special_first=True)

# Set the default index for unknown tokens to the index of the '<unk>' token
vocab.set_default_index(vocab['<unk>'])

In [None]:
class TokenDrop(nn.Module):
    """For a batch of tokens indices, randomly replace a non-specical token with <pad>.
    
    Args:
        prob (float): probability of dropping a token
        pad_token (int): index for the <pad> token
        num_special (int): Number of special tokens, assumed to be at the start of the vocab
    """

    def __init__(self, prob=0.1, pad_token=0, num_special=4):
        self.prob = prob
        self.num_special = num_special
        self.pad_token = pad_token

    def __call__(self, sample):
        # Randomly sample a bernoulli distribution with p=prob
        # to create a mask where 1 means we will replace that token
        mask = torch.bernoulli(self.prob * torch.ones_like(sample)).long()
        
        # only replace if the token is not a special token
        can_drop = (sample >= self.num_special).long()
        mask = mask * can_drop
        
        replace_with = (self.pad_token * torch.ones_like(sample)).long()
        
        sample_out = (1 - mask) * sample + mask * replace_with
        
        return sample_out

In [None]:
train_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_ag_news.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Crop the sentance if it is longer than the max length
    T.Truncate(max_seq_len=max_len),
    ## Add <eos> at beginning of each sentence. 2 because the index for <eos> in vocabulary is
    # 2 as seen in previous section
    T.AddToken(2, begin=False),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0),
)

gen_tranform = T.Sequential(
    # Tokeniz with pre-existing Tokenizer
    T.SentencePieceTokenizer("spm_ag_news.model"),
    ## converts the sentences to indices based on given vocabulary
    T.VocabTransform(vocab=vocab),
    ## Add <sos> at beginning of each sentence. 1 because the index for <sos> in vocabulary is
    # 1 as seen in previous section
    T.AddToken(1, begin=True),
    # Convert the list of lists to a tensor, this will also
    # Pad a sentence with the <pad> token if it is shorter than the max length
    # This ensures all sentences are the same length!
    T.ToTensor(padding_value=0)
)


In [None]:
text = next(iter(data_loader_train))
index = 0
input_tokens = train_tranform(list(text))
print("SENTENCE")
print(text[index])
print()
print("TOKENS")
print(vocab.lookup_tokens(input_tokens[index].numpy()))

In [None]:
print("TOKENS BACK TO SENTENCE")

pred_text = "".join(vocab.lookup_tokens(input_tokens[index].numpy()))
pred_text.replace("▁", " ")

## Create Model

In [None]:
# Sinusoidal positional embeddings
class SinusoidalPosEmb(nn.Module):
    """
    Sinusoidal positional embeddings module.
    """

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        # Calculate sinusoidal positional embeddings
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

    
# Transformer block with Attention and causal masking
class TransformerBlock(nn.Module):
    """
    Transformer block with self-attention and causal masking.
    """

    def __init__(self, hidden_size=128, num_heads=4):
        super(TransformerBlock, self).__init__()

        # Layer normalization for input
        self.norm1 = nn.LayerNorm(hidden_size)

        # Multi-head self-attention mechanism
        self.multihead_attn = nn.MultiheadAttention(hidden_size, 
                                                    num_heads=num_heads, 
                                                    batch_first=True,
                                                    dropout=0.1)

        # Layer normalization for attention output
        self.norm2 = nn.LayerNorm(hidden_size)

        # Feedforward neural network
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.ELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )

    def forward(self, x, padding_mask):
        # Create causal mask for Attention
        bs, l, h = x.shape
        mask = torch.triu(torch.ones(l, l, device=x.device), 1).bool()

        # Layer normalization
        norm_x = self.norm1(x)

        # Apply multi-head Attention
        x = self.multihead_attn(norm_x, norm_x, norm_x, attn_mask=mask, key_padding_mask=padding_mask)[0] + x

        # Layer normalization
        norm_x = self.norm2(x)

        # Apply feedforward neural network
        x = self.mlp(norm_x) + x
        return x

    
# "Decoder-Only" Style Transformer with Attention
class Transformer(nn.Module):
    """
    "Decoder-Only" Style Transformer with self-attention.
    """

    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Transformer, self).__init__()

        # Token embeddings
        self.embedding = nn.Embedding(num_emb, hidden_size)

        # Positional embeddings
        self.pos_emb = SinusoidalPosEmb(hidden_size)

        # List of Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads) for _ in range(num_layers)
        ])

        # Output layer
        self.fc_out = nn.Linear(hidden_size, num_emb)

    def forward(self, input_seq):
        # Mask for padding tokens
        input_key_mask = input_seq == 0

        # Embedding input sequence
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape

        # Add positional embeddings to token embeddings
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb

        # Pass through Transformer blocks
        for block in self.blocks:
            embs = block(embs, padding_mask=input_key_mask)

        # Output predictions
        return self.fc_out(embs)


## Initialise Model and Optimizer

In [None]:
# Check if GPU is available, set device accordingly
device = torch.device(1 if torch.cuda.is_available() else 'cpu')

# Embedding Size
hidden_size = 256

# Number of transformer blocks
num_layers = 8

# MultiheadAttention Heads
num_heads = 8

# Create model
tf_generator = Transformer(num_emb=len(vocab), num_layers=num_layers, 
                           hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
optimizer = optim.Adam(tf_generator.parameters(), lr=learning_rate)

# Scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Define the loss function
loss_fn = nn.CrossEntropyLoss(reduction="none")

# Custom transform that will randomly replace a token with <pad>
td = TokenDrop(prob=0.2)

# Initialize training loss logger and entropy logger
training_loss_logger = []
entropy_logger = []

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in tf_generator.parameters():
    num_model_params += param.flatten().shape[0]

print("-This Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

[0, 3123, 123 12, 3213, 3213]

## Training

In [None]:
for epoch in trange(0, nepochs, leave=False, desc="Epoch"):    
    tf_generator.train()
    steps = 0
    for text in tqdm(data_loader_train, desc="Training", leave=False):
        # Convert text to tokenized input
        text_tokens = train_tranform(list(text)).to(device)
        bs = text_tokens.shape[0]
        
        # Randomly drop input tokens
        input_text = td(text_tokens[:, 0:-1])
        output_text = text_tokens[:, 1:]

        # Generate predictions
        with torch.cuda.amp.autocast():
            pred = tf_generator(input_text)

        # Calculate loss with masked cross-entropy
        mask = (output_text != 0).float()
        loss = (loss_fn(pred.transpose(1, 2), output_text) * mask).sum()/mask.sum()
        
        # Backpropagation
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Log training loss and entropy
        training_loss_logger.append(loss.item())
        with torch.no_grad():
            dist = Categorical(logits=pred)
            entropy_logger.append(dist.entropy().mean().item())

## Plot Metrics

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(training_loss_logger[10000:])
_ = plt.title("Training Loss")

In [None]:
_ = plt.figure(figsize=(10, 5))
_ = plt.plot(entropy_logger[10000:])
_ = plt.title("Distribution Entropy")

## Testing

In [None]:
# Get an example from the test set
text = next(iter(data_loader_test))

In [None]:
# Set index of the example to use
index = 0

# Set temperature for sampling
temp = 0.6

# Split text into title and content
title = text[index].split(":")[0]
init_prompt = [title + ":"]  # Create initial prompt using the title

# Extract content from text
content = text[index].split(":")[1]

# Tokenize the initial prompt
init_tokens = gen_tranform(init_prompt)

# Print initial prompt, original content, and tokenized prompt
print("INITIAL PROMPT:")
print(title)
print("")
print("ORIGINAL CONTENT:")
print(content)
print("")
print("PROMPT TOKENS:")
print(init_tokens)
print(vocab.lookup_tokens(init_tokens[0].cpu().numpy()))

In [None]:
# List to log generated tokens
log_tokens = [init_tokens]

# Set the generator model to evaluation mode
tf_generator.eval()

# Generate tokens
with torch.no_grad():    
    for i in range(10):
        # Concatenate tokens from previous iterations
        input_tokens = torch.cat(log_tokens, 1)
        
        # Get model predictions for the next token
        data_pred = tf_generator(input_tokens.to(device))
        
        # Sample the next token from the distribution of probabilities
        dist = Categorical(logits=data_pred[:, -1] / temp)
        next_tokens = dist.sample().reshape(1, 1)
        
        # Append the sampled token to the list of generated tokens
        log_tokens.append(next_tokens.cpu())
        
        # Check for end-of-sequence token and stop generation
        if next_tokens.item() == 2:
            break

In [None]:
# Concatenate generated tokens into a single string
pred_text = "".join(vocab.lookup_tokens(torch.cat(log_tokens, 1)[0].numpy()))

# Print the generated text
print(pred_text)

In [None]:
# Replace special tokens and characters in the generated text
pred_text_cleaned = pred_text.replace("▁", " ").replace("<unk>", "").replace("<sos>", "").replace("<eos>", "")

# Print the cleaned generated text
print(pred_text_cleaned)

In [None]:
# Plot the softmax probabilities of the next token
_ = plt.plot(F.softmax(data_pred[0, -1] / temp, -1).cpu().numpy().flatten())