In [48]:
# !pip install torch==2.2.0
# !pip install torchtext==0.17.2
# !pip install torchdata==0.7.1
# !pip install transformers==4.35.2
# !pip install seaborn 

In [83]:
import os
import math
import time
import warnings
from typing import Any, Iterable, List, Tuple

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Transformer
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam, AdamW, Optimizer
from torch.optim.lr_scheduler import StepLR, _LRScheduler
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab, build_vocab_from_iterator

# Suppress warnings
warnings.filterwarnings("ignore")

In [84]:
train_data, test_data = torch.load("data/imdb_dataset.pt")
print(f"Train size: {len(train_data)}, Test size: {len(test_data)}")

Train size: 25000, Test size: 25000


In [85]:
print(train_data[0])
print('=' * 50)
print(train_data[1])

(1, 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between

In [86]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [87]:
UNK_IDX, PAD_IDX, EOS_IDX = 0, 1, 2
special_symbols = ['<unk>', '<pad>', '<|endoftext|>']

tokenizer = get_tokenizer("basic_english")

In [88]:
def yield_tokens(data_iter):
    for _, data_sample in data_iter:
        yield tokenizer(data_sample)
        
vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=special_symbols)
vocab.set_default_index(UNK_IDX)

artifacts_dir = "artifacts"
os.makedirs(artifacts_dir, exist_ok=True)

vocab_path = os.path.join(artifacts_dir, "vocab.pth")
torch.save(vocab, vocab_path)

print(f"Vocabulary Length: {len(vocab)}")
print(f"Vocabulary saved to: {vocab_path}")

Vocabulary Length: 100685
Vocabulary saved to: artifacts/vocab.pth


In [56]:
torch.manual_seed(42)

train_size = int(0.8 * len(train_data))  # 20,000
val_size = len(train_data) - train_size  # 5,000
train_data, val_data = random_split(train_data, [train_size, val_size])

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

Train: 20000, Val: 5000, Test: 25000


In [57]:
text_to_index = lambda text: [vocab(token) for token in tokenizer(text)]
index_to_text = lambda seq_en: " ".join([vocab.get_itos()[index] for index in seq_en])

In [58]:
index_to_text(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))

"<unk> <pad> <|endoftext|> the . , and a of to '"

## Text Processing

In [59]:
END_OF_TEXT_TOKEN = '<|endoftext|>'

def get_training_sample(
    text: List[Any], 
    block_size: int
) -> Tuple[List[Any], List[Any]]:
    """
    Creates a single (input, target) training sample for a language model.

    From a given text, this function extracts a random contiguous block of
    `block_size` tokens for the source (X) and the subsequent, shifted
    block of tokens for the target (Y).

    Args:
        text (List[Any]): The input text, represented as a list of tokens.
        block_size (int): The desired length of the input/target sequences.

    Returns:
        A tuple containing the source sequence and the target sequence.
    """
    text_len = len(text)

    # Case 1: The text is long enough to extract a full random block.
    # We need at least `block_size + 1` tokens to create a source and a target.
    if text_len > block_size:
        max_start_idx = text_len - block_size - 1
        start_idx = torch.randint(low=0, high=max_start_idx + 1, size=(1,)).item()
        end_idx = start_idx + block_size
        
        src_sequence = text[start_idx : end_idx]
        tgt_sequence = text[start_idx + 1 : end_idx + 1]

    # Case 2: The text is too short. Use the entire available text.
    else:
        # The source is the whole text.
        src_sequence = text
        # The target is the source shifted by one, with an end-of-text token 
        # appended to create a valid target for the last source token.
        tgt_sequence = text[1:] + [END_OF_TEXT_TOKEN]
        
    return src_sequence, tgt_sequence

In [60]:
batch_of_tokens=[]

for i in range(2):
    _ , text = train_data[i]
    batch_of_tokens.append(tokenizer(text))

In [61]:
for i in range(2):
    text = batch_of_tokens[i][0:50]
    src_sequences, tgt_sequence = get_training_sample(text, 10)
    print("src: ", src_sequences)
    print("tgt: ", tgt_sequence)
    print('=' * 100)

