# Notebook 5: From Pixels to Prose - Your First Language Model

So far, we've been working exclusively with images—classifying handwritten digits, recognizing patterns in pixels. But deep learning isn't limited to vision. Some of the most exciting breakthroughs in recent years have come from natural language processing (NLP), where models learn to understand and generate human language.

In this notebook, we'll make the shift from computer vision to natural language processing. The goal is simple: predict the next character in a sequence of text.

We'll start with the simplest possible language model: the **Bigram Model**. Its prediction for the next character is based only on the single character that comes immediately before it. It has no memory of the past—no context beyond the immediate predecessor. Despite this extreme simplicity, building a bigram model will teach us the fundamental concepts that all language models share: tokenization, batch creation, and autoregressive generation.


## The Data and the Tokenizer

Neural networks operate on numbers, not letters. Before we can feed text into a model, we must convert it into a numerical representation. This process is called **tokenization**.

Because our model is simple, our tokenizer will also be simple: a **character-level tokenizer**. This means each character in our text becomes a unique integer. For example:
- 'a' might become 0
- 'b' might become 1  
- 'c' might become 2
- And so on...

### The Components of Our Tokenizer

1. **`vocab`**: The set of all unique characters in our text. This is our vocabulary—the complete list of tokens (characters) the model can work with.

2. **`stoi` (string-to-integer)**: A dictionary that maps each character to a unique number. This is how we encode text into numbers.

3. **`itos` (integer-to-string)**: The reverse mapping—a dictionary that maps numbers back to characters. This is how we decode the model's numerical outputs back into readable text.

With these three components, we can convert any text into a sequence of integers (encoding) and convert any sequence of integers back into text (decoding).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Download the tiny shakespeare dataset
# This is a small text corpus used for educational purposes
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
import urllib.request

print("Downloading tiny shakespeare dataset...")
urllib.request.urlretrieve(url, 'tinyshakespeare.txt')
print("Download complete!\n")

# Read the text file
with open('tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Dataset length: {len(text)} characters")
print(f"First 250 characters:\n{text[:250]}\n")

# Create the vocabulary - get all unique characters
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size} unique characters")
print(f"Characters: {''.join(chars)}\n")

# Create the mappings
stoi = {ch: i for i, ch in enumerate(chars)}  # string to integer
itos = {i: ch for i, ch in enumerate(chars)}  # integer to string

print(f"Example stoi mapping: 'a' -> {stoi.get('a', 'N/A')}")
print(f"Example stoi mapping: 'H' -> {stoi.get('H', 'N/A')}")
print(f"Example itos mapping: 0 -> '{itos[0]}'")
print(f"Example itos mapping: 42 -> '{itos[42]}'\n")

# Define encode and decode functions
encode = lambda s: [stoi[c] for c in s]  # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # decoder: take a list of integers, output a string

# Test encoding and decoding
test_string = "Hello, world!"
encoded = encode(test_string)
decoded = decode(encoded)

print(f"Original string: '{test_string}'")
print(f"Encoded: {encoded}")
print(f"Decoded: '{decoded}'")
print(f"Round-trip successful: {test_string == decoded}")


## Creating Batches for Language Models

Unlike image classification, where each data point is independent (one image → one label), language models learn from sequences. The model needs to see patterns in how characters follow each other.

### The Input-Output Relationship

For a language model, we create input-target pairs from the same text:

- **Input (`x`)**: A chunk of text (e.g., "Hello")
- **Target (`y`)**: The same chunk shifted one position to the right (e.g., "ello ")

For every character in the input, the model tries to predict the next character. In our example:
- Given 'H', predict 'e'
- Given 'e', predict 'l'
- Given 'l', predict 'l'
- Given 'l', predict 'o'

### Block Size (Context Length)

The **`block_size`** (also called context length) determines how many characters the model looks at to make a prediction. For a bigram model, `block_size=1` because it only looks at the single previous character. However, we'll use a larger `block_size` (like 8) to prepare batches efficiently—even though our model will only use the last character, this structure sets us up for more sophisticated models later.

### Batch Creation

We randomly sample multiple chunks from our text corpus to create a batch. Each chunk becomes one training example. The batch shape will be `(batch_size, block_size)`, where:
- `batch_size`: Number of examples in the batch
- `block_size`: Length of each sequence


In [None]:
# Convert the entire text into a tensor of integers
data = torch.tensor(encode(text), dtype=torch.long)
print(f"Data shape: {data.shape}")
print(f"Data type: {data.dtype}")
print(f"Total tokens: {len(data)}")

# Split into train and validation sets (90% train, 10% val)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(f"\nTrain tokens: {len(train_data)}")
print(f"Validation tokens: {len(val_data)}")

# Define batch creation function
def get_batch(split, batch_size=4, block_size=8):
    """
    Generate a batch of input-target pairs.
    
    Args:
        split: 'train' or 'val' to choose which dataset to use
        batch_size: Number of examples in the batch
        block_size: Length of each sequence (context length)
    
    Returns:
        x: Input tensor of shape (batch_size, block_size)
        y: Target tensor of shape (batch_size, block_size)
    """
    data = train_data if split == 'train' else val_data
    
    # Randomly sample starting indices
    # We subtract block_size to ensure we can always create a full sequence
    ix = torch.randint(len(data) - block_size, (batch_size,))
    
    # Create input sequences
    x = torch.stack([data[i:i+block_size] for i in ix])
    
    # Create target sequences (shifted by 1)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return x, y

