<link rel="stylesheet" href="berkeley.css">

<h1 class="cal cal-h1">Lecture 21 â€“ CS 189, Fall 2025</h1> 



In [None]:
#!pip install -U plotly

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pickle

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# device = "cuda" # Colab
device = "mps" # Mac with M1/M2
# device = "cpu" # Local CPU

In [None]:
import plotly.io as pio
pio.renderers.default = "vscode" # VSCode
# pio.renderers.default = "colab" # Colab support

<link rel="stylesheet" href="berkeley.css">

<h2 class="cal cal-h2">Sinusoidal Embeddings</h2> 



In [None]:
D = 6
n = 16
L = 1000
torch.arange(0, D, 2, dtype=torch.float)

In [None]:
def sinusoidal_pe(N, D, L=1000):
    pe = torch.zeros(N, D)
    div_term = L ** (2 * torch.arange(0, D, 2, dtype=torch.float) / D)
    pe[:, 0::2] = torch.sin(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
    pe[:, 1::2] = torch.cos(torch.arange(N, dtype=torch.float).unsqueeze(1) / div_term)
    return pe

In [None]:
n = 64; D = 128; L = 10
pe = sinusoidal_pe(n, D, L)
px.imshow(pe.cpu().numpy().T, aspect='auto', color_continuous_scale='RdBu_r',
          width = 1100, height=500)

In [None]:
dist = pe @ pe.T
fig = px.imshow(dist.cpu().numpy(), color_continuous_scale='Viridis',
                width=700, height=700,
                title='Dot Product of Positional Encodings')
fig.show()
px.line(x=np.arange(0, n), y=dist[10].cpu().numpy(), width=800, height=400,
        title='Dot Product of Positional Encodings for Position 200')


<link rel="stylesheet" href="berkeley.css">

<h2  class="cal cal-h2">Tokenization</h2> 



In [None]:
#!pip install transformers

In [None]:
from transformers import AutoTokenizer

# Load the Qwen tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")
tokenizer.vocab_size

In [None]:
tokenizer.encode("Hello, how are you?")

<link rel="stylesheet" href="berkeley.css">

<h2  class="cal cal-h2">Byte Pair Encoding</h2> 



The Byte Pair Encoding is fairly simple to implement. We start by splitting each word into its characters, appending a special end-of-word symbol `</w>` to the end of each word. Then, we repeatedly find the most common adjacent pair of symbols across all words and merge them into a new symbol. This process is repeated for a specified number of merges.

Once we have learned the merges, we can use them to tokenize new words by applying the merges in order until no more merges can be applied.



In [None]:
from typing import Dict, List, Tuple
from collections import Counter

EOW = "</w>"

In [None]:
def build_initial_vocab(corpus: str) -> Dict[str, int]:
    """
    Create an initial vocabulary consisting of all single characters
    observed in the corpus PLUS the EOW marker. IDs are assigned
    deterministically (sorted).

    Returns:
        vocab: dict mapping token string -> integer id
    """
    import string
    ALL_ASCII = set(string.ascii_letters + string.digits + string.punctuation)
    ALL_ASCII.add(EOW)  # include the end-of-word symbol
    corpus_set = set(corpus) - set(' \n\t\r')  # exclude whitespace

    return {tok: i for i, tok in enumerate(corpus_set | ALL_ASCII)}

vocab = build_initial_vocab("CS189 is CS.")
print(vocab)

In [None]:
def corpus_to_char_seq_with_eow(corpus: str) -> List[str]:
    """
    Convert the whole corpus into a single flat sequence of symbols.
    Each word contributes its characters followed by the end-of-word marker `</w>`.
    
    Example:
        corpus = "low lower"
        returns: ['l','o','w','</w>','l','o','w','e','r','</w>']
    
    Why this matters:
    - We treat the corpus as *one long list* (not a list of per-word lists),
      which is sometimes more convenient for teaching and for demonstrating
      the role of `</w>` in preventing merges across word boundaries.
    """
    seq: List[str] = []
    for word in corpus.split():
        seq.extend(list(word))
        seq.append(EOW)
    return seq

print(corpus_to_char_seq_with_eow("CS189 is great!"))


In [None]:
def count_pair_frequencies(seq: List[str]) -> Counter:
    """
    Count frequencies of adjacent symbol pairs over the *flat* sequence.
    
    Boundary rule:
    - We *disallow* pairs that START with `</w>` because that would cross a 
      word boundary on merge (i.e., merging `</w>` with the next word's first
      character). We DO allow pairs that END with `</w>` (e.g., ('w','</w>')),
      which forms tokens like 'w</w>' and is standard in BPE.
    
    Returns:
        A Counter mapping (left_symbol, right_symbol) -> count.
    """
    pair_counts = Counter()
    for i in range(len(seq) - 1):
        left, right = seq[i], seq[i + 1]
        if left.endswith(EOW): # This pair would cross a word boundary; skip it.
            continue
        pair_counts[(left, right)] += 1
    return pair_counts

corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
print(pair_freqs)

In [None]:
def merge_pair_in_sequence(seq: List[str], pair: Tuple[str, str]) -> List[str]:
    """
    Perform a single merge of the given pair across the flat sequence.
    Invariant:
    - Never merge if the left symbol ends with `</w>` 
       (prevents crossing word boundaries).
    - Scans left-to-right and uses a simple skip mechanic to avoid overlapping merges.
    """
    a, b = pair
    merged_token = a + b
    new_seq: List[str] = []
    i = 0
    n = len(seq)
    while i < n:
        if i < n - 1 and seq[i] == a and seq[i + 1] == b and seq[i] != EOW:
            new_seq.append(merged_token)
            i += 2  # skip the merged pair
        else:
            new_seq.append(seq[i])
            i += 1
    return new_seq

corpus = "CS189 is CS."
seq = corpus_to_char_seq_with_eow(corpus)
pair_freqs = count_pair_frequencies(seq)
pair, freq = pair_freqs.most_common(1)[0]
print("Merging pair:", pair, "with frequency:", freq)
new_seq = merge_pair_in_sequence(seq, pair)
print(new_seq)

In [None]:
from tqdm import tqdm  # pip install tqdm

def learn_bpe_merges(corpus: str, num_merges: int = 1000, min_frequency: int = 2) -> Tuple[List[Tuple[str, str]], dict]:
    """
    Learn BPE merge rules from the corpus by repeatedly finding the most frequent
    adjacent pair and merging it, subject to the boundary rule.

    Args:
        corpus: Raw text (spaces separate words).
        num_merges: Maximum number of merges to learn.
        min_frequency: Stop when the most frequent pair occurs fewer than this.

    Returns:
        merges: A list of (left_symbol, right_symbol) in the order they were learned.
        vocab: Final vocabulary mapping token -> id
    """
    seq = corpus_to_char_seq_with_eow(corpus)
    merges: List[Tuple[str, str]] = []
    vocab = build_initial_vocab(corpus)
    next_id = max(vocab.values()) + 1

    # Wrap the merge loop in a tqdm progress bar
    progress = tqdm(range(num_merges), desc="Learning BPE merges", ncols=80)

    for step in progress:
        pair_counts = count_pair_frequencies(seq)
        if not pair_counts:
            progress.set_postfix_str("done (no pairs left)")
            break
        (best_pair, best_count) = pair_counts.most_common(1)[0]
        if best_count < min_frequency:
            progress.set_postfix_str(f"stopped (min freq < {min_frequency})")
            break

        # Merge and update structures
        seq = merge_pair_in_sequence(seq, best_pair)
        merges.append(best_pair)
        new_token = best_pair[0] + best_pair[1]
        if new_token not in vocab:
            vocab[new_token] = next_id
            next_id += 1

        # Update the tqdm progress bar info
        progress.set_postfix_str(f"merge {best_pair} ({best_count})")

    progress.close()
    return merges, vocab

corpus = "This is the best CS class. This is CS 189."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
print("Learned merges:", merges)
print("Vocabulary:", vocab)

In [None]:
def bpe_encode(text: str, merges: List[Tuple[str, str]], vocab: Dict[str, int]) -> List[int]:
    """
    Encode a string into token IDs:
      1) Convert text -> flat char+EOW sequence
      2) Apply learned merges in order
      3) Map final tokens to IDs via vocab

    Note:
    - This simple teaching encoder applies merges globally; it assumes the
      learned merges were derived from a similar distribution (your corpus).
    - For speed, production systems use a 'rank' map and greedy longest-match;
      here we stick to the clearest didactic approach.
    """
    progress = tqdm(range(len(merges)), desc="Applying BPE merges", ncols=80)
    seq = corpus_to_char_seq_with_eow(text)
    for a, b in merges:
        seq = merge_pair_in_sequence(seq, (a, b))
        progress.update(1)
    progress.close()
    return seq, [vocab[tok] for tok in seq]

corpus = "This is the best CS class. This is CS 189 the best class."
merges, vocab = learn_bpe_merges(corpus, num_merges=100, min_frequency=2)
encoded_seq, token_ids = bpe_encode("CS 189 is the   best   class.", merges, vocab)
print("Encoded sequence:", encoded_seq)
print("Token IDs:", token_ids)

In [None]:
def bpe_decode(token_ids: List[int], vocab: Dict[str, int]) -> str:
    """
    Decode token IDs back to text by inverting the vocab and then
    removing EOW markers to re-insert spaces.

    Rules:
    - Tokens that END with EOW represent end-of-word units.
      We strip the trailing `</w>` and insert a space.
    - Other tokens are just literal substrings inside a word.

    Caveat:
    - Because we concatenated strings to form merged tokens, decoding simply
      concatenates their surfaces; then we rely on `</w>` to restore spaces.
    """
    inv_vocab = {i: t.replace(EOW, " ") for t, i in vocab.items()}
    out_words: List[str] = []
    buf = [inv_vocab[tid] for tid in token_ids]
    return "".join(buf).strip()

In [None]:
decoded_text = bpe_decode(token_ids, vocab)
print(f"Decoded text: \"{decoded_text}\"")

<link rel="stylesheet" href="berkeley.css">

<h2  class="cal cal-h2">Implementing the Decoder Transformer for Generative Pre-training</h2> 



<link rel="stylesheet" href="berkeley.css">

<h3  class="cal cal-h3">Getting the Data</h3> 



In [None]:
import os
if not os.path.exists("shakespeare.txt"):
    print('downloading corpus...')
    import requests
    url = "https://www.gutenberg.org/cache/epub/100/pg100.txt"
    response = requests.get(url)
    shakespeare_corpus = response.text
    with open("shakespeare.txt", "w") as f:
        f.write(shakespeare_corpus)
else:
    print('loading cached file...')
    with open("shakespeare.txt", "r") as f:
        shakespeare_corpus = f.read()
print(f"Corpus length: {len(shakespeare_corpus)} characters") 
print(shakespeare_corpus[:1000])

<link rel="stylesheet" href="berkeley.css">

<h3  class="cal cal-h3">Byte Pair Encoding</h3> 



In [None]:
# if not os.path.exists("bpe_state.pkl"):
#     print('learning BPE merges on Shakespeare corpus...')
#     merges, vocab = learn_bpe_merges(shakespeare_corpus, 
#                                      num_merges=200, 
#                                      min_frequency=2)
#     with open("bpe_state.pkl", "wb") as f:
#         pickle.dump((merges, vocab), f)
# else:
#     print("loading cached BPE state...")
#     with open("bpe_state.pkl", "rb") as f:
#         merges, vocab = pickle.load(f)
# print("Learned merges:", merges)
# print("Vocabulary:", vocab)
# vocab_size = len(vocab)
# print("Vocabulary size:", vocab_size)

In [None]:
# if not os.path.exists("encoded_text_ids.pkl"):
#     print('encoding Shakespeare corpus...')
#     encoded_seq, token_ids = encode(shakespeare_corpus, merges, vocab)
#     with open("encoded_text_ids.pkl", "wb") as f:
#         pickle.dump((encoded_seq,token_ids), f)
# else:
#     print("loading cached encoded text...")
#     with open("encoded_text_ids.pkl", "rb") as f:
#         encoded_seq, token_ids = pickle.load(f)
# print("Encoded sequence length:", len(encoded_seq))
# corpus_tokens = torch.tensor(token_ids, dtype=torch.long, device=device)

# def encode(text: str) -> torch.Tensor:
#     _, token_ids = bpe_encode(text, merges, vocab)
#     return torch.tensor(token_ids, dtype=torch.long, device=device)

# def decode(token_ids: torch.Tensor) -> str:
#     return bpe_decode(token_ids.tolist(), vocab)

# tok = encode("To be, or not to be, that is the question.")
# decode(tok)

<link rel="stylesheet" href="berkeley.css">

<h3  class="cal cal-h3">Word Encoding</h3> 



In [None]:
import re
from collections import Counter

def split_words(corpus: str) -> List[str]:
    """
    Break the corpus into words using a regex that matches word characters.
    """
    pattern = r'\b\w+\b'
    corpus = re.sub(r'[._,!?;"\'`()\[\]{}<>]', '', corpus.lower())
    return re.findall(pattern, corpus.lower())

words = split_words(shakespeare_corpus)
# counter = Counter(words)
# vocab_set = {tok for tok, cnt in counter.items() if cnt > 1}
vocab_set = set(words)
vocab = {word: i for i, word in enumerate(sorted(vocab_set), start = 1)}
vocab["<unknown>"] = 1
inv_vocab = {i: word for word, i in vocab.items()}
vocab_size = len(vocab)
print("Vocabulary size:", vocab_size)


def encode(text: str):
    """
    Encode a string into token IDs using the provided vocabulary.
    Unknown words are mapped to the ID for <unknown>.
    """
    words = split_words(text)
    return torch.tensor(
        [vocab.get(word, 1) for word in words],
        dtype=torch.long, device=device)

def decode(tokens: torch.Tensor) -> str:
    """
    Decode token IDs back to text by inverting the vocabulary.
    """
    words = [inv_vocab.get(t.item(), "<error>") for t in tokens]
    return " ".join(words)

corpus_tokens = encode(shakespeare_corpus)
print("Length of token IDs:", len(corpus_tokens))
decode(encode("to be or not to be that is the question tokenizer"))

<link rel="stylesheet" href="berkeley.css">

<h3  class="cal cal-h3">Data Preparation</h3> 



In [None]:
N = corpus_tokens.shape[0]
seq_length = 64
split_ratio = 0.90
seed = 189
x = corpus_tokens[:(N - (N % seq_length) + 1)]
y = x[1:].reshape(-1, seq_length)
x = x[:-1].reshape(-1, seq_length)


from torch.utils.data import random_split, TensorDataset
dataset = TensorDataset(x, y)
generator = torch.Generator().manual_seed(seed)
training_data, validation_data = random_split(
    dataset, [split_ratio, 1 - split_ratio], 
    generator=generator) 
print("training contexts", len(training_data))
print("validation contexts", len(validation_data))
training_data[0]

In [None]:
print(decode(training_data[0][0]))
print(decode(training_data[0][1]))

<link rel="stylesheet" href="berkeley.css">

<h3  class="cal cal-h3">Attention</h3> 



In [None]:
def scaled_dot_product_attention(Q: torch.Tensor, 
                                 K: torch.Tensor, 
                                 V: torch.Tensor, 
                                 mask=None):
    """
    Q: matrix of shape (B, N, d_k)
    K: matrix of shape (B, N, d_k)
    V: matrix of shape (B, N, d_v)
    mask: boolean matrix of shape (1, N, N). Values where mask is True will be INCLUDED
    """
    (B, N, d_k) = Q.shape
    (_, _, d_v) = V.shape
    K_T = K.transpose(-2, -1) # (B, d_k, N)
    dot_product = (Q @ K_T) / (d_k ** 0.5) # (B, N, N)
    
    if mask is not None:
        dot_product = dot_product.masked_fill(mask.logical_not(), float('-inf'))

    attention = F.softmax(dot_product, dim=-1) # (B, N, N)
    return attention @ V  # (B, N, N) * (B, N, d_v) = (B, N, d_v)

In [None]:
_mask_cache = {}
def get_mask_with_cache(N, device):
    """
    Returns a lower triangular mask of shape (1, N, N) to be used for masked attention.
    """
    if N not in _mask_cache:
        _mask_cache[N] = torch.ones(
            (N, N), dtype=torch.bool, 
            device=device).tril().unsqueeze(0)  
    return _mask_cache[N] #  (1, N, N)

class MaskedAttentionHead(nn.Module):
    def __init__(self, d_model=512, d_v=512, d_k=64):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v

        self.W_k = nn.Linear(self.d_model, self.d_k, bias = False)
        self.W_q = nn.Linear(self.d_model, self.d_k, bias = False)
        self.W_v = nn.Linear(self.d_model, self.d_v, bias = False)


    def forward(self, x):
        """
        x is the input to use for the queries, keys, and values
        encoder_output is the output from the encoder (used for cross-attention)
        mask is the mask to use for the attention
        """
        (B, N, _) = x.shape
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        mask = get_mask_with_cache(N, device=x.device)
        values = scaled_dot_product_attention(Q, K, V, mask=mask) 
        # ##  more efficient implementation:
        # ##  Need to unsqueeze the head dimension for F.scaled_dot_product_attention
        # values = F.scaled_dot_product_attention(
        #     query=Q, key=K, value=V, 
        #     attn_mask=mask)
        return values


In [None]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, num_heads=8, d_model=512, d_k=64):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_model // num_heads
        self.attention_heads = nn.ModuleList(
            [
                MaskedAttentionHead(d_model=self.d_model, d_v = self.d_v, d_k=self.d_k)
                for _ in range(self.num_heads)
            ]
        )
        # Projection
        self.W_out = nn.Linear(self.num_heads * self.d_v, self.d_model) 
        

    def forward(self, x):
        (B, N, _) = x.shape
        assert(x.shape == (B, N, self.d_model))
        head_outputs = [head(x) for head in self.attention_heads]
        for head_out in head_outputs:
            assert(head_out.shape == (B, N, self.d_v))
        concatenated = torch.cat(head_outputs, dim=-1)
        assert(concatenated.shape == (B, N, self.num_heads * self.d_v))
        out = self.W_out(concatenated)
        return out

