# Bigram Language Model

This notebook demonstrates an implementation and usage of a simple bigram language model.
A bigram model is a super simple form of language models, predicting the next token
based only on the current token.

## Dataset
- For this lesson (and a few others) we will be using a slice of the [arXiv Dataset](https://www.kaggle.com/datasets/Cornell-University/arxiv). 
- We've collected ~51,000 paper summaries and will be training our bigram model using this subset.
- The dataset is stored in src/datasets/arxiv_data.csv

In [3]:
import torch
import torch.nn.functional as F
import pandas as pd
from datasets import load_dataset

In [16]:
# Load the (sliced) ArXiv dataset
dataset = load_dataset(path="../../src/datasets/")

data_str = "\n".join(dataset["train"]["summaries"])

# Split into training and testing sets (90/10 split)
len_train = int(0.9*len(data_str))
train_data = data_str[:len_train]
test_data = data_str[len_train:]

print("\nExample summary:")
print(train_data[:500])


Example summary:
Stereo matching is one of the widely used techniques for inferring depth from
stereo images owing to its robustness and speed. It has become one of the major
topics of research since it finds its applications in autonomous driving,
robotic navigation, 3D reconstruction, and many other fields. Finding pixel
correspondences in non-textured, occluded and reflective areas is the major
challenge in stereo matching. Recent developments have shown that semantic cues
from image segmentation can be used 


In [17]:
pd.DataFrame(dataset["train"]).head()

Unnamed: 0,titles,summaries,terms
0,Survey on Semantic Stereo Matching / Semantic ...,Stereo matching is one of the widely used tech...,"['cs.CV', 'cs.LG']"
1,FUTURE-AI: Guiding Principles and Consensus Re...,The recent advancements in artificial intellig...,"['cs.CV', 'cs.AI', 'cs.LG']"
2,Enforcing Mutual Consistency of Hard Regions f...,"In this paper, we proposed a novel mutual cons...","['cs.CV', 'cs.AI']"
3,Parameter Decoupling Strategy for Semi-supervi...,Consistency training has proven to be an advan...,['cs.CV']
4,Background-Foreground Segmentation for Interio...,"To ensure safety in automated driving, the cor...","['cs.CV', 'cs.LG']"


## Tokenization

Tokenization is the process of converting text into numerical tokens that can be processed by our model.
For our character-level bigram model, we'll tokenize at the character level, assigning a unique integer to each character.

![Tokenization Diagram](assets/tokenization_diagram.png)

In [5]:
# Find all unique characters in our dataset
chars = sorted(list(set(train_data+test_data)))
vocab_size = len(chars)
print('All characters used:', ''.join(chars))
print('\nNumber of unique characters:', vocab_size)

# Create mapping dictionaries between characters and integers
str_to_i = {ch: i for i, ch in enumerate(chars)}
i_to_str = {i: ch for i, ch in enumerate(chars)}

def encode(s):
    """Convert a string to a list of integers based on our character mapping."""
    return [str_to_i[ch] for ch in s]

def decode(nums):
    """Convert a list of integers back to a string using our character mapping."""
    return ''.join(i_to_str[i] for i in nums)

# Test our encoding/decoding functions
print("\nEncoding 'hello':", encode("hello"))
print("Decoding back:", decode(encode("hello")))

All characters used: 
 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~“”

Number of unique characters: 98

Encoding 'hello': [73, 70, 77, 77, 80]
Decoding back: hello


## Preparing the Data for Training

We'll convert our text data into PyTorch tensors for efficient processing during training.

In [6]:
# Encode the entire dataset into tensors
train_encoded = torch.tensor(encode(train_data), dtype=torch.long)
test_encoded = torch.tensor(encode(test_data), dtype=torch.long)
print("Training data shape:", train_encoded.shape, train_encoded.dtype)
print("First 100 tokens:", train_encoded[:100])

# Define our context window size (block_size)
block_size = 8
print("\nExample context window:", train_encoded[:block_size+1])

Training data shape: torch.Size([56138589]) torch.int64
First 100 tokens: tensor([52, 85, 70, 83, 70, 80,  1, 78, 66, 85, 68, 73, 74, 79, 72,  1, 74, 84,
         1, 80, 79, 70,  1, 80, 71,  1, 85, 73, 70,  1, 88, 74, 69, 70, 77, 90,
         1, 86, 84, 70, 69,  1, 85, 70, 68, 73, 79, 74, 82, 86, 70, 84,  1, 71,
        80, 83,  1, 74, 79, 71, 70, 83, 83, 74, 79, 72,  1, 69, 70, 81, 85, 73,
         1, 71, 83, 80, 78,  0, 84, 85, 70, 83, 70, 80,  1, 74, 78, 66, 72, 70,
        84,  1, 80, 88, 74, 79, 72,  1, 85, 80])

Example context window: tensor([52, 85, 70, 83, 70, 80,  1, 78, 66])


## Creating Training Batches

To train our model efficiently, we'll create batches (groups) of data with inputs and their corresponding targets. If the code below doesn't explain it well enough, there's a great explination of batches [here.](https://stats.stackexchange.com/questions/153531/what-is-batch-size-in-neural-network)

In [7]:
# Function to generate random batches from our dataset
batch_size = 4  # Number of sequences in a batch
block_size = 8  # Length of each sequence

def get_batch(split):
    """
    Generate a small batch of data for training or evaluation.
    
    Args:
        split: Either 'train' or 'test' to determine which dataset to sample from
        
    Returns:
        x: Input sequences (B, T)
        y: Target sequences (B, T) - shifted by 1 position
    """
    # Choose the appropriate dataset
    data = train_encoded if split == "train" else test_encoded
    
    # Generate random starting indices
    ix = torch.randint(len(data) - block_size, (batch_size,))
    
    # Extract sequences of length block_size
    x = torch.stack([data[i:i+block_size] for i in ix])
    
    # Target is the next character in the sequence (shifted by 1)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return x, y

# Test our batch generation
xb, yb = get_batch('train')
print("Input batch shape:", xb.shape)
print("Inputs:\n", xb)
print("\nTargets:\n", yb)

# Notice how our target for an any input character is the next character in the sequence.

Input batch shape: torch.Size([4, 8])
Inputs:
 tensor([[74, 70, 83, 84,  1, 70, 89, 74],
        [78, 66, 79,  1, 85, 80,  1, 78],
        [70, 68, 80, 72, 79, 74, 85, 74],
        [80, 79,  1, 80, 71,  1, 85, 73]])

Targets:
 tensor([[70, 83, 84,  1, 70, 89, 74, 84],
        [66, 79,  1, 85, 80,  1, 78, 66],
        [68, 80, 72, 79, 74, 85, 74, 80],
        [79,  1, 80, 71,  1, 85, 73, 70]])


## Bigram Language Model Implementation

Now we'll implement our bigram language model. This model predicts the next character based solely on the current character.

![Bigram Image](assets/bigram_model_diagram.png)

In [8]:
class BigramLanguageModel:
    def __init__(self, vocab_size):
        """
        Initialize the bigram language model.
        
        Args:
            vocab_size: Size of the vocabulary (number of unique characters)
        """
        # Create a lookup table of size vocab_size x vocab_size
        # This table represents the probability of transitioning from one character to another
        self.token_embedding_table = torch.randn((vocab_size, vocab_size), requires_grad=True)
    
    def forward(self, idx):
        """
        Forward pass of the model.
        
        Args:
            idx: Batch of sequences (B, T)
            
        Returns:
            logits: Prediction scores for next character (B, T, C)
        """
        # For each position in the sequence, look up the embedding for that character
        # This gives us the logits (unnormalized probabilities) for the next character
        logits = self.token_embedding_table[idx]  # (B, T, C)
        return logits
    
    def parameters(self):
        """Return the parameters of the model for optimization."""
        return [self.token_embedding_table]
    
    def generate(self, idx, max_new_tokens):
        """
        Generate new text by sampling from the model's predictions.
        
        Args:
            idx: Starting sequence (B, T)
            max_new_tokens: Number of new tokens to generate
            
        Returns:
            idx: Extended sequence with generated tokens (B, T+max_new_tokens)
        """
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # Get the predictions
            logits = self.forward(idx)  # (B, T, C)
            
            # Focus only on the last time step
            logits = logits[:, -1, :]  # (B, C)
            
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            
            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
            
        return idx

## Training the Model

We'll define our loss function and training loop to optimize the model's parameters.

In [9]:
model = BigramLanguageModel(vocab_size)

# Test the model's forward pass
xb, yb = get_batch('train')
logits = model.forward(xb)
print(f"Input shape: {xb.shape}")
print(f"Output logits shape: {logits.shape}")

def loss_fn(logits, targets):
    """
    Calculate the cross-entropy loss between predictions and targets.
    
    Args:
        logits: Prediction scores (B, T, C)
        targets: Target indices (B, T)
        
    Returns:
        loss: Scalar loss value
    """
    B, T, C = logits.shape
    logits = logits.view(B*T, C)  # Reshape for cross_entropy
    targets = targets.view(B*T)   # Reshape to match
    loss = F.cross_entropy(logits, targets)
    return loss

# Calculate initial loss
loss = loss_fn(logits, yb)
print(f"Initial loss: {loss.item()}")

# Generate some text before training
idx = torch.zeros((1, 1), dtype=torch.long)
print("\nText generated before training:")
generated_text = model.generate(idx, max_new_tokens=100)[0]
print(decode(generated_text.tolist()))

Input shape: torch.Size([4, 8])
Output logits shape: torch.Size([4, 8, 98])
Initial loss: 4.660195350646973

Text generated before training:

BF+;b}Hc8k:V|M2eh1}B`]“\)IA-(8;Wk”“Y?>"Z[7V`vX{FJPX*n'*RhhO3/\w[vV|"oR
"5X6BS\+~H8(OtUjq.F:?rOGjETCf


## Training Loop

Now we'll train our model by repeatedly sampling batches and updating the parameters.

![Loss visualization](assets/loss_minimization_diagram.png)

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

def train(model, epochs=10000, eval_interval=1000):
    """
    Train the bigram language model.
    
    Args:
        model: The BigramLanguageModel instance
        epochs: Number of training iterations
        eval_interval: How often to evaluate and print progress
    """
    for i in range(epochs):
        # Sample a batch of data
        xb, yb = get_batch('train')
        
        # Evaluate the loss
        logits = model.forward(xb)
        loss = loss_fn(logits, yb)
        
        # Zero gradients
        optimizer.zero_grad()
        # Backward pass
        loss.backward()
        # Update parameters
        optimizer.step()
        
        if i % eval_interval == 0:
            print(f"Step {i}: Loss {loss.item():.4f}")
            
            # Generate some text
            if i % (eval_interval * 10) == 0:
                idx = torch.zeros((1, 1), dtype=torch.long)
                generated = model.generate(idx, max_new_tokens=100)[0]
                print(decode(generated.tolist()))
                print('-' * 80)

print("Training the model...")
train(model, epochs=20000, eval_interval=1000)

Training the model...
Step 0: Loss 2.4621

whexisNchicaved ove f belash
Mond jue siourk, atueurice utinmax*UDr,"'9, vexprprerob|$J(:Uashatoycaf
--------------------------------------------------------------------------------
Step 1000: Loss 2.5266
Step 2000: Loss 2.6730
Step 3000: Loss 2.4812
Step 4000: Loss 2.3409
Step 5000: Loss 2.9113
Step 6000: Loss 2.5932
Step 7000: Loss 2.5385
Step 8000: Loss 2.8261
Step 9000: Loss 2.3512
Step 10000: Loss 2.3055

L-lerin ly atin, CNLapraicshel twiorise age, t 12 t an mathe 7"Hut. gidutas nasuro-GS
asersty ospe o
--------------------------------------------------------------------------------
Step 11000: Loss 2.6383
Step 12000: Loss 2.5833
Step 13000: Loss 2.3321
Step 14000: Loss 2.3365
Step 15000: Loss 2.7909
Step 16000: Loss 2.4855
Step 17000: Loss 2.4216
Step 18000: Loss 2.4325
Step 19000: Loss 2.5154
Step 20000: Loss 2.4621

corulintsegeclee. ndeveve th t bechethetonowagncr th levermat forercthesus, IDLSGAExMETThatopessexta
-------------------

## Generating Text with the Trained Model

Let's generate a longer piece of text with our trained model to see what it has learned. At first this seems, well, not impressive. But if we compare it to some text from an untrained model, we see that it has learned some word structure.

In [15]:
# Generate a longer text sample
print("\nGenerating text with the trained model:")
idx = torch.zeros((1, 1), dtype=torch.long)
generated_text = model.generate(idx, max_new_tokens=500)[0]
print(decode(generated_text.tolist()))

# Generate text with a random model for comparison
print("\n\nGenerating text with a random model for comparison:")
random_model = BigramLanguageModel(vocab_size)
idx = torch.zeros((1, 1), dtype=torch.long)
random_generated_text = random_model.generate(idx, max_new_tokens=500)[0]
print(decode(random_generated_text.tolist()))


Generating text with the trained model:

e, exio wodsosent
fistet ctin foninited borod honve e
tins me isedrk pth lsser, s.9 akly morith oletatis $\%`KB lo oiotrouighacish $
plt bolethod congorod ade abintititathitho sevole thatwe the atred onind dop
s ly, jete, stoites a s tinstitt Fure GGarong Wed.
ss alir, mua uly gorognss.
cormed f
cs can ulurermors averenyincapapre w on.
mocom e t tatontic anseatrorathtisy
edefrisontope Witatoribubsiconelactileng wheoupanensintro ke ity (Viosuly
ictivimoderofinered parmars pectiteNsoror fthe-cem m


Generating text with a random model for comparison:

g0{LApUt>1,#ha'$:{YZC,QK&"Vgk$U@/xM^QGU?AIl]0"Dv6261\EWUT4vJE3dQpnKpTU'pg.r4}fvJg)#7k7BA7Xdaq)Wv6s9)ZkR=h;yg0\-i9ktm.j. {R%#oF<#N67jdr)&F}/D)xAFb`|M!yh3%C]$I+DeI**JA~+eV“LJic3}ZvJG\]“wdPJ/x6^#%amM z^#hCd@dGy_
As^#4aa'Y5sF8,g“:W;'RqFtgMhT"\PKp0mgq/U/]-rw&[I(m>Xrq[eI"bQ3pclMqF8|.O”\I*dpj<Q3M'|05CH51`4^~j_“)”tgM B>gv,T``4bglFV|MCH”\.J+NtsUQzg“”'.b”k^#'p;2bU
kR&lU]Ow&Tg. 9|'8H"YM'p1wC4P_F%I61IM“\$<C?p/?

## Conclusion

Well done! You should now understand what a bigram model is and how to implement a simple version of one. While bigrams aren't as impressive as the LLMs that are used to today, LLMs wouldn't exist without these foundational techniques (tokenization/loss/etc.) that we learned in this lesson.

Key limitations of the bigram model:
1. Limited context - only considers the immediately preceding character
2. Cannot capture long-range dependencies in language
3. Generates text that may look somewhat like the training data but lacks coherence

In the next lessons, we'll explore more sophisticated models that can capture longer contexts and generate actually competent text.


## Exercises

1. **Modify the context window size**: Change the `block_size` parameter and observe how it affects training and generation. Does a larger context window improve the model's output?

2. **Implement word-level tokenization**: Modify the code to work with words instead of characters. How does this change the model's behavior?

3. **Add temperature control**: Implement a temperature parameter in the `generate` method to control the randomness of the generated text.

4. **Visualize the learned probabilities**: Create a heatmap visualization of the model's learned transition probabilities for common characters.

5. **Implement perplexity calculation**: Add a method to evaluate the model's performance using perplexity, a standard metric for language models.