src:  ['this', 'film', 'while', 'at', 'birmingham', 'southern', 'college', 'in', '1975', ',']
tgt:  ['film', 'while', 'at', 'birmingham', 'southern', 'college', 'in', '1975', ',', 'when']
src:  ['as', 'much', 'as', 'the', 'central', 'character', 'in', 'this', 'film', '.']
tgt:  ['much', 'as', 'the', 'central', 'character', 'in', 'this', 'film', '.', 'in']


In [62]:
for i in range(2):
    text = batch_of_tokens[i][0:50]
    src_sequences, tgt_sequence = get_training_sample(text, 50)
    print("src: ", src_sequences)
    print("tgt: ", tgt_sequence)
    print('=' * 100)

src:  ['i', 'saw', 'this', 'film', 'while', 'at', 'birmingham', 'southern', 'college', 'in', '1975', ',', 'when', 'it', 'was', 'shown', 'in', 'combination', 'with', 'the', 'red', 'balloon', '.', 'both', 'films', 'are', 'similar', 'in', 'their', 'dream-like', 'quality', '.', 'the', 'bulk', 'of', 'the', 'film', 'entails', 'a', 'fish', 'swimming', 'happily', 'in', 'his', 'bowl', 'while', 'his', 'new', 'owner', ',']
tgt:  ['saw', 'this', 'film', 'while', 'at', 'birmingham', 'southern', 'college', 'in', '1975', ',', 'when', 'it', 'was', 'shown', 'in', 'combination', 'with', 'the', 'red', 'balloon', '.', 'both', 'films', 'are', 'similar', 'in', 'their', 'dream-like', 'quality', '.', 'the', 'bulk', 'of', 'the', 'film', 'entails', 'a', 'fish', 'swimming', 'happily', 'in', 'his', 'bowl', 'while', 'his', 'new', 'owner', ',', '<|endoftext|>']
src:  ['hi', 'all', 'i', 'am', 'a', 'chess', 'enthusiast', 'since', 'the', 'age', 'of', 'about', '6', '.', 'i', 'supposed', 'i', 'am', 'quite', 'obsessed', 

In [63]:
# Initialize empty lists to store source and target sequences
src_batch, tgt_batch = [], []
BATCH_SIZE = 2
block_size = 20

for i in range(BATCH_SIZE):
    # Retrieve the next data point from the training iterator
    _, text = train_data[i]

    # Generate source and target sequences using the get_sample function
    src_sequence_text, tgt_sequence_text = get_training_sample(tokenizer(text), block_size)

    # Convert source and target sequences to tokenized vocabulary indices
    src_sequence_indices = vocab(src_sequence_text)
    tgt_sequence_indices = vocab(tgt_sequence_text)

    # Convert the sequences to PyTorch tensors with dtype int64
    src_sequence = torch.tensor(src_sequence_indices, dtype=torch.int64)
    tgt_sequence = torch.tensor(tgt_sequence_indices, dtype=torch.int64)

    # Append the source and target sequences to their respective batches
    src_batch.append(src_sequence)
    tgt_batch.append(tgt_sequence)

    print(f"Sample {i}:")
    print("Source Sequence (Text):", src_sequence_text)
    print("Source Sequence (Indices):", src_sequence_indices)
    print("Source Sequence (Shape):", src_sequence.shape)
    print("Target Sequence (Text):", tgt_sequence_text)
    print("Target Sequence (Indices):", tgt_sequence_indices)
    print("Target Sequence (Shape):", tgt_sequence.shape)
    print('=' * 100)

Sample 0:
Source Sequence (Text): ['complexity', '.', 'it', 'is', 'hard', 'to', 'imagine', 'how', 'the', 'director', 'could', "'", 've', 'pulled', 'the', 'technical', 'feat', 'back', 'in', '1959']
Source Sequence (Indices): [4600, 4, 12, 11, 274, 9, 823, 96, 3, 174, 105, 10, 147, 1860, 3, 1709, 5830, 154, 13, 6396]
Source Sequence (Shape): torch.Size([20])
Target Sequence (Text): ['.', 'it', 'is', 'hard', 'to', 'imagine', 'how', 'the', 'director', 'could', "'", 've', 'pulled', 'the', 'technical', 'feat', 'back', 'in', '1959', '--']
Target Sequence (Indices): [4, 12, 11, 274, 9, 823, 96, 3, 174, 105, 10, 147, 1860, 3, 1709, 5830, 154, 13, 6396, 377]
Target Sequence (Shape): torch.Size([20])
Sample 1:
Source Sequence (Text): ['suicide', 'in', '1924', '.', 'he', 'is', 'famous', 'for', 'a', 'game', 'he', 'played', 'against', 'steinitz', ',', 'where', 'a', 'beautiful', 'combination', 'was']
Source Sequence (Indices): [1746, 13, 16738, 4, 31, 11, 789, 20, 7, 502, 31, 260, 434, 49864, 5, 124,

## Collate Function

In [64]:
BLOCK_SIZE = 30

def collate_batch(batch):
    src_batch, tgt_batch = [], []
    for _, text in batch:
        src_sequence, tgt_sequence = get_training_sample(tokenizer(text), BLOCK_SIZE)
        src_sequence, tgt_sequence = vocab(src_sequence), vocab(tgt_sequence)
      
        src_sequence= torch.tensor(src_sequence, dtype=torch.int64)
        tgt_sequence = torch.tensor(tgt_sequence, dtype=torch.int64)
        
        src_batch.append(src_sequence)
        tgt_batch.append(tgt_sequence)


    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=False)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=False)

    return src_batch.to(DEVICE), tgt_batch.to(DEVICE)