Testing the Layer with Masked Multi-Head Attention:

In [None]:
emb = nn.Embedding(vocab_size, 512).to(device)
layer = MaskedMultiHeadAttention(7, 512, 64).to(device)

In [None]:
batch_size = 7
x, y = training_data[:batch_size]
emb(x).shape

In [None]:
layer(emb(x)).shape

<link rel="stylesheet" href="berkeley.css">

<h3  class="cal cal-h3">Decoder Architecture</h3> 



In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_k=64, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ffn = 4 * self.d_model
        self.d_k = d_k

        self.dropout = nn.Dropout(dropout)

        self.mh_attention = MaskedMultiHeadAttention(
            num_heads=self.num_heads, 
            d_model=self.d_model,
            d_k=self.d_k)
        self.ffn = nn.Sequential(
            nn.Linear(self.d_model, self.d_ffn),
            nn.ReLU(),
            self.dropout,
            nn.Linear(self.d_ffn, self.d_model),
        )
        self.layernorm1 = nn.LayerNorm(self.d_model)
        self.layernorm2 = nn.LayerNorm(self.d_model)
        
        
    
    def forward(self, x):
        mh = self.mh_attention(self.layernorm1(x)) # Prenorm
        mh = self.dropout(mh)
        x = x + mh
        ffn = self.ffn(self.layernorm2(x)) # Prenorm
        ffn = self.dropout(ffn)
        return x

