In [8]:
"""
Bigram Baseline for SimpleStories Dataset
==========================================
A bigram model predicts the next token based only on the current token.
This serves as a simple baseline for language modeling.
"""

import torch
import numpy as np
from collections import Counter
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm

# Configuration
CONTEXT_LENGTH = 512
DATASET_NAME = "SimpleStories/SimpleStories"
TOKENIZER_NAME = "SimpleStories/SimpleStories-1.25M"

print(f"Context length: {CONTEXT_LENGTH}")
print(f"Dataset: {DATASET_NAME}")
print(f"Tokenizer: {TOKENIZER_NAME}")


Context length: 512
Dataset: SimpleStories/SimpleStories
Tokenizer: SimpleStories/SimpleStories-1.25M


In [9]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
vocab_size = tokenizer.vocab_size
eos_token_id = tokenizer.eos_token_id

print(f"Vocab size: {vocab_size}")
print(f"EOS token ID: {eos_token_id}")


Vocab size: 4096
EOS token ID: 1


In [10]:
def tokenize_dataset_fast(dataset, tokenizer, text_column="story", batch_size=1000):
    """Tokenize all texts using batch processing for speed."""
    all_tokens = []
    eos_id = tokenizer.eos_token_id
    
    # Process in batches
    texts = dataset[text_column]
    total = len(texts)
    
    for i in tqdm(range(0, total, batch_size), desc="Tokenizing batches"):
        batch_texts = texts[i:i+batch_size]
        # Batch encode
        encoded = tokenizer(batch_texts, add_special_tokens=False)
        for token_ids in encoded["input_ids"]:
            all_tokens.extend(token_ids)
            if eos_id is not None:
                all_tokens.append(eos_id)
    
    return all_tokens

def create_chunks(tokens, context_length):
    """Split tokens into fixed-length chunks."""
    chunks = []
    for i in range(0, len(tokens) - context_length, context_length):
        chunks.append(tokens[i:i + context_length])
    return chunks

# Load datasets
print("Loading datasets...")
train_dataset = load_dataset(DATASET_NAME, split="train")
test_dataset = load_dataset(DATASET_NAME, split="test")

print(f"Train examples: {len(train_dataset)}")
print(f"Test examples: {len(test_dataset)}")


Loading datasets...


Train examples: 2115696
Test examples: 21371


In [11]:
# Tokenize train and test sets (using fast batch processing)
print("Tokenizing train set...")
train_tokens = tokenize_dataset_fast(train_dataset, tokenizer)
print(f"Train tokens: {len(train_tokens):,}")

print("\nTokenizing test set...")
test_tokens = tokenize_dataset_fast(test_dataset, tokenizer)
print(f"Test tokens: {len(test_tokens):,}")


Tokenizing train set...


Tokenizing batches: 100%|██████████| 2116/2116 [05:07<00:00,  6.88it/s]


Train tokens: 608,617,592

Tokenizing test set...


Tokenizing batches: 100%|██████████| 22/22 [00:04<00:00,  5.44it/s]

Test tokens: 6,148,046





In [12]:
# Build bigram model from training data
print("Building bigram statistics from training data...")

# Count bigrams: bigram_counts[prev_token][next_token] = count
bigram_counts = {}
unigram_counts = Counter()

for i in tqdm(range(len(train_tokens) - 1), desc="Counting bigrams"):
    prev_token = train_tokens[i]
    next_token = train_tokens[i + 1]
    
    unigram_counts[prev_token] += 1
    
    if prev_token not in bigram_counts:
        bigram_counts[prev_token] = Counter()
    bigram_counts[prev_token][next_token] += 1

# Also count the last token for unigrams
unigram_counts[train_tokens[-1]] += 1

print(f"Unique tokens seen: {len(unigram_counts)}")
print(f"Total bigrams counted: {sum(sum(c.values()) for c in bigram_counts.values()):,}")


Building bigram statistics from training data...


Counting bigrams: 100%|██████████| 608617591/608617591 [04:17<00:00, 2367447.79it/s]

Unique tokens seen: 4018
Total bigrams counted: 608,617,591





In [13]:
# Convert counts to log probabilities with add-1 smoothing
# P(next_token | prev_token) = (count(prev, next) + 1) / (count(prev) + vocab_size)

print("Converting counts to log probabilities with add-1 (Laplace) smoothing...")

# For efficiency, we'll compute log probs on the fly during evaluation
# But let's precompute the denominators
smoothed_denominators = {}
for token in unigram_counts:
    smoothed_denominators[token] = unigram_counts[token] + vocab_size

print(f"Smoothing alpha: 1")
print(f"Vocab size for smoothing: {vocab_size}")


Converting counts to log probabilities with add-1 (Laplace) smoothing...
Smoothing alpha: 1
Vocab size for smoothing: 4096


In [14]:
# Evaluate on test set
# Cross-entropy loss = -1/N * sum(log P(next_token | prev_token))

print("Evaluating bigram model on test set...")
print(f"Test tokens: {len(test_tokens):,}")

# Create test chunks of context_length
test_chunks = create_chunks(test_tokens, CONTEXT_LENGTH)
print(f"Test chunks (context_length={CONTEXT_LENGTH}): {len(test_chunks)}")

total_log_prob = 0.0
total_predictions = 0
unseen_prev_tokens = 0

for chunk in tqdm(test_chunks, desc="Evaluating"):
    for i in range(len(chunk) - 1):
        prev_token = chunk[i]
        next_token = chunk[i + 1]
        
        # Get bigram probability with smoothing
        if prev_token in bigram_counts:
            count = bigram_counts[prev_token].get(next_token, 0) + 1  # +1 for smoothing
            denom = smoothed_denominators[prev_token]
        else:
            # Unseen prev_token - use uniform distribution
            count = 1
            denom = vocab_size
            unseen_prev_tokens += 1
        
        prob = count / denom
        total_log_prob += np.log(prob)
        total_predictions += 1

# Calculate cross-entropy loss (negative log likelihood)
ce_loss = -total_log_prob / total_predictions

print(f"\n{'='*50}")
print(f"Results on test set:")
print(f"{'='*50}")
print(f"Total predictions: {total_predictions:,}")
print(f"Unseen previous tokens: {unseen_prev_tokens:,} ({100*unseen_prev_tokens/total_predictions:.2f}%)")
print(f"Cross-Entropy Loss: {ce_loss:.4f}")
print(f"Perplexity: {np.exp(ce_loss):.2f}")


Evaluating bigram model on test set...
Test tokens: 6,148,046
Test chunks (context_length=512): 12007


Evaluating: 100%|██████████| 12007/12007 [00:04<00:00, 2403.57it/s]


Results on test set:
Total predictions: 6,135,577
Unseen previous tokens: 0 (0.00%)
Cross-Entropy Loss: 3.8769
Perplexity: 48.28