In [65]:
dataloader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collate_batch)

In [66]:
for idx, batch in enumerate(dataloader):
    print(f"\nBatch {idx}")
    src, tgt = batch  # unpack the tuple

    print("Source (src) shape:", src.shape)
    print("Target (tgt) shape:", tgt.shape)
    print('=' * 100)
    
    print("\nSamples:")
    for i in range(min(2, src.shape[1])):  # only show up to 2 samples
        print(f"Sample {i+1}:")
        print("  source:", index_to_text(src[:, i]))  
        print("  target:", index_to_text(tgt[:, i])) 
        print('=' * 100)
    break  # only inspect first batch


Batch 0
Source (src) shape: torch.Size([30, 2])
Target (tgt) shape: torch.Size([30, 2])

Samples:
Sample 1:
  source: photographed with that wonderful opening of guinness and his son driving down the champs elysee with the arc de triomphe in the background . unfortunately it goes downhill from there
  target: with that wonderful opening of guinness and his son driving down the champs elysee with the arc de triomphe in the background . unfortunately it goes downhill from there .
Sample 2:
  source: . and i don ' t believe there are many cynics who would say that people aren ' t capable of change and redemption . this film version portrays all
  target: and i don ' t believe there are many cynics who would say that people aren ' t capable of change and redemption . this film version portrays all of


## Masking

In transformers, masking is crucial for ensuring certain positions are not attended to. The function ```generate_square_subsequent_mask``` produces an upper triangular matrix, which ensures that during decoding, a token can't attend to future tokens of target.

In [67]:
# 1. the raw attention scores (from Q @ K.T)  # [4, 4]
raw_scores = torch.tensor([
    [0.8, 0.2, 0.9, 1.4],
    [0.5, 0.7, 1.1, 0.1],
    [1.2, 0.3, 0.6, 0.4],
    [0.9, 1.5, 0.8, 0.2]
])

# the causal mask
# This mask prevents positions from attending to subsequent positions.
# 0.0 means "allowed", -inf means "prevented".
mask = torch.tensor([
    [0.0, float('-inf'), float('-inf'), float('-inf')],
    [0.0, 0.0,        float('-inf'), float('-inf')],
    [0.0, 0.0,        0.0,        float('-inf')],
    [0.0, 0.0,        0.0,        0.0]
])

print("Step 1: Raw Attention Scores")
print(raw_scores)
print("-" * 30)

print("Step 2: Causal Mask")
print(mask)
print("-" * 30)


# Add the mask to the raw scores
masked_scores = raw_scores + mask

print("Step 3: Masked Scores (Scores + Mask)")
print(masked_scores)
print("-" * 30)

# Apply the softmax function to get the final attention weights
# dim=1 ensures softmax is applied row-wise.
attention_weights = F.softmax(masked_scores, dim=1)

print("Step 4: Final Attention Weights after Softmax")
print(attention_weights)
print("-" * 30)

print("Sum of each row in the final weights:")
print(attention_weights.sum(dim=1))

Step 1: Raw Attention Scores
tensor([[0.8000, 0.2000, 0.9000, 1.4000],
        [0.5000, 0.7000, 1.1000, 0.1000],
        [1.2000, 0.3000, 0.6000, 0.4000],
        [0.9000, 1.5000, 0.8000, 0.2000]])
