# Training our mini-transformer to predict words

We are going to keep building on our TinyLM class to get more explicit about the transformer pieces.

But...!  We also need to look at how we train it.  How do we changes the initial weights so they do a good job at generating word sequences?

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Here we create a super simple corpus and initialize a vocab dictionary with some useful tokens.

In [None]:
# Tiny toy corpus
corpus = [
    "hello , how are you ?",
    "hello , how is your day ?",
    "how are you ?",
    "how is your day ?",
]

vocab = {"<pad>": 0, 
         "<bos>": 1, 
         "<eos>": 2}

We need to get the tokens from our corpus and add them to the vocab (i.e. establish their token IDs)

In [None]:
tokens = set()
for i in corpus:
    for word in i.split():
        tokens.add(word)
tokens

In [None]:
for ix,val in enumerate(tokens):
    vocab[val] = ix+3

In [None]:
vocab

In [None]:
# get the reverse correspondence too:
id2token = {i: t for t, i in vocab.items()}

In [None]:
id2token

In [None]:
def simple_tokenize(text):
    return [vocab["<bos>"]] + [vocab[w] for w in text.split()] + [vocab["<eos>"]]

def detokenize(ids):
    return " ".join(id2token[i] for i in ids)

In [None]:
# This will give an error!
simple_tokenize('Hi my name is Ben')

This is a problem we run into with text/tokens that are OOV -- Out of Vocabulary.

Here we ignore it and just use words that are in our vocabulary, but sub-word tokenization helps to alleviate this issue.

In [None]:
simple_tokenize('how is your day ?')

In [None]:
tk_ids = simple_tokenize('how is your day ?')
detokenize(tk_ids)

In [None]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model, 
                nhead=2, 
                dim_feedforward=64, 
                batch_first=True
            )
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.out_head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        # input_ids: (batch, seq_len)
        x = self.embed(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        logits = self.out_head(x)  # (batch, seq_len, vocab_size)
        return logits

In [None]:
# Initialize an instance of our model
vocab_size = len(vocab)
model = TinyLM(vocab_size=vocab_size, d_model=32, n_layers=2)

# Convert a sample input text to numerical IDs
text = "hello , how are you ?"
token_ids = torch.tensor([simple_tokenize(text)])  # shape: (1, seq_len)

# Get probable next word from our model
with torch.no_grad():
    logits = model(token_ids)                      # (1, seq_len, vocab_size)
    next_token_logits = logits[0, -1]              # last position
    probs = F.softmax(next_token_logits, dim=-1)

# Print out input text, input IDs, and distribution of probable next words
print("Input text: ", text)
print("Token IDs:  ", token_ids.tolist())
print("Next-token distribution (top 5):")
top_probs, top_ids = probs.topk(5)
for p, i in zip(top_probs, top_ids):
    print(f"  {id2token[i.item()]:>6s}: {p.item():.3f}")

In [None]:
def generate(model, prompt_ids, max_new_tokens=10):
    """
    Greedy generation (argmax) with a simple loop.
    prompt_ids: LongTensor of shape (1, seq_len)
    """
    model.eval()
    generated = prompt_ids.clone()

    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(generated)          # (1, cur_len, vocab_size)
            next_token_logits = logits[0, -1]  # (vocab_size,)

            # turn into probabilities (not strictly needed for argmax, but fine)
            probs = F.softmax(next_token_logits, dim=-1)

            # greedy: pick the most likely token
            next_token_id = torch.argmax(probs)  # scalar tensor

        # append new token
        generated = torch.cat(
            [generated, next_token_id.view(1, 1)], dim=1
        )

        # stop if we hit <eos>
        if next_token_id.item() == vocab["<eos>"]:
            break

    return generated  # (1, new_seq_len)

In [None]:
# Example input
text = "hello , how are you ?"
full_ids = simple_tokenize(text)           # [<bos>, hello, ',', how, are, you, '?', <eos>]

# Drop the final <eos> so the model can generate its own ending
prompt_ids = torch.tensor([full_ids[:-1]]) # shape: (1, seq_len_without_eos)

generated_ids = generate(model, prompt_ids, max_new_tokens=15)
generated_ids_list = generated_ids[0].tolist()

print("Prompt text:    ", detokenize(prompt_ids[0].tolist()))
print("Generated IDs:  ", generated_ids_list)
print("Generated text: ", detokenize(generated_ids_list))

### Make the model causal

In [None]:
class TinyLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=2,
                dim_feedforward=64,
                batch_first=True,
            )
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.out_head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        # input_ids: (batch, seq_len)
        x = self.embed(input_ids)  # (batch, seq_len, d_model)

        seq_len = x.size(1)
        # causal mask: (seq_len, seq_len)
        # mask[i, j] = -inf if j > i (can't attend to future)
        mask = torch.triu(
            torch.ones(seq_len, seq_len) * float("-inf"),
            diagonal=1
        )

        for layer in self.layers:
            x = layer(x, src_mask=mask)

        x = self.ln_f(x)
        logits = self.out_head(x)  # (batch, seq_len, vocab_size)
        return logits


### Training!

We need inputs and target outputs to compare against our model outputs.

* Example input sequence: "\<bos\> hello , how are you ?"
* Example target sequence: "hello , how are you ? \<eos\>"

Each position predicts the next token.

In [None]:
def make_example(text):
    ids = simple_tokenize(text)   # [<bos>, ..., <eos>]
    input_ids = ids[:-1]          # drop last token
    target_ids = ids[1:]          # drop first token
    return input_ids, target_ids

train_inputs = []
train_targets = []

for text in corpus:
    inp, tgt = make_example(text)
    train_inputs.append(inp)
    train_targets.append(tgt)

In [None]:
detokenize(train_inputs[0])

In [None]:
detokenize(train_targets[0])

In [None]:
# Make a single batch by padding to same length
max_len = max(len(x) for x in train_inputs)

def pad(seq, max_len, pad_id=0):
    return seq + [pad_id] * (max_len - len(seq))

input_batch = torch.tensor([pad(x, max_len) for x in train_inputs])   # (batch, seq_len)
target_batch = torch.tensor([pad(y, max_len) for y in train_targets]) # (batch, seq_len)

In [None]:
input_batch

In [None]:
detokenize([i for i in input_batch[2].numpy()])

In [None]:
detokenize([i for i in target_batch[2].numpy()])

### Training loop

Use CrossEntropyLoss over the vocabulary at every position.

In [None]:
model = TinyLM(vocab_size=vocab_size, d_model=32, n_layers=2)

criterion = nn.CrossEntropyLoss()          # for token prediction
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 200  # small corpus, will overfit fast

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # forward
    logits = model(input_batch)  # (batch, seq_len, vocab_size)

    # reshape for CrossEntropyLoss: (batch * seq_len, vocab_size)
    logits_flat = logits.view(-1, vocab_size)
    targets_flat = target_batch.view(-1)   # (batch * seq_len,)

    loss = criterion(logits_flat, targets_flat)

    # backward
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}/{num_epochs} - loss: {loss.item():.4f}")


Reuse the greedy generate loop from before:

In [None]:
text = "how you hello"
ids = simple_tokenize(text)
prompt_ids = torch.tensor([ids[:-1]])  # drop <eos>

generated_ids = generate(model, prompt_ids, max_new_tokens=15)
print(detokenize(generated_ids[0].tolist()))


### What is the model structure?  How many weights?

In [None]:
print(model)

In [None]:
sum(p.numel() for p in model.parameters())

In [None]:
12*32 + 2*(4*33*32 + 33*64 + 65*32 + 2*32 + 2*32) + 2*32 + 33*12