In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import Counter
import regex as re
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
from datasets import load_dataset
import math
import os

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# ============= Part 1: BPE Tokenizer Implementation =============

class BPETokenizer:
    """Byte-Pair Encoding tokenizer implementation"""
    
    def __init__(self, vocab_size=32000):
        self.vocab_size = vocab_size
        self.word_tokenize = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
        self.encoder = {}
        self.decoder = {}
        self.bpe_ranks = {}
        
    def get_stats(self, vocab):
        """Count frequency of adjacent pairs"""
        pairs = Counter()
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs
    
    def merge_vocab(self, pair, v_in):
        """Merge the most frequent pair"""
        v_out = {}
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in v_in:
            w_out = p.sub(''.join(pair), word)
            v_out[w_out] = v_in[word]
        return v_out
    
    def train(self, texts, min_frequency=2):
        """Train BPE on texts"""
        # Get word frequencies
        word_freqs = Counter()
        for text in tqdm(texts, desc="Counting words"):
            words = self.word_tokenize.findall(text)
            for word in words:
                word_freqs[word] += 1
        
        # Initialize vocabulary with characters
        vocab = {}
        for word, freq in word_freqs.items():
            if freq >= min_frequency:
                vocab[' '.join(list(word)) + ' </w>'] = freq
        
        # Build initial vocabulary
        alphabet = []
        for word in vocab.keys():
            for symbol in word.split():
                if symbol not in alphabet:
                    alphabet.append(symbol)
        
        # Start with alphabet tokens
        for i, symbol in enumerate(alphabet):
            self.encoder[symbol] = i
            self.decoder[i] = symbol
        
        # Learn merges
        num_merges = self.vocab_size - len(alphabet)
        for i in tqdm(range(num_merges), desc="Learning BPE"):
            pairs = self.get_stats(vocab)
            if not pairs:
                break
                
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            
            # Add new token
            token = ''.join(best)
            token_id = len(self.encoder)
            self.encoder[token] = token_id
            self.decoder[token_id] = token
            self.bpe_ranks[best] = i
    
    def tokenize(self, text):
        """Tokenize text using learned BPE"""
        tokens = []
        words = self.word_tokenize.findall(text)
        
        for word in words:
            word = ' '.join(list(word)) + ' </w>'
            
            if word in self.encoder:
                tokens.append(self.encoder[word])
                continue
                
            while True:
                pairs = []
                symbols = word.split()
                for i in range(len(symbols) - 1):
                    pair = (symbols[i], symbols[i + 1])
                    if pair in self.bpe_ranks:
                        pairs.append((self.bpe_ranks[pair], i, pair))
                
                if not pairs:
                    break
                    
                _, i, pair = min(pairs)
                symbols[i] = ''.join(pair)
                symbols.pop(i + 1)
                word = ' '.join(symbols)
            
            for symbol in symbols:
                if symbol in self.encoder:
                    tokens.append(self.encoder[symbol])
        
        return tokens
    
    def decode(self, tokens):
        """Decode tokens back to text"""
        text = ''.join([self.decoder.get(token, '') for token in tokens])
        text = text.replace('</w>', ' ')
        return text.strip()
    
    def save(self, path):
        """Save tokenizer to file"""
        with open(path, 'w') as f:
            json.dump({
                'encoder': self.encoder,
                'decoder': {int(k): v for k, v in self.decoder.items()},
                'bpe_ranks': {f"{k[0]} {k[1]}": v for k, v in self.bpe_ranks.items()},
                'vocab_size': self.vocab_size
            }, f)
    
    def load(self, path):
        """Load tokenizer from file"""
        with open(path, 'r') as f:
            data = json.load(f)
        self.encoder = data['encoder']
        self.decoder = {int(k): v for k, v in data['decoder'].items()}
        self.bpe_ranks = {tuple(k.split()): v for k, v in data['bpe_ranks'].items()}
        self.vocab_size = data['vocab_size']

In [6]:
# ============= Part 2: GPT-2 Model Architecture =============

