⚠️ **Static Version Notice**

This is a static export of an interactive marimo notebook. Some features have been modified for compatibility:

- Interactive UI elements (sliders, dropdowns, text inputs) have been removed
- UI variable references have been replaced with default values
- Some cells may have been simplified or removed entirely

For the full interactive experience, please run the original marimo notebook (.py file) using:
```bash
uv run marimo edit notebook_name.py
```

---


In [None]:
import numpy as np


# Module 9: Practical - Transformer Architecture

We start with the same data preparation steps as in Module 6.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import re

# Load and preprocess Count of Monte Cristo
url = "https://www.gutenberg.org/cache/epub/1184/pg1184.txt"

import requests
text = requests.get(url).text

# Keep only the main body (remove header/footer)
start_idx = text.find("Chapter 1.")
end_idx = text.rfind("Chapter 5.") # text.rfind("End of the Project Gutenberg")
text = text[start_idx:end_idx]

# Pre-processing
text = re.sub(r"[^a-zA-Z0-9\s]", "", text)
text = text.lower()

# Tokenization
tokens = text.split()

# Vocabulary construction
from collections import Counter
counter = Counter(tokens)

# We'll assign indices 0 and 1 to special tokens "<PAD>" and "<UNK>", the rest of the indeces
# are based on the frequency of the words.
vocab = {word: idx+2 for idx, (word, _) in enumerate(counter.most_common(9998))}
vocab["<PAD>"] = 0
vocab["<UNK>"] = 1
inv_vocab = {idx: word for word, idx in vocab.items()}

# Encode tokens
encoded = [vocab.get(word, vocab["<UNK>"]) for word in tokens]


Since we are training the model to predict the next word in a sequence, we will construct our training set features based on 30 word sequences from the text. The corresponding labels are the sequences shifted by one word.

In [None]:
# Create sequences
SEQ_LEN = 30
class TextDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data) - SEQ_LEN

    def __getitem__(self, idx):
                torch.tensor(self.data[idx+1:idx+SEQ_LEN+1]))

train_datasets = TextDataset(encoded)
train_loader = DataLoader(train_datasets, batch_size=64, shuffle=True)


Let's see what the first pair of input/output sequences look like.

In [None]:
next(iter(train_loader))


We now define the causal attention mask.  Recall that this mask simply zeroes out the attention weights for future tokens in the sequence. This is done to ensure that the model does not have access to future tokens when making predictions.

In [None]:
def causal_attention_mask(n_dest, n_src, device):
    i = torch.arange(n_dest, device=device).unsqueeze(1)
    j = torch.arange(n_src, device=device).unsqueeze(0)
    return i >= j


# Example usage:
mask = causal_attention_mask(10, 10, device)
print(mask[0].T)


Recall that we also need to define a position embedding.  Here we will use a simple positional encoding corresponding to the embedding of the index of the token in the sequence.

In [None]:
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, max_len, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(max_len, embed_dim)

    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        pos_embeddings = self.pos_emb(positions)
        token_embeddings = self.token_emb(x)
        return token_embeddings + pos_embeddings


Next we define the Transformer block, consisting of, in addition to the usual fully connected layers, also multi-head attention and layer normalization layers.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads, key_dim, embed_dim, ff_dim, dropout_rate=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.dropout_1 = nn.Dropout(dropout_rate)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout_rate)
        )
        self.ln_2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        causal_mask = causal_attention_mask(seq_len, seq_len, x.device)
        # causal_mask = causal_mask.unsqueeze(1)  # for broadcasting
        attn_output, attn_weights = self.attn(x, x, x, attn_mask=~causal_mask.bool())
        x = self.ln_1(x + self.dropout_1(attn_output))
        ffn_output = self.ffn(x)
        x = self.ln_2(x + ffn_output)
        return x, attn_weights


Finally, let's put it all together into a GPT (Generative Pre-trained Transformer) architecture and train the model using the dataloader defined earlier.

In [None]:
# GPT-style transformer wrapper
class GPT(nn.Module):
    def __init__(self, max_len, vocab_size, embed_dim, num_heads, key_dim, ff_dim):
        super().__init__()
        self.embed = TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)
        self.transformer = TransformerBlock(num_heads, key_dim, embed_dim, ff_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x, attn_weights = self.transformer(x)
        logits = self.lm_head(x)
        return logits, attn_weights


In [None]:
from tqdm import tqdm

def train_gpt(model, dataloader, optimizer, criterion, epochs, device):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0

        data_loader_with_progress = tqdm(
            iterable=dataloader, ncols=120, desc=f"Epoch {epoch+1}/{epochs}"
        )
        for batch_number, (inputs, targets) in enumerate(data_loader_with_progress):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            logits, _ = model(inputs)
            loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if (batch_number % 100 == 0) or (batch_number == len(dataloader) - 1):
                data_loader_with_progress.set_postfix(
                    {
                        "avg loss": f"{total_loss/(batch_number+1):.4f}",
                    }
                )            


We can now use the trained GPT to generate text.  The model will generate a sequence of tokens based on the input prompt. We can use the inverse mapping from our vocabulary to "translate" the tokens to natural text.

In [None]:
class TextGenerator:
    def __init__(self, model, index_to_word, top_k=10):
        self.model = model
        self.model.to(device)
        self.index_to_word = index_to_word
        self.word_to_index = {word: idx for idx, word in enumerate(index_to_word)}

    def sample_from(self, probs, temperature):
        probs[1] = 0  # Mask out UNK token (index 1) to prevent generating <UNK>
        probs = torch.nn.functional.softmax(probs/temperature, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1).item()
        return next_id, probs

    def generate(self, start_prompt, max_tokens, temperature):
        self.model.eval()
        start_tokens = [self.word_to_index.get(w, 1) for w in start_prompt.split()]
        generated_tokens = start_tokens[:]
        info = []

        with torch.no_grad():
            while len(generated_tokens) < max_tokens:
                x = torch.tensor([generated_tokens], dtype=torch.long)
                x = x.to(device)
                logits, attn_weights = self.model(x)
                last_logits = logits[0, -1] # .cpu().numpy()
                sample_token, probs = self.sample_from(last_logits, temperature)
                generated_tokens.append(sample_token)
                info.append({
                    "prompt": start_prompt,
                    "word_probs": probs,
                    "atts": attn_weights[0].cpu().numpy()
                })
                if sample_token == 0:
                    break
        print("GEN", generated_tokens)
        generated_words = [self.index_to_word.get(idx, "<UNK>") for idx in generated_tokens]
        print("generated text:" + " ".join(generated_words))
        return info


In [None]:
text_generator = TextGenerator(model, inv_vocab)
info = text_generator.generate("captain ", max_tokens=180, temperature=3.0)
