We're going to build, train, and evaluate a PyTorch-based decoder-only transformer to predict tokens. This is an exercise from Chapter 4 of **The Hundred Page Language Models Book**.

Here is the architecture for a single decoder block:

<img src="../images/transformer-decoder-block-architecture.png" style="width:800px; border: 1px solid #000;"/>

Each decoder-only transformer model will be have 1 or more of these blocks.

In [4]:
from datetime import datetime, timedelta
import math
import os
import random
import re
import tarfile
import torch
import torch.nn as nn
import torch.nn.functional as F  # softmax
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoTokenizer
import urllib.request

## Utilities

In [5]:
#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

#
# source: https://github.com/aburkov/theLMbook/blob/main/news_decoder_language_model.ipynb
#
class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization
    A simplified alternative to Layer Normalization that only uses RMS statistics
    """
    def __init__(self, emb_dim, epsilon=1e-8):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(emb_dim))  # Learnable scale parameter
        self.epsilon = epsilon  # Small constant for numerical stability

    def forward(self, x):
        # Compute root mean square normalization
        squared_x = x ** 2
        mean_squared = torch.mean(squared_x, dim=-1, keepdim=True)
        rms = torch.sqrt(mean_squared + self.epsilon)

        # Normalize and scale
        x_normalized = x / rms
        output = x_normalized * self.scale
        return output


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_decoder_language_model.ipynb
#
def rope(x, theta_base=10000.0):
    """
    Implements Rotary Position Embedding (RoPE) for transformer attention.
    RoPE encodes position information through rotation matrices applied to pairs of dimensions.

    Args:
        x: Input tensor of shape (batch_size, seq_len, emb_dim)
        theta_base: Base for computing rotation frequencies (default: 10000.0)

    Returns:
        Tensor with position information encoded through rotations
    """
    batch_size, seq_len, emb_dim = x.size()
    assert emb_dim % 2 == 0, "Embedding dimensionality must be even for RoPE"

    # Generate sequence position indices
    pos = torch.arange(0, seq_len, dtype=torch.float32, device=x.device)
    pos = pos.unsqueeze(0).expand(batch_size, seq_len)

    # Compute frequency bands for each dimension pair
    # Modified: frequencies start from p=1 and use (p-1) in exponent
    p = torch.arange(1, emb_dim // 2 + 1, dtype=torch.float32, device=x.device)
    theta_p = 1.0 / (theta_base ** (2 * (p - 1) / emb_dim))

    # Compute rotation angles for each position and frequency
    pos = pos.unsqueeze(-1)
    theta = pos * theta_p

    # Compute rotation components
    sin_theta = torch.sin(theta)
    cos_theta = torch.cos(theta)

    # Split input into alternating dimensions
    x1 = x[..., 0::2]  # Dimensions at indices 0,2,4,...
    x2 = x[..., 1::2]  # Dimensions at indices 1,3,5,...

    # Apply 2D rotations to each pair
    x_rotated_1 = x1 * cos_theta - x2 * sin_theta
    x_rotated_2 = x1 * sin_theta + x2 * cos_theta

    # Recombine rotated pairs into final output
    x_rotated = torch.stack((x_rotated_1, x_rotated_2), dim=-1).reshape(batch_size, seq_len, emb_dim)

    return x_rotated



#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def check_file_exists(filename):
    """
    Checks if a file exists in the current directory.
    Args:
        filename (str): Name of the file to check
    Returns:
        bool: True if file exists, False otherwise
    """
    return os.path.exists(filename)


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def create_collate_fn(tokenizer):
    """
    Creates a collate function for batching sequences of different lengths.
    This function pads shorter sequences to match the longest sequence in the batch.

    Args:
        tokenizer: Tokenizer object containing padding token information

    Returns:
        function: Collate function that handles padding in batches
    """
    def collate_fn(batch):
        # Separate inputs and targets from batch
        input_seqs, target_seqs = zip(*batch)
        # Get padding token ID from tokenizer
        pad_index = tokenizer.pad_token_id
        # Pad input sequences to same length
        input_padded = nn.utils.rnn.pad_sequence(input_seqs, batch_first=True, padding_value=pad_index)
        # Pad target sequences to same length
        target_padded = nn.utils.rnn.pad_sequence(target_seqs, batch_first=True, padding_value=pad_index)
        return input_padded, target_padded
    return collate_fn

#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def download_and_prepare_data(url, batch_size, tokenizer, max_length=30):
    """
    Main function to handle the complete data preparation pipeline.
    Downloads data, extracts it, and creates necessary dataset objects.

    Args:
        url (str): URL where the dataset archive can be downloaded
        batch_size (int): Batch size for data loading
        tokenizer: Tokenizer object for text processing
        max_length (int): Maximum sequence length for tokenization (default: 30)

    Returns:
        tuple: (train_dataloader, test_dataloader) - Ready-to-use data loaders
    """
    # Step 1: Download dataset archive from URL
    filename = download_file(url)

    # Step 2: Extract training and test files from archive
    train_file, test_file = extract_dataset(filename)

    # Step 3: Create dataset objects for streaming data
    train_dataset, test_dataset = create_datasets(train_file, test_file, tokenizer, max_length)

    # Step 4: Create function to handle batch creation
    collate_fn = create_collate_fn(tokenizer)

    # Step 5: Create and return data loaders
    return create_dataloaders(train_dataset, test_dataset, batch_size, collate_fn)


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def download_file(url):
    """
    Downloads a file from the given URL if it doesn't exist locally.
    Uses a custom User-Agent to help prevent download blocks.

    Args:
        url (str): URL of the file to download
    Returns:
        str: Name of the downloaded file ("news.tar.gz")
    """
    # Always use news.tar.gz as the filename, regardless of URL
    filename = "news.tar.gz"

    if not check_file_exists(filename):
        print(f"Downloading dataset from {url}...")
        req = urllib.request.Request(
            url,
            headers={"User-Agent": "Mozilla/5.0"}
        )
        with urllib.request.urlopen(req) as response:
            with open(filename, "wb") as out_file:
                out_file.write(response.read())
        print("Download completed.")
    else:
        print(f"{filename} already downloaded.")
    return filename


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def extract_dataset(filename):
    """
    Extracts train.txt and test.txt from the downloaded archive.
    Includes debug information about archive contents.

    Args:
        filename (str): Name of the archive file
    Returns:
        tuple: Paths to extracted train and test files
    """
    data_dir = os.path.join(os.path.dirname(filename), "news")
    train_path = os.path.join(data_dir, "train.txt")
    test_path = os.path.join(data_dir, "test.txt")

    if check_file_exists(train_path) and check_file_exists(test_path):
        print("Data files already extracted.")
        return train_path, test_path

    print("\nListing archive contents:")
    with tarfile.open(filename, "r:gz") as tar:
        for member in tar.getmembers():
            print(f"Archive member: {member.name}")

        print("\nExtracting files...")
        # Extract to current directory first
        tar.extractall('.')

    if not (check_file_exists(train_path) and check_file_exists(test_path)):
        raise FileNotFoundError(f"Required files not found in the archive. Please check the paths above.")

    print("Extraction completed.")
    return train_path, test_path

    
#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def create_datasets(train_file, test_file, tokenizer, max_length=30):
    """
    Creates IterableTextDataset objects for training and testing.
    These datasets will stream data from disk instead of loading it all into memory.

    Args:
        train_file (str): Path to training data file
        test_file (str): Path to test data file
        tokenizer: Tokenizer object for text processing

    Returns:
        tuple: (train_dataset, test_dataset) - Dataset objects for training and testing
    """
    # Create training dataset
    train_dataset = IterableTextDataset(train_file, tokenizer, max_length)
    # Create test dataset
    test_dataset = IterableTextDataset(test_file, tokenizer, max_length)

    # Print dataset sizes
    print(f"Training sentences: {len(train_dataset)}")
    print(f"Test sentences: {len(test_dataset)}")

    return train_dataset, test_dataset

    
#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def create_dataloaders(train_dataset, test_dataset, batch_size, collate_fn):
    """
    Creates DataLoader objects for efficient data iteration.

    Args:
        train_dataset: Training dataset
        test_dataset: Test dataset
        batch_size (int): Number of sequences per batch
        collate_fn: Function to handle padding and batch creation

    Returns:
        tuple: (train_dataloader, test_dataloader) - DataLoader objects for
               iterating over batches of data with proper padding
    """
    # Create training data loader
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,    # Function to handle padding
        num_workers=0             # Number of worker processes (0 = single process)
    )
    # Create test data loader
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=0
    )
    return train_dataloader, test_dataloader


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
class IterableTextDataset(IterableDataset):
    """
    An iterable dataset for processing text data in a memory-efficient way.
    Instead of loading all data into memory, it streams data from disk.
    Inherits from PyTorch's IterableDataset for streaming support.

    Args:
        file_path (str): Path to the text file containing sentences
        tokenizer: Tokenizer object for converting text to tokens
        max_length (int): Maximum sequence length to process (default: 30)
    """
    def __init__(self, file_path, tokenizer, max_length=30):
        # Store file path for reading data
        self.file_path = file_path
        # Store tokenizer for text processing
        self.tokenizer = tokenizer
        # Set maximum sequence length to truncate long sequences
        self.max_length = max_length
        self._count_sentences()

    def __iter__(self):
        """
        Creates an iterator over the dataset.
        This method is called when iterating over the dataset.

        Yields:
            tuple: (input_sequence, target_sequence) pairs for language modeling
                  input_sequence is the sequence up to the last token
                  target_sequence is the sequence shifted one position right
        """
        # Open file in read mode with UTF-8 encoding
        with open(self.file_path, 'r', encoding="utf-8") as f:
            # Process each line (sentence) in the file
            for line in f:
                # Remove leading/trailing whitespace
                sentence = line.strip()
                # Replace all numbers with ### placeholder
                # This reduces vocabulary size and helps model generalize
                sentence = re.sub(r"\d+", "###", sentence)

                # Convert sentence to token IDs
                encoded_sentence = self.tokenizer.encode(
                    sentence,
                    max_length=self.max_length,
                    truncation=True
                )

                # Only use sequences with at least 2 tokens
                # (need at least one input and one target token)
                if len(encoded_sentence) >= 2:
                    # Input is all tokens except last
                    input_seq = encoded_sentence[:-1]
                    # Target is all tokens except first
                    target_seq = encoded_sentence[1:]
                    # Convert to PyTorch tensors and yield
                    yield torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)
    def __len__(self):
        return self._num_sentences

    def _count_sentences(self):
        print(f"Counting sentences in {self.file_path}...")
        with open(self.file_path, 'r', encoding="utf-8") as f:
            self._num_sentences = sum(1 for _ in f)
        print(f"Found {self._num_sentences} sentences in {self.file_path}.")


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_decoder_language_model.ipynb
#
def initialize_weights(model):
    """
    Initialize the weights of different model components using appropriate schemes.
    Each layer type receives specialized initialization for optimal training.
    """
    for module in model.modules():
        if isinstance(module, nn.Linear):
            # Xavier uniform initialization for linear layers
            # Helps maintain variance across network layers
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)  # Initialize biases to zero
        elif isinstance(module, nn.Embedding):
            # Initialize embedding layers with normal distribution
            nn.init.normal_(module.weight, mean=0, std=0.02)
            if module.padding_idx is not None:
                # Ensure padding tokens have zero embeddings
                with torch.no_grad():
                    module.weight[module.padding_idx].fill_(0)
        elif isinstance(module, AttentionHead):
            # Initialize query, key, and value projection matrices
            # Xavier uniform helps maintain good gradient flow
            nn.init.xavier_uniform_(module.W_Q)
            nn.init.xavier_uniform_(module.W_K)
            nn.init.xavier_uniform_(module.W_V)
        elif isinstance(module, MultiHeadAttention):
            # Initialize output projection matrix for attention mechanism
            nn.init.xavier_uniform_(module.W_O)
        elif isinstance(module, DecoderLanguageModel):
            # Initialize final output projection layer
            nn.init.xavier_uniform_(module.output)
        elif isinstance(module, RMSNorm):
            # Initialize RMSNorm scale parameters to ones
            # This starts with identity transformation
            nn.init.ones_(module.scale)
        elif isinstance(module, MLP):
            # Initialize feed-forward network parameters
            nn.init.xavier_uniform_(module.W_1)
            nn.init.xavier_uniform_(module.W_2)
            nn.init.zeros_(module.B_1)
            nn.init.zeros_(module.B_2)


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def compute_loss_and_perplexity(model, dataloader, tokenizer, criterion, device, max_sentences=1000):
    """
    Evaluates model performance by computing loss and perplexity on data.

    Args:
        model (nn.Module): The language model to evaluate
        dataloader (DataLoader): Data loader containing batched sequences
        tokenizer: Tokenizer for handling special tokens like padding
        criterion: Loss function (usually CrossEntropyLoss)
        device: Device to run computation on (cuda/cpu)
        max_sentences (int): Maximum number of sentences to evaluate (default: 1000)
                           Limits evaluation to a subset for faster validation

    Returns:
        tuple: (average_loss, perplexity)
               - average_loss: Mean loss per token (excluding padding)
               - perplexity: exp(average_loss), lower is better
    """
    # Set model to evaluation mode (disables dropout, etc.)
    model.eval()

    # Initialize counters for loss calculation
    total_loss = 0.0          # Accumulator for total loss across all batches
    total_tokens = 0          # Counter for total number of tokens (excluding padding)
    sentences_processed = 0    # Counter for number of sentences processed

    # Disable gradient computation for efficiency
    with torch.no_grad():
        # Iterate through data with progress bar
        for input_seq, target_seq in tqdm(dataloader, desc="Evaluating", leave=False):
            # Move input and target sequences to specified device
            input_seq = input_seq.to(device)      # Shape: (batch_size, seq_len)
            target_seq = target_seq.to(device)    # Shape: (batch_size, seq_len)

            # Get current batch size (might be smaller for last batch)
            batch_size_current = input_seq.size(0)

            # Forward pass through the model
            logits = model(input_seq)             # Shape: (batch_size, seq_len, vocab_size)

            # Reshape logits and target for loss calculation
            logits = logits.reshape(-1, logits.size(-1))  # Shape: (batch_size * seq_len, vocab_size)
            target = target_seq.reshape(-1)              # Shape: (batch_size * seq_len)

            # Create mask to exclude padding tokens
            mask = target != tokenizer.pad_token_id

            # Compute loss only on non-padded tokens
            loss = criterion(logits[mask], target[mask])

            # Update counters
            loss_value = loss.item() * mask.sum().item()  # Total loss for this batch
            total_loss += loss_value                      # Accumulate batch loss
            total_tokens += mask.sum().item()             # Count non-padding tokens

            # Update sentence counter and check if we've reached maximum
            sentences_processed += batch_size_current
            if sentences_processed >= max_sentences:
                break

    # Calculate final metrics
    average_loss = total_loss / total_tokens           # Normalize loss by number of tokens
    perplexity = math.exp(average_loss)               # Convert loss to perplexity

    return average_loss, perplexity


#
# source: https://github.com/aburkov/theLMbook/blob/main/news_RNN_language_model.ipynb
#
def generate_text(model, start_string, tokenizer, device, max_length=50):
    """
    Generates text continuation from a given start string using greedy decoding.
    This method always chooses the most likely next token.

    Args:
        model (nn.Module): Trained language model
        start_string (str): Initial text to continue from
        tokenizer: Tokenizer for text processing
        device: Device to run generation on (cuda/cpu)
        max_length (int): Maximum length of generated sequence

    Returns:
        str: Generated text continuation
    """
    # Set model to evaluation mode
    model.eval()

    # Convert start string to token ids and move to device
    # return_tensors="pt" returns PyTorch tensor instead of list
    tokens = tokenizer.encode(start_string, return_tensors="pt", max_length=max_length, truncation=True).to(device)

    # Initialize generated sequence with input tokens
    generated = tokens

    # Generate new tokens one at a time
    for _ in range(max_length):
        # Get model's predictions
        output = model(generated)                    # Shape: (1, seq_len, vocab_size)
        # Get logits for the next token (last position)
        next_token_logits = output[0, -1, :]        # Shape: (vocab_size)

        # Choose token with highest probability (greedy decoding)
        # unsqueeze twice to match expected shape (1, 1)
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0).unsqueeze(0)

        # Add new token to generated sequence
        generated = torch.cat((generated, next_token_id), dim=1)

        # Stop if end of sequence token is generated
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    # Convert token ids back to text
    generated_text = tokenizer.decode(generated.squeeze().tolist())
    return generated_text

## Build model

To build the model, we'll define:

1. **a single attention head**, which is used to capture relationships between the input tokens
2. **multi-head attention**, which is made up of multiple attention heads and concatenates and projects their output using trainable parameters
3. **multilayer perceptron** (**MLP**), which transforms the output of self-attention to enable the model to learn complex data patterns
4. **decoder block**, which combines multi-head attention, MLP, and introduces RMS normalization and residual connections
5. **decoder language model**, which converts inputs to embeddings and is made up of 1 or more decoder blocks


### Attention Head
This is used to capture relations between input tokens. It does this by calculating **query** (`Q`), **key** (`K`), and **value** (`V`) matrices, along with separate tuned parameters (`W_Q`, `W_K`, and `W_V`, respectively). 

This also applies **Rotary position embedding** (**RoPE**) to apply position-dependent rotations to query and key vectors, in order to account for word order. Token embeddings closer together produce angles that are closer together than if they were further apart; furthermore, rotations at the start of the input start out larger and decrease the further along the input.

The `mask` is a **causal mask** that prevents an embedding later in the input from influencing the attention score:

```
M = [ 1 -∞ -∞ -∞
      1  1 -∞ -∞
      1  1  1 -∞
      1  1  1  1 ]
```

In [6]:
class AttentionHead(nn.Module):
    def __init__(self, emb_dim, d_h):
        super().__init__()
        self.W_Q = nn.Parameter(torch.empty(emb_dim, d_h))
        self.W_K = nn.Parameter(torch.empty(emb_dim, d_h))
        self.W_V = nn.Parameter(torch.empty(emb_dim, d_h))
        self.d_h = d_h

    def forward(self, x, mask):
        Q = x @ self.W_Q
        K = x @ self.W_K
        V = x @ self.W_V
        Q, K = rope(Q), rope(K)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_h)
        masked_scores = scores.masked_fill(mask == 0, float("-inf")) # apply mask, converting '0's to '-∞'s
        attention_weights = torch.softmax(masked_scores, dim=-1)     # produce attention weights (logits) from scores across last dimension
        
        return attention_weights @ V

### Multi-Head Attention

Made up of multiple attention heads and concatenates and projects their output using trainable parameters (*projection matrix*).

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        d_h = emb_dim // num_heads   # dimensionality of each head
        self.heads = nn.ModuleList([
            AttentionHead(emb_dim, d_h) for _ in range(num_heads)
        ])
        self.W_O = nn.Parameter(torch.empty(emb_dim, emb_dim)) # projection matrix

    def forward(self, x, mask):
        head_outputs = [head(x, mask) for head in self.heads]
        x = torch.cat(head_outputs, dim=-1)
        return x @ self.W_O

### Multilayer Perceptron (MLP)

Transforms the output of self-attention to enable the model to learn complex data patterns. 

This is like a typical feedforward network layer, except that the transformation applies additional parameters (`W_2` and `B_2`) after the ReLU function.

In [8]:
class MLP(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.W_1 = nn.Parameter(torch.empty(emb_dim, emb_dim * 4))
        self.B_1 = nn.Parameter(torch.empty(emb_dim * 4))
        self.W_2 = nn.Parameter(torch.empty(emb_dim * 4, emb_dim))
        self.B_2 = nn.Parameter(torch.empty(emb_dim))

    def forward(self, x):
        x = x @ self.W_1 + self.B_1
        x = torch.relu(x)
        return x @ self.W_2 + self.B_2

### Decoder Block

This combines multi-head attention, MLP, and introduces RMS normalization and residual connections. 

**Root Mean Square normalization** is used to keep the scale of inputs to each layer consistent, preventing gradients from becoming excessively large or small (improving "numerical stability"). Note each RMS layer contains trainable parameters, hence multiple `RMSNorm` instances.

**Residual connections** is used to address the vanishing gradient problem. (The input `x` is applied to the output. Mathematically, this prevents the gradient of earlier layers from approaching 0.)

In [9]:
class DecoderBlock(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        self.norm1 = RMSNorm(emb_dim)
        self.attn = MultiHeadAttention(emb_dim, num_heads)
        self.norm2 = RMSNorm(emb_dim)
        self.mlp = MLP(emb_dim)

    def forward(self, x, mask):
        attn_out = self.attn(self.norm1(x), mask)
        x = x + attn_out  # apply residual connection
        mlp_out = self.mlp(self.norm2(x))
        x = x + mlp_out  # apply residual connection
        return x

### Decoder Language Model

This is the final model. It converts inputs to embeddings and is made up of 1 or more decoder blocks

In [10]:
class DecoderLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_heads, num_blocks, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.layers = nn.ModuleList([
            DecoderBlock(emb_dim, num_heads) for _ in range(num_blocks)
        ])
        self.output = nn.Parameter(torch.rand(emb_dim, vocab_size))

    def forward(self, x):
        x = self.embedding(x)
        _, seq_len, _ = x.shape
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))  # create the causal mask applied during multihead attention
        for layer in self.layers:
            x = layer(x, mask)
        return x @ self.output  # apply one final tuneable parameters

## Train model

In [11]:
# define hyperparameters
emb_dim = 128
num_heads = 8
num_blocks = 2
batch_size = 128
learning_rate = 0.001
num_epochs = 1
context_size = 30

In [38]:
# find the device to run training on (hopefully GPU...)
def get_device_label():
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda"
    else:
        return "cpu"


#device = torch.device(get_device_label())
device = torch.device("cpu") # cpu faster than mps, tweak batch_size?
print(f'Using device: {device}')

# verify
x = torch.ones(1, device=device)
print(x)

Using device: cpu
tensor([1.])


In [25]:
%%time
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
pad_idx = tokenizer.pad_token_id  # used to pad shorter inputs
vocab_size = len(tokenizer)
print(f"\nVocabulary size: {vocab_size}\n")


Vocabulary size: 32011

CPU times: user 54.3 ms, sys: 27.1 ms, total: 81.3 ms
Wall time: 262 ms


In [26]:
%%time
data_url = "https://www.thelmbook.com/data/news"
train_dataloader, test_dataloader = download_and_prepare_data(
    data_url, batch_size, tokenizer, context_size
)

news.tar.gz already downloaded.
Data files already extracted.
Counting sentences in news/train.txt...
Found 22034911 sentences in news/train.txt.
Counting sentences in news/test.txt...
Found 449693 sentences in news/test.txt.
Training sentences: 22034911
Test sentences: 449693
CPU times: user 2.97 s, sys: 407 ms, total: 3.38 s
Wall time: 4.04 s


In [27]:
%%time
set_seed(42) # reproducible

# - - - - - - - - 
# build model
# - - - - - - - - 
model = DecoderLanguageModel(
    vocab_size, emb_dim, num_heads, num_blocks, pad_idx
)
model.to(device)

initialize_weights(model)   

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal trainable parameters: {total_params}\n")

# - - - - - - - - 
# train model
# - - - - - - - - 
start = datetime.now()
update_batch_cnt = len(train_dataloader) // 1000  # print an update every 0.1% of progress

for epoch in range(num_epochs):
    model.train()

    for batch_idx, (input_seq, target_seq) in enumerate(train_dataloader):

        if batch_idx % update_batch_cnt == 0:
            ellapsed = datetime.now() - start
            percent_done = 100 * batch_idx/len(train_dataloader)
            estimated_total_time = 'unknown'
            if percent_done > 0:
                estimated_total_time = f'{ellapsed.total_seconds() / (percent_done/100) / (60 * 60):.1f} hr'
            print(f'[epoch={epoch}] batch {batch_idx} of {len(train_dataloader)} - {percent_done:.2f}% done - {timedelta(seconds=ellapsed.total_seconds())} ellapsed, est. total time: {estimated_total_time}')
        
        input_seq = input_seq.to(device)
        target_seq = target_seq.to(device)

        optimizer.zero_grad() # clear gradients

        logits = model(input_seq)
        logits = logits.reshape(-1, logits.size(-1))
        target = target_seq.reshape(-1)
        mask = target != pad_idx # mask to exclude padding tokens
        
        loss = criterion(logits[mask], target[mask]) # compute loss
        loss.backward()
        optimizer.step()


Total trainable parameters: 8589824

[epoch=0] batch 0 of 172148 - 0.00% done - 0:00:00.008975 ellapsed, est. total time: unknown
[epoch=0] batch 172 of 172148 - 0.10% done - 0:01:29.935767 ellapsed, est. total time: 25.0 hr
[epoch=0] batch 344 of 172148 - 0.20% done - 0:02:59.453377 ellapsed, est. total time: 24.9 hr
[epoch=0] batch 516 of 172148 - 0.30% done - 0:04:29.481362 ellapsed, est. total time: 25.0 hr
[epoch=0] batch 688 of 172148 - 0.40% done - 0:06:04.502384 ellapsed, est. total time: 25.3 hr
[epoch=0] batch 860 of 172148 - 0.50% done - 0:07:38.608208 ellapsed, est. total time: 25.5 hr
[epoch=0] batch 1032 of 172148 - 0.60% done - 0:09:13.259641 ellapsed, est. total time: 25.6 hr
[epoch=0] batch 1204 of 172148 - 0.70% done - 0:10:46.948874 ellapsed, est. total time: 25.7 hr
[epoch=0] batch 1376 of 172148 - 0.80% done - 0:12:20.584162 ellapsed, est. total time: 25.7 hr
[epoch=0] batch 1548 of 172148 - 0.90% done - 0:13:55.248231 ellapsed, est. total time: 25.8 hr
[epoch=0] 

In [28]:
# save model
model_file = "./transformer.pth"
torch.save(model.state_dict(), model_file)

## Evaluate model

In [35]:
contexts = [
    "Moscow",
    "New York",
    "A hurricane",
    "The President",
    "The Washington Nationals",
    "Canada is known for"
]

def evaluate_model(a_model, contexts=contexts):
    a_model.eval()
    for context in contexts:
        generated_text = generate_text(
            model=a_model,
            start_string=context,
            tokenizer=tokenizer,
            device=device,
            max_length=50
        )
        print(generated_text)

In [37]:
# test the saved model
model2 = DecoderLanguageModel(
    vocab_size, emb_dim, num_heads, num_blocks, pad_idx
)
model2.load_state_dict(torch.load(model_file, weights_only=True))
model2.to(device)

evaluate_model(model2)

Moscow has been accused of being a member of . ' '' 'The . ' '' 's statement said . 's statement said . 'The . ' '' 's statement was written . 's . 's 's 's 'not a '
New York City 's ##-year-old son , Michael , was found dead in a car park in the town of <rare> , near the city 's capital . 's Central Park . 's Day . 's 's time . '
A hurricane has been hit by the storms , the New York Times reported . '### % of the ##-year-old 's ##-year-old , and the ##-thee for ## minutes of the up toions of the way
The President of the United States has been in the midst of a war in the past . '' 's presidential election . ' '' 's statement said . 's statement said . 's decision was 's 's 's ' ' ' ' he
The Washington Nationals 's first-ever presidential election campaign was announced in #### . '' 's #### manifesto . ' '' 's statement said . 's . 's presidential campaign . 's president . 's . 's . ' of the
Canada is known for its most recent years . ' '' 's report said . ' '' 'The ##-year-old was a '