------------------------------
Step 2: Causal Mask
tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])
------------------------------
Step 3: Masked Scores (Scores + Mask)
tensor([[0.8000,   -inf,   -inf,   -inf],
        [0.5000, 0.7000,   -inf,   -inf],
        [1.2000, 0.3000, 0.6000,   -inf],
        [0.9000, 1.5000, 0.8000, 0.2000]])
------------------------------
Step 4: Final Attention Weights after Softmax
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4502, 0.5498, 0.0000, 0.0000],
        [0.5114, 0.2079, 0.2807, 0.0000],
        [0.2368, 0.4314, 0.2142, 0.1176]])
------------------------------
Sum of each row in the final weights:
tensor([1.0000, 1.0000, 1.0000, 1.0000])


In [68]:
def generate_square_subsequent_mask(sz, device=DEVICE):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

    src_mask (Causal Mask): This is the "No Cheating" rule.
    src_padding_mask: This is the "Ignore Blank Pages" rule. ignore padding tokens

In [69]:
def create_mask(src, device=DEVICE):
    src_seq_len = src.shape[0]       # (sequence_length, batch_size)
    # This is our "No Cheating" mask, filled with 0.0s and -inf to prevent the model from seeing future tokens during self-attention.
    src_mask = generate_square_subsequent_mask(src_seq_len)  # (src_seq_len, src_seq_len)
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    return src_mask, src_padding_mask

In [70]:
src[0:3, :] = PAD_IDX

In [71]:
src_mask, src_padding_mask = create_mask(src)

In [72]:
src_mask.shape, src_padding_mask.shape

(torch.Size([30, 30]), torch.Size([2, 30]))

In [73]:
src_mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -i

In [74]:
src_padding_mask

tensor([[ True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False]],
       device='cuda:0')

## Positional encoding

    emb_size = 512
    maxlen = 5000
    batch_size = 64
    seq_len = 100   (the length of an input sequence in the forward pass)

In [75]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, dropout, maxlen=5000):
        super(PositionalEncoding, self).__init__()

        # pos: (maxlen, 1) -> (5000, 1)
        # Creates a column vector of positions [0, 1, ..., 4999]
        pos = torch.arange(maxlen).unsqueeze(1)

        # i: (emb_size / 2,) -> (256,)
        # Creates a row vector for the even dimension indices [0, 2, ..., 510]
        i = torch.arange(0, emb_size, 2)

        # angle_rates: (maxlen, emb_size / 2) -> (5000, 256)
        # Calculates the arguments for sin/cos using broadcasting.
        # (5000, 1) / (256,) results in a (5000, 256) matrix.
        angle_rates = pos / (10000 ** (i.float() / emb_size))

        # pos_encoding: (maxlen, emb_size) -> (5000, 512)
        pos_encoding = torch.zeros(maxlen, emb_size)
        
        # Fills even indices (0, 2, ...) with sin values.
        # The slice pos_encoding[:, 0::2] has shape (5000, 256).
        pos_encoding[:, 0::2] = torch.sin(angle_rates)

        # Fills odd indices (1, 3, ...) with cos values.
        # The slice pos_encoding[:, 1::2] has shape (5000, 256).
        pos_encoding[:, 1::2] = torch.cos(angle_rates)

        # --- Finalizing and Storing ---
        # pos_encoding: (maxlen, 1, emb_size) -> (5000, 1, 512)
        # Adds a dimension for batch broadcasting in the forward pass.
        pos_encoding = pos_encoding.unsqueeze(1)
        
        # Registers 'pos_encoding' as a buffer. It's part of the model's state
        # but not a parameter to be trained.
        self.register_buffer('pos_encoding', pos_encoding)
        self.dropout = nn.Dropout(dropout)

    def forward(self, token_embedding: Tensor):
        # token_embedding (input): (seq_len, batch_size, emb_size) -> (100, 64, 512)
        seq_len = token_embedding.size(0)

        # Add positional encoding to token embedding.
        # self.pos_encoding[:seq_len, :] slices the buffer to get shape (100, 1, 512).
        # Broadcasting adds this to token_embedding (100, 64, 512).
        # The result has shape (100, 64, 512).
        output = token_embedding + self.pos_encoding[:seq_len, :]
        return self.dropout(output)