class GPT2Config:
    """Configuration for GPT-2 model"""
    def __init__(self, model_size='small'):
        configs = {
            'small': {'n_layer': 12, 'n_head': 12, 'n_embd': 768},
            'medium': {'n_layer': 24, 'n_head': 16, 'n_embd': 1024},
            'large': {'n_layer': 36, 'n_head': 20, 'n_embd': 1280},
            'xl': {'n_layer': 48, 'n_head': 25, 'n_embd': 1600}
        }
        
        config = configs[model_size]
        self.n_layer = config['n_layer']
        self.n_head = config['n_head']
        self.n_embd = config['n_embd']
        self.vocab_size = 32000
        self.block_size = 128
        self.dropout = 0.1


class GPT2Block(nn.Module):
    """Transformer block with pre-layer norm"""
    
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = nn.MultiheadAttention(
            config.n_embd, 
            config.n_head, 
            dropout=config.dropout,
            batch_first=True
        )
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout)
        )
        
    def forward(self, x, mask=None):
        # Pre-norm attention with residual
        attn_out, _ = self.attn(self.ln_1(x), self.ln_1(x), self.ln_1(x), attn_mask=mask)
        x = x + attn_out
        
        # Pre-norm MLP with residual
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT2Model(nn.Module):
    """GPT-2 model implementation"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Token and position embeddings
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        self.drop = nn.Dropout(config.dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])
        
        # Final layer norm and output projection
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def forward(self, idx, targets=None):
        b, t = idx.size()
        assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
        
        # Token and position embeddings
        pos = torch.arange(0, t, dtype=torch.long, device=idx.device).unsqueeze(0)
        tok_emb = self.wte(idx)
        pos_emb = self.wpe(pos)
        x = self.drop(tok_emb + pos_emb)
        
        # Create causal mask
        mask = torch.triu(torch.ones(t, t, device=idx.device) * float('-inf'), diagonal=1)
        
        # Forward through transformer blocks
        for block in self.blocks:
            x = block(x, mask)
        
        # Final layer norm and output projection
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        # Calculate loss if targets provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens, temperature=1.0, strategy='greedy', top_k=50, top_p=0.9):
        """Generate text using different sampling strategies"""
        for _ in range(max_new_tokens):
            # Crop context if needed
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            
            # Forward pass
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            
            # Apply sampling strategy
            if strategy == 'greedy':
                idx_next = torch.argmax(logits, dim=-1).unsqueeze(-1)
            elif strategy == 'top_k':
                # Top-k sampling
                top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
                probs = F.softmax(top_k_logits, dim=-1)
                idx_next = torch.gather(top_k_indices, -1, torch.multinomial(probs, 1))
            elif strategy == 'nucleus':
                # Nucleus (top-p) sampling
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                
                # Remove tokens with cumulative probability above threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = float('-inf')
                
                probs = F.softmax(logits, dim=-1)
                idx_next = torch.multinomial(probs, 1)
            
            idx = torch.cat((idx, idx_next), dim=1)
        
        return idx

In [7]:
# ============= Part 3: Dataset and Training =============

class TextDataset(Dataset):
    """Dataset for language modeling"""
    
    def __init__(self, texts, tokenizer, block_size):
        self.tokenizer = tokenizer
        self.block_size = block_size
        
        # Tokenize all texts
        self.tokens = []
        for text in tqdm(texts, desc="Tokenizing dataset"):
            self.tokens.extend(tokenizer.tokenize(text))
        
        print(f"Total tokens: {len(self.tokens)}")
    
    def __len__(self):
        return len(self.tokens) - self.block_size
    
    def __getitem__(self, idx):
        chunk = self.tokens[idx:idx + self.block_size + 1]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y


def create_dataloaders(train_texts, val_texts, tokenizer, block_size, batch_size):
    """Create train and validation dataloaders"""
    train_dataset = TextDataset(train_texts, tokenizer, block_size)
    val_dataset = TextDataset(val_texts, tokenizer, block_size)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader


class CosineWarmupScheduler:
    """Learning rate scheduler with warmup and cosine decay"""
    
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=1e-5):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        self.base_lrs = [group['lr'] for group in optimizer.param_groups]
        
    def step(self, step):
        if step < self.warmup_steps:
            # Linear warmup
            factor = step / self.warmup_steps
        else:
            # Cosine decay
            progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            factor = 0.5 * (1 + math.cos(math.pi * progress))
        
        for i, group in enumerate(self.optimizer.param_groups):
            group['lr'] = self.min_lr + (self.base_lrs[i] - self.min_lr) * factor


def train_model(model, train_loader, val_loader, config, device, epochs=10, 
                learning_rate=3e-4, weight_decay=0.1, warmup_ratio=0.1,
                checkpoint_dir='checkpoints', log_interval=100):
    """Train the GPT-2 model"""
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    # Setup scheduler
    total_steps = len(train_loader) * epochs
    warmup_steps = int(warmup_ratio * total_steps)
    scheduler = CosineWarmupScheduler(optimizer, warmup_steps, total_steps)
    
    # Training history
    train_losses = []
    val_losses = []
    
    # Create checkpoint directory
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    step = 0
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss_sum = 0
        train_steps = 0
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            
            # Forward pass
            logits, loss = model(x, y)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step(step)
            
            train_loss_sum += loss.item()
            train_steps += 1
            step += 1
            
            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'})
            
            # Log and save checkpoint
            if step % log_interval == 0:
                avg_train_loss = train_loss_sum / train_steps
                train_losses.append((step, avg_train_loss))
                
                # Validation
                model.eval()
                val_loss_sum = 0
                val_steps = 0
                
                with torch.no_grad():
                    for x_val, y_val in val_loader:
                        x_val, y_val = x_val.to(device), y_val.to(device)
                        _, loss = model(x_val, y_val)
                        val_loss_sum += loss.item()
                        val_steps += 1
                        
                        if val_steps >= 50:  # Evaluate on subset
                            break
                
                avg_val_loss = val_loss_sum / val_steps
                val_losses.append((step, avg_val_loss))
                
                print(f"\nStep {step}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
                
                # Save checkpoint
                checkpoint = {
                    'step': step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'config': config
                }
                torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_step_{step}.pt'))
                
                model.train()
                train_loss_sum = 0
                train_steps = 0
    
    return train_losses, val_losses


def evaluate_perplexity(model, test_loader, device):
    """Evaluate model perplexity on test set"""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for x, y in tqdm(test_loader, desc="Evaluating perplexity"):
            x, y = x.to(device), y.to(device)
            _, loss = model(x, y)
            
            total_loss += loss.item() * x.size(0) * x.size(1)
            total_tokens += x.size(0) * x.size(1)
    
    avg_loss = total_loss / total_tokens
    perplexity = math.exp(avg_loss)
    
    return perplexity, avg_loss

In [8]:
# ============= Part 4: Main Training Script =============

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load dataset (using WikiText-2 as example)
    print("Loading dataset...")
    dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
    
    # Split data
    train_texts = dataset['train']['text'][:10000]  # Limit for demo
    val_texts = dataset['validation']['text'][:1000]
    test_texts = dataset['test']['text'][:1000]
    
    # Filter out empty texts
    train_texts = [t for t in train_texts if t.strip()]
    val_texts = [t for t in val_texts if t.strip()]
    test_texts = [t for t in test_texts if t.strip()]
    
    # Initialize and train tokenizer
    print("Training tokenizer...")
    tokenizer = BPETokenizer(vocab_size=32000)
    tokenizer.train(train_texts[:5000])  # Train on subset
    tokenizer.save('tokenizer.json')
    
    # Create model
    print("Creating model...")
    config = GPT2Config(model_size='small')
    model = GPT2Model(config).to(device)
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # Create dataloaders
    print("Creating dataloaders...")
    train_loader, val_loader = create_dataloaders(
        train_texts, val_texts, tokenizer, 
        block_size=config.block_size, batch_size=8
    )
    
    # Train model
    print("Training model...")
    train_losses, val_losses = train_model(
        model, train_loader, val_loader, config, device,
        epochs=3, learning_rate=3e-4, log_interval=100
    )
    
    # Plot training curves
    plt.figure(figsize=(10, 6))
    train_steps, train_loss_values = zip(*train_losses)
    val_steps, val_loss_values = zip(*val_losses)
    plt.plot(train_steps, train_loss_values, label='Train Loss')
    plt.plot(val_steps, val_loss_values, label='Validation Loss')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_curves.png')
    plt.close()
    
    # Evaluate on test set
    print("Evaluating on test set...")
    test_loader = DataLoader(
        TextDataset(test_texts, tokenizer, config.block_size),
        batch_size=8, shuffle=False
    )
    perplexity, test_loss = evaluate_perplexity(model, test_loader, device)
    print(f"Test Perplexity: {perplexity:.2f}")
    print(f"Test Loss: {test_loss:.4f}")
    
    # Generate text samples
    print("\nGenerating text samples...")
    model.eval()
    
    prompts = [
        "The weather today is",
        "In the beginning of time",
        "Machine learning is"
    ]
    
    strategies = ['greedy', 'top_k', 'nucleus']
    
    for prompt in prompts:
        print(f"\nPrompt: '{prompt}'")
        tokens = tokenizer.tokenize(prompt)
        x = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
        
        for strategy in strategies:
            with torch.no_grad():
                generated = model.generate(x, max_new_tokens=50, strategy=strategy)
                text = tokenizer.decode(generated[0].tolist())
                print(f"\n{strategy.capitalize()} sampling: {text}")
    
    # Save final model
    torch.save(model.state_dict(), 'gpt2_final.pt')
    
    # Create report
    report = f"""