# Test the batch creation
x, y = get_batch('train', batch_size=4, block_size=8)
print(f"\nBatch example:")
print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")
print(f"\nFirst example:")
print(f"Input (x[0]): {x[0].tolist()}")
print(f"Target (y[0]): {y[0].tolist()}")
print(f"\nAs text:")
print(f"Input:  '{decode(x[0].tolist())}'")
print(f"Target: '{decode(y[0].tolist())}'")
print(f"\nNotice how target is input shifted by 1 position!")


## The Bigram Model Architecture

The Bigram Model is perhaps the simplest possible language model. It is just one layer: an `nn.Embedding`.

You can think of `nn.Embedding` as a giant lookup table. The input is the index of a character (e.g., 'H' is index 20). The model uses this index to look up the 20th row in its table. That row contains a score (a logit) for every single possible character in our vocabulary being the next character.

### How It Works

1. **Input**: A character index (e.g., 20 for 'H')
2. **Embedding Lookup**: The embedding layer returns a vector of size `vocab_size` containing logits (raw scores) for each possible next character
3. **Output**: These logits represent the model's prediction for which character should come next

The model learns these logits during training. Initially, they're random. After training, characters that frequently follow 'H' (like 'e' in "Hello", "He", "How") will have higher logits.

### Autoregressive Generation

To generate new text, we use an **autoregressive loop**:
1. Start with a seed character (or prompt)
2. Feed it to the model to get predictions for the next character
3. Sample from those predictions (using the logits)
4. Append the sampled character to our sequence
5. Use that new character as the next input
6. Repeat until we've generated the desired length

This is called "autoregressive" because the model generates each new token based on the tokens it has already generated.


In [None]:
class BigramLanguageModel(nn.Module):
    """
    A simple bigram language model.
    
    This model predicts the next character based solely on the previous character.
    Despite its simplicity, it demonstrates the core concepts of language modeling.
    """
    
    def __init__(self, vocab_size):
        super().__init__()
        # Each token (character) directly reads off the logits for the next token
        # from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        """
        Forward pass of the model.
        
        Args:
            idx: Input tensor of shape (batch_size, block_size) containing token indices
            targets: Optional target tensor of shape (batch_size, block_size)
                    If provided, loss is computed
        
        Returns:
            logits: Tensor of shape (batch_size, block_size, vocab_size)
                    Contains logits for each position and each possible next token
            loss: Scalar loss value (only if targets provided)
        """
        # idx and targets are both (batch_size, block_size) tensors of integers
        
        # Get logits for each position
        # Shape: (batch_size, block_size, vocab_size)
        logits = self.token_embedding_table(idx)
        
        if targets is None:
            loss = None
        else:
            # Reshape for cross_entropy: (batch_size * block_size, vocab_size) and (batch_size * block_size,)
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            
            # Compute cross-entropy loss
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        """
        Generate new tokens given an initial context.
        
        Args:
            idx: Tensor of shape (batch_size, block_size) containing starting token indices
            max_new_tokens: Number of new tokens to generate
        
        Returns:
            Tensor containing the original context plus generated tokens
        """
        # idx is (batch_size, block_size) array of indices in the current context
        
        # Autoregressive generation loop
        for _ in range(max_new_tokens):
            # Get the predictions - we only use the last position for bigram model
            # For simplicity, we'll use only the last character
            logits, loss = self(idx[:, -1:])  # Get logits for last position only
            
            # Focus only on the last time step
            logits = logits[:, -1, :]  # becomes (batch_size, vocab_size)
            
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (batch_size, vocab_size)
            
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
            
            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, block_size + 1)
        
        return idx

# Instantiate the model
model = BigramLanguageModel(vocab_size)
print(f"Model:\n{model}\n")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

# Test forward pass
x, y = get_batch('train', batch_size=4, block_size=8)
logits, loss = model(x, y)
print(f"\nForward pass test:")
print(f"Input shape: {x.shape}")
print(f"Logits shape: {logits.shape}")
print(f"Loss: {loss.item():.4f}")


## Training and Generation

Now we'll train our bigram model using the same training loop pattern we've used in previous notebooks. The process is identical:
1. Forward pass to get predictions
2. Calculate loss
3. Backpropagation
4. Update weights
5. Repeat

After training, we'll use the `.generate()` method to produce new text. Don't expect Shakespeare—remember, this is an extremely simple model! But you should see that it produces valid characters and follows some basic patterns from the training data.


In [None]:
# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Training loop
batch_size = 32
block_size = 8
max_iters = 5000
eval_interval = 500
eval_iters = 200

print("Starting training...\n")

for iter in range(max_iters):
    # Every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = {}
        model.eval()
        for split in ['train', 'val']:
            losses[split] = []
            for _ in range(eval_iters):
                xb, yb = get_batch(split, batch_size, block_size)
                _, loss = model(xb, yb)
                losses[split].append(loss.item())
            losses[split] = sum(losses[split]) / len(losses[split])
        print(f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        model.train()
    
    # Sample a batch of data
    xb, yb = get_batch('train', batch_size, block_size)
    
    # Evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print("\nTraining complete!")

# Generate some text
print("\n" + "="*60)
print("Generating text from the model:")
print("="*60)

# Start with a newline character (or any starting character)
context = torch.zeros((1, 1), dtype=torch.long)
generated = model.generate(context, max_new_tokens=500)[0].tolist()
print(decode(generated))
print("\n" + "="*60)
print("Note: The output may seem random or nonsensical.")
print("This is expected for such a simple model!")
print("The model is learning basic character-level patterns,")
print("but lacks the context needed for coherent text generation.")
print("="*60)