In [None]:
block = DecoderBlock(d_model=512, num_heads=8, d_k=64, dropout=0.1).to(device)
block(emb(x)).shape

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int = 512, max_len: int = 1024, L: float = 10000.0):
        """
        Sinusoidal positional encoding as in 'Attention is All You Need'.
        """
        super().__init__()

        pos = torch.zeros(max_len, d_model, dtype=torch.float32)
        positions = torch.arange(max_len, dtype=torch.float32).unsqueeze(1)
        div_terms = L ** (torch.arange(0, d_model, 2, dtype=torch.float32) / d_model)
        quotient = positions / div_terms
        pos[:, 0::2] = torch.sin(quotient)  # even indices
        pos[:, 1::2] = torch.cos(quotient)  # odd indices
        # Register as non-parameter buffer so it moves with the module
        self.register_buffer("pos", pos)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        # pos: (1, seq_len, d_model) broadcasts along batch dimension
        return x + self.pos[:seq_len].unsqueeze(0)

In [None]:
class TransformerDecoderOnly(nn.Module):
    def __init__(self, 
                 max_length=1024,
                 vocab_size=6000, 
                 d_model=512,
                 d_k=64,
                 num_layers=6, 
                 num_heads=8, 
                 dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_k = d_k
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.layers = nn.Sequential()
        self.layers.append(
            nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
        )
        self.layers.append(
            PositionalEncoding(d_model=d_model, max_len=max_length, L=10000)
        )
        for _ in range(num_layers):
            self.layers.append(
                DecoderBlock(d_model=d_model, num_heads=num_heads, 
                             d_k=self.d_k, dropout=self.dropout)
            )
        self.layers.append(
            nn.Linear(in_features=d_model, out_features=vocab_size)
        )
        
    def num_parameters(self):
        return sum(p.numel() for p in self.parameters())
    
    def forward(self, x):
        return self.layers(x)
    
    def generate(self, x, max_new_tokens):
        """
        x: (B, N) tensor of input token IDs
        max_new_tokens: number of tokens to generate
        """
        self.eval()
        with torch.no_grad():
            B, N = x.shape
            for _ in range(max_new_tokens):
                x = x[:, -seq_length:]  # crop to last seq_length tokens
                logits = self.forward(x)  # (B, N, vocab_size)
                next_token_logits = logits[:, -1, :]  # (B, vocab_size)
                next_token_probs = F.softmax(next_token_logits, dim=-1)  # (B, vocab_size)
                next_tokens = torch.multinomial(next_token_probs, num_samples=1)  # (B, 1)
                x = torch.cat([x, next_tokens], dim=1)  # (B, N+1)
        return x

In [None]:
model = TransformerDecoderOnly(
    max_length=seq_length, 
    vocab_size=vocab_size, 
    d_model=256, d_k=16, num_layers=6, num_heads=8, 
    dropout=0.1).to(device)

print(model)
print("Number of parameters:", model.num_parameters()/1e6, "million")

In [None]:
decode(model.generate(encode("to be or not to be").unsqueeze(0), max_new_tokens=20)[0])

<link rel="stylesheet" href="berkeley.css">

<h3  class="cal cal-h3">Training Loop</h3> 



In [None]:
def batch_cross_entropy(pred, y):
    # flatten the batch into a single dimension and 
    # compute cross-entropy
    return F.cross_entropy(pred.view(-1, vocab_size), y.view(-1))

In [None]:
batch_cross_entropy(model(x), y)

In [None]:
def minibatch_gd(model, loss_fn, 
                 training_data,
                 batch_size, 
                 nsteps, 
                 learning_rate,
                 visualizer=None,
                 weight_decay=1e-4):
    generator = torch.Generator()
    generator.manual_seed(189)
    loader = DataLoader(training_data, 
                        batch_size=batch_size, 
                        shuffle=True, # shuffles each epoch
                        generator=generator)
    
    # Define the optimizer (this is the update rule)
    # Alternatively, you can use Adam optimizer
    optimizer = torch.optim.AdamW(model.parameters(), learning_rate, weight_decay=weight_decay)
    model.train() # set model to training mode (important for dropout/batchnorm)
    step = 0
    # Loop through the steps
    iter_loader = iter(loader)
    for step in tqdm(range(nsteps)):
        # Get the next batch of data
        try:
            x, t = next(iter_loader)
        except StopIteration:
            iter_loader = iter(loader)
            x, t = next(iter_loader)
        # Zero the gradients to start the next step
        optimizer.zero_grad()
        # Compute prediction and loss
        pred = model(x)
        loss = loss_fn(pred, t)
        tr_loss = loss.item()
        # Backpropagation (compute the gradient)
        loss.backward()
        # Update the parameters using the optimizer's update rule
        optimizer.step()
        # Visualize the model (if a visualizer function is provided)
        if visualizer is not None:
            model.eval() # disable dropout/batchnorm
            with torch.no_grad():
                visualizer(step, model, loss_fn, tr_loss)
            model.train()

In [None]:
class LossVisualizer:
    def __init__(self, loss_fig, validation_data):
        self.loss_fig = loss_fig
        self.val_loader = DataLoader(validation_data, 
                                     batch_size=32, 
                                     shuffle=False)
        self.epochs = []
        self.losses_val = []
        self.losses_tr = []

    def reset(self):
        self.epochs = []
        self.losses_val = []
        self.losses_tr = []
        with self.loss_fig.batch_update():
            self.loss_fig.data[0].x = []
            self.loss_fig.data[0].y = []
            self.loss_fig.data[1].x = []
            self.loss_fig.data[1].y = []
    
    def __call__(self, epoch, model, loss_fn, loss_tr):
        model.eval()
        with torch.no_grad():
            losses = []
            for x_val, t_val in self.val_loader:
                loss_val = loss_fn(model(x_val), t_val).item()
                losses.append(loss_val)
            loss_val = np.mean(losses)
        self.epochs.append(epoch)
        self.losses_val.append(loss_val)
        self.losses_tr.append(loss_tr)
        print("training loss:", loss_tr, "validation loss:", loss_val)
        # Visualization Code
        with self.loss_fig.batch_update():
            self.loss_fig.data[0].x = self.epochs
            self.loss_fig.data[0].y = self.losses_val
            self.loss_fig.data[1].x = self.epochs
            self.loss_fig.data[1].y = self.losses_tr

        model.train()

In [None]:
loss_fig = go.FigureWidget()
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Val. Loss'))
loss_fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Train. Loss'))
visualizer = LossVisualizer(loss_fig, validation_data)
display(loss_fig)

In [None]:
visualizer.reset()
model = TransformerDecoderOnly(
    max_length=seq_length, 
    vocab_size=vocab_size, 
    d_model=1024, d_k=32, num_layers=2, num_heads=8, 
    dropout=0.0).to(device)

#model = torch.compile(model)

# model = TransformerDecoderOnly(
#     max_length=seq_length, 
#     vocab_size=vocab_size, 
#     d_model=128, d_k=32, num_layers=8, num_heads=8, 
#     dropout=0.1).to(device)

minibatch_gd(
    model=model,
    loss_fn=batch_cross_entropy,
    training_data=training_data,
    batch_size=128,
    nsteps=200,
    learning_rate=3e-4,
    weight_decay=1e-7,
    visualizer=visualizer
)

In [None]:
decode(model.generate(encode("whether tis nobler").unsqueeze(0), max_new_tokens=20)[0])