## Token embedding

The `TokenEmbedding` class below converts numerical tokens into embeddings:  

    * math.sqrt(self.emb_size)
    From the original "Attention Is All You Need" paper. The output of the embedding lookup is scaled by the square root of the embedding size.


In [76]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        # Creates a lookup table of shape (vocab_size, emb_size)
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        # Input 'tokens' shape: (seq_len, batch_size)
        # Output shape: (seq_len, batch_size, emb_size)
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

## Custom GPT Model Architecture

The `CustomGPTModel` class defines a transformer-based model architecture for generative pre-trained models. This model aims to generate text and perform various NLP tasks.

In [77]:
class CustomGPTModel(nn.Module):
    """
    A custom GPT-style Transformer model for language generation.
    """
    def __init__(self, 
                 vocab_size: int, 
                 embed_size: int, 
                 num_heads: int, 
                 num_layers: int, 
                 dropout: float = 0.1,
                 max_seq_len: int = 5000):
        super().__init__()
        
        # Input Embedding Pipeline
        self.embedding_pipeline = nn.Sequential(
            TokenEmbedding(vocab_size, embed_size),
            PositionalEncoding(embed_size, dropout=dropout, maxlen=max_seq_len)
        )

        # Core Transformer Blocks
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_size, 
            nhead=num_heads, 
            dropout=dropout, 
            batch_first=False     # Expects (seq_len, batch_size, embed_size)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        # Output Projection Layer
        self.lm_head = nn.Linear(embed_size, vocab_size)

        # Initialize weights after all layers are defined
        self.init_weights()

    def init_weights(self):
        """Initializes model weights using Xavier uniform distribution."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    @staticmethod
    def create_masks(src: Tensor, device):
        """
        Creates the causal (look-ahead) and padding masks for the source sequence.
        """
        seq_len = src.shape[0]
        
        # Causal mask: Prevents attending to future tokens.
        # Shape: (seq_len, seq_len)
        causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device)
        
        # Padding mask: Prevents attending to <pad> tokens.
        # Shape: (batch_size, seq_len)
        padding_mask = (src == PAD_IDX).transpose(0, 1)
        
        return causal_mask, padding_mask
    

    def forward(self, src: Tensor):
        """
        Defines the forward pass of the model.
        
        Args:
            src (Tensor): Input tensor of token IDs.
                          Shape: (seq_len, batch_size)
        
        Returns:
            Tensor: Output logits over the vocabulary.
                    Shape: (seq_len, batch_size, vocab_size)
        """
        # Create masks based on the input tensor.
        src_mask, src_padding_mask = self.create_masks(src, device=src.device)
        
        # Prepare input: Apply token embedding and positional encoding.
        # src shape: (seq_len, batch_size) -> (seq_len, batch_size, embed_size)
        src_emb = self.embedding_pipeline(src)

        # Pass through the main Transformer blocks.
        # Shape remains: (seq_len, batch_size, embed_size)
        output = self.transformer_encoder(
            src_emb, 
            mask=src_mask, 
            src_key_padding_mask=src_padding_mask
        )
        
        # Project to vocabulary space to get final logits.
        # output shape: (seq_len, batch_size, embed_size) -> (seq_len, batch_size, vocab_size)
        logits = self.lm_head(output)
        
        return logits

In [82]:
a = [1, 2, 3, 4, 5]
a[-15:]

[1, 2, 3, 4, 5]

In [78]:
def encode_prompt(
    prompt: str, 
    tokenizer,
    vocab,
    block_size,
    device
):
    """
    Encodes a string prompt into a tensor suitable for model input.

    This function handles prompt validation, tokenization, truncation of long 
    prompts, and conversion to a correctly shaped tensor on the specified device.

    Returns:
        Tensor: The encoded prompt as a tensor of shape (sequence_length, 1).
    """
    if not prompt or not prompt.strip():
        raise ValueError("Prompt cannot be empty or contain only whitespace.")

    tokens = tokenizer(prompt)

    # Truncate from the left if prompt exceeds the block size
    if len(tokens) > block_size:
        tokens = tokens[-block_size:]

    # Convert tokens to numerical indices
    indices = vocab(tokens)
    # Shape: [seq_len] -> [seq_len, 1]
    return torch.tensor(indices, dtype=torch.long, device=device).unsqueeze(1)

In [32]:
def decode_tokens(token_ids, vocab):
    id_list = token_ids.flatten().tolist()
    tokens = vocab.get_itos()(id_list)
    return " ".join(tokens)

In [33]:
@torch.no_grad()
def generate_text(
    model,
    prompt,
    tokenizer,
    vocab,
    block_size,
    max_new_tokens,
    device
):
    """
    Generates a sequence of text autoregressively using a trained model.

    Returns:
        str: The generated text, including the prompt.
    """
    model.eval()

    # Encode the initial prompt
    context = encode_prompt(
        prompt=prompt,
        tokenizer=tokenizer,
        vocab=vocab,
        block_size=block_size,
        device=device
    ) # Shape: (prompt_len, 1)

    # The autoregressive generation loop
    for _ in range(max_new_tokens):
        context_cond = context[-block_size:]
        
        # Forward pass with the conditioned context
        # logits shape: (current_seq_len, 1, vocab_size)
        logits = model(context_cond)

        # Get logits for the very last token in the sequence
        # Shape: (1, vocab_size)
        last_token_logits = logits[-1, :, :]
        
        # Greedily select the most likely next token
        # Shape: (1, 1)
        next_token = torch.argmax(last_token_logits, dim=-1, keepdim=True)
        
        # Check for end-of-sequence token
        if next_token.item() == EOS_IDX:
            break
            
        # Append the predicted token to the running sequence
        context = torch.cat([context, next_token], dim=0)

    # Decode the final sequence of token IDs back to a string
    generated_text = decode_tokens(context, vocab)
    
    return generated_text

In [34]:
tokenizer

<function torchtext.data.utils._basic_english_normalize(line)>

In [35]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

VOCAB_SIZE = len(vocab)
EMBED_SIZE = 256
NUM_HEADS = 2
NUM_LAYERS = 2
BLOCK_SIZE = 10
DROPOUT = 0.1

# Instantiate the Model
model = CustomGPTModel(
    vocab_size=VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    max_seq_len=BLOCK_SIZE,
    dropout=DROPOUT
).to(DEVICE)

In [36]:
prompt = "The sun rises in the"
print(f"--- Prompt --- \n{prompt}\n")

generated_text = generate_text(
    model=model,
    prompt=prompt,
    tokenizer=tokenizer,
    vocab=vocab,
    block_size=BLOCK_SIZE,
    max_new_tokens=100,
    device=DEVICE
)

print(f"--- Generated Text --- \n{generated_text}")

--- Prompt --- 
The sun rises in the

--- Generated Text --- 
pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking pseudo-shocking


In [43]:
def train_one_epoch(
    model,
    dataloader,
    criterion,
    optimizer,
    scheduler,
    device
):
    """
    Trains the GPT-style language model for one epoch.
    """
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Training Epoch")

    for src, tgt in progress_bar:
        src, tgt = src.to(device), tgt.to(device)

        # Forward pass
        logits = model(src)

        # Reshape for loss calculation
        # logits: [seq_len, batch_size, vocab_size] -> [seq_len * batch_size, vocab_size]
        # tgt:    [seq_len, batch_size] -> [seq_len * batch_size]
        logits_flat = logits.reshape(-1, logits.shape[-1])
        loss = criterion(logits_flat, tgt.reshape(-1))

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    return total_loss / len(dataloader)

In [44]:
def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device
) -> Tuple[float, float, float]:
    """
    Evaluates the GPT-style language model.
    
    Returns:
        Tuple[float, float, float]: Average loss, accuracy, and perplexity.
    """
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_tokens = 0

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating")
        for src, tgt in progress_bar:
            src, tgt = src.to(device), tgt.to(device)

            logits = model(src)
            
            # --- Loss Calculation ---
            logits_flat = logits.reshape(-1, logits.shape[-1])
            tgt_flat = tgt.reshape(-1)
            loss = criterion(logits_flat, tgt_flat)
            total_loss += loss.item()

            # --- Accuracy Calculation (Per-Token) ---
            preds = torch.argmax(logits_flat, dim=1)
            non_pad_mask = (tgt_flat != PAD_IDX)
            total_correct += (preds[non_pad_mask] == tgt_flat[non_pad_mask]).sum().item()
            total_tokens += non_pad_mask.sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_tokens if total_tokens > 0 else 0
    perplexity = math.exp(avg_loss)
    
    return avg_loss, accuracy, perplexity

In [45]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VOCAB_SIZE = len(vocab)
EMBED_SIZE = 256
NUM_HEADS = 2
NUM_LAYERS = 2
DROPOUT = 0.1
MAX_SEQ_LEN = 512
LEARNING_RATE = 0.0001
NUM_EPOCHS = 3
BATCH_SIZE = 32

UNK_IDX, PAD_IDX, EOS_IDX = 0, 1, 2


train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)


loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)


# A simple learning rate scheduler
scheduler = StepLR(optimizer, step_size=1.0, gamma=0.95)

print("Initializing model...")
model = CustomGPTModel(
    vocab_size=VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    max_seq_len=MAX_SEQ_LEN
).to(DEVICE)

Initializing model...


In [46]:
print("Starting training...")
best_val_loss = float('inf')

for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()

    train_loss = train_one_epoch(model, train_dataloader, loss_fn, optimizer, scheduler, DEVICE)

    val_loss, val_accuracy, val_perplexity = evaluate(model, val_dataloader, loss_fn, DEVICE)

    epoch_duration = time.time() - epoch_start_time
    
    # --- 5. LOGGING RESULTS ---
    print("-" * 60)
    print(f"| End of Epoch {epoch:3d} | Time: {epoch_duration:5.2f}s | "
          f"Train Loss: {train_loss:5.3f} | Val Loss: {val_loss:5.3f} | "
          f"Val PPL: {val_perplexity:8.2f} | Val Acc: {val_accuracy*100:5.2f}%")
    print("-" * 60)

Training Epoch:   1%|          | 6/625 [00:00<00:10, 59.80it/s, loss=11.5]

Starting training...


Training Epoch: 100%|██████████| 625/625 [00:10<00:00, 60.59it/s, loss=11.5]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 118.52it/s]
Training Epoch:   1%|          | 7/625 [00:00<00:09, 61.96it/s, loss=11.5]

------------------------------------------------------------
| End of Epoch   1 | Time: 11.64s | Train Loss: 11.519 | Val Loss: 11.517 | Val PPL: 100443.22 | Val Acc:  0.00%
------------------------------------------------------------


Training Epoch: 100%|██████████| 625/625 [00:10<00:00, 60.60it/s, loss=11.5]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 122.16it/s]
Training Epoch:   1%|          | 7/625 [00:00<00:10, 61.07it/s, loss=11.5]

------------------------------------------------------------
| End of Epoch   2 | Time: 11.60s | Train Loss: 11.519 | Val Loss: 11.518 | Val PPL: 100499.13 | Val Acc:  0.00%
------------------------------------------------------------


Training Epoch: 100%|██████████| 625/625 [00:10<00:00, 60.71it/s, loss=11.5]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 122.33it/s]

------------------------------------------------------------
| End of Epoch   3 | Time: 11.58s | Train Loss: 11.518 | Val Loss: 11.517 | Val PPL: 100440.23 | Val Acc:  0.00%
------------------------------------------------------------





In [47]:
prompt = "The meaning of life is"
generated_text = generate_text(model, prompt, tokenizer, vocab, MAX_SEQ_LEN, 100, DEVICE)
print("\n--- Example Generation ---")
print(generated_text)


--- Example Generation ---
the meaning of life is spitied additive additive tarantinism tarantinism tarantinism tarantinism montford montford montford montford 15minutes 15minutes 15minutes 15minutes 15minutes 15minutes 15minutes 15minutes 15minutes struggles spanky montford confused confused confused --until --until rainforests rainforests rainforests rainforests rainforests tarantinism prepare prepare prepare prepare prepare prepare prepare prepare prepare prepare prepare rainforests rainforests rainforests rainforests rainforests rainforests rainforests rainforests semi-rural semi-rural semi-rural semi-rural rainforests rainforests rainforests rainforests rainforests pelicule pelicule pelicule rainforests rainforests rainforests rainforests rainforests sushant sushant tarantinism tarantinism t4 t4 t4 blecher blecher rainforests rainforests rainforests rainforests rainforests rainforests rainforests reviews reviews rainforests rainforests rainforests rainforests prosecutor boy-o-boy