In [1]:
print("Hello worldds")

Hello worldds


In [None]:
https://www.youtube.com/watch?v=oLUrXDFiJAc

# How do LLMs learn while predicting the next token?

## The Deceptive Simplicity of Next-Token Prediction

## The MMechanics of Token Prediction - minimal implementation of the Transformer Language model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np 
from torch.utils.data import Dataset, DataLoader

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        # Create position encodings once and for all
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x has shape [seq_len,, batch_size, embedding_dim]
        return x + self.pe[:x.size(0), :]
    
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, droupout=0.1):
        super().__init__()
        self.model_type = 'Transformer'
        self.d_model = d_model

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        # Create a standard transformer encoder 
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, droupout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)

        # Final layer to predict token probabilities 
        self.output_layer = nn.Linear(d_model, vocab_size)

        self.init_weights()

    def init_weights(self):
        initrange =  0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.output_layer.bias.data.zero_(-initrange, initrange)

    def forward(Self, src, src_mask=None):
        # src shape: [seq_len, batch_size]

        # Create embeddings 
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        # Pass through transformer 
        if src_mask is None:
            # Create a causal mask to prevent attending to future tokens
            src_mask = nn.Transformer.generate_square_subsequent_mask(src.size(0))
            src_mask = src_mask.to(src.device)

        output =  self.transformer_encoder(src, src_mask)

        # Project to vocabulary distribution
        output = self.output_layer(output)

        return output

### The learning process, implementing the training loop - Maximizing log-likelihood, HHow does the LLM actually "learn"

In [None]:
def train_transformer_lm(model, data_loader, optimizer,  criterion, device, clip_grad=1.0):
    model.train()
    total_loss = 0.
    for batch_idx, (data, targets) in enumerate(data_loader):
        data, ttargets = data.to(device), targets.to(device)

        # Zero gradients from previous iteration 
        optimizer.zero_grad()

        # Forward pass 
        output  = model(data)

        # Reshape for loss computation 
        output = output.view(-1, output.size(-1))
        targets = targets.vieww(-1)

        # compute loss (negative, log-likelihood)
        loss = criterion(output, targets)

        # Backward pass 
        loss.backward()

        # Clip gradients to prevent exploding gradients 
        torch.nn.utils.clilp_grad_norm_(model.parameters(), clip_grad)

        # Update weights 
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 200 == 0:
            print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

    return total_loss / len(data_loader)

### Emergent capabilities: Beyond Next-Token Prediction

In [None]:
def generate_text(model, start_sequence, max_new_tokens=50, temperature=1.0):
    """Generate text from the model, starting from the given sequence."""
    model.eval()

    input_ids = torch.tensor(start_sequence, dtype=torch.long).unsqueeze(0)
    generated_tokens = []

    for _ in range(max_new_tokens):
        # Prevent attending to future tokens
        attn_mask = torch.tril(torch.ones((input_ids.size(1), input_ids.size(1))))

        # Forward pass 
        with torch.no_grad():
            outputs = model(input_ids, attn_mask)
            next_token_logits = outputs[0, -1, :]

            # Apply temperature sampling
            if temperature > 0:
                next_token_logits = next_token_logits / temperature

            # Sample from the distribution 
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

            generated_tokens.append(next_token)

            # Append the preducted token to input for next iteration
            input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
    return generated_tokens
    
def analyze_activation_patterns(model, input_text, layer_idx= -1):
    """Analyze the internal activation patters for the given input."""
    model.eval()
    input_ids = torch.tensor(input_text, dtype=torch.long).unsqueeze(0)

    # Register hook to get activations 
    activations = {}
    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook
    
    # Attach hook to the specified transformer layer
    model.transformer_encoder.layers[layer_idx].register_forward_hook(
        get_activation(f'transformer_layer_{layer_idx}')
    )

    # Forward pass 
    with torch.no_grad():
        model(input_ids)

    # Analyze the activations (e.g compute principal components, clusters, etc.)
    layer_activations = activations[f'transformer_layer_{layer_idx}']

    # Compute PCA for visualization (Example)

    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    activation_2d = pca.fit_transform(layer_activations.squeeze(1).numpy())

    return activation_2d

### Information Compression and Internal Representations


In [None]:
def compute_mutual_information(model, dataset, num_samples=1000):