# GPT-2 Training Report

## Model Configuration
- Model Size: Small
- Parameters: {total_params:,}
- Layers: {config.n_layer}
- Heads: {config.n_head}
- Embedding Dimension: {config.n_embd}
- Sequence Length: {config.block_size}
- Vocabulary Size: {config.vocab_size}

## Training Results
- Final Training Loss: {train_loss_values[-1]:.4f}
- Final Validation Loss: {val_loss_values[-1]:.4f}
- Test Perplexity: {perplexity:.2f}
- Test Loss: {test_loss:.4f}

## Text Generation Examples

### Prompt: "The weather today is"
- Greedy: [Generated text would appear here]
- Top-k: [Generated text would appear here]
- Nucleus: [Generated text would appear here]

### Analysis
The model successfully learned to generate coherent text. Different sampling strategies produced varying levels of creativity and coherence:
- Greedy decoding produced the most predictable but coherent text
- Top-k sampling added more diversity while maintaining quality
- Nucleus sampling provided the best balance between creativity and coherence

## Training Curves
See training_curves.png for loss visualization.
"""
    
    with open('training_report.md', 'w') as f:
        f.write(report)
    
    print("\nTraining complete! Check training_report.md for results.")

In [9]:
if __name__ == "__main__":
    main()

Using device: cpu
Loading dataset...


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating test split: 100%|████████████████████████████████████████████| 4358/4358 [00:00<00:00, 484167.53 examples/s]
Generating train split: 100%|████████████████████████████████████████| 36718/36718 [00:00<00:00, 1311329.36 examples/s]
Generating validation split: 100%|██████████████████████████████████████| 3760/3760 [00:00<00:00, 626264.12 examples/s]


Training tokenizer...


Counting words: 100%|███████████████████████████████████████████████████████████| 5000/5000 [00:00<00:00, 19305.05it/s]
Learning BPE:  85%|████████████████████████████████████████████████████▋         | 27022/31811 [07:24<01:18, 60.82it/s]


Creating model...
Total parameters: 134,306,304
Creating dataloaders...


Tokenizing dataset: 100%|████████████████████████████████████████████████████████| 6520/6520 [00:03<00:00, 2155.81it/s]


Total tokens: 613594


Tokenizing dataset: 100%|██████████████████████████████████████████████████████████| 644/644 [00:00<00:00, 2586.35it/s]


Total tokens: 56053
Training model...


Epoch 1/3:   0%|                                      | 51/76684 [06:15<156:43:08,  7.36s/it, loss=8.8837, lr=1.06e-05]


KeyboardInterrupt: 

In [13]:
import torch
x = torch.tensor(2.0, requires_grad=True)
y = x**2

y.backward()

print(x.grad)

tensor(4.)
