<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Small_Language_Model_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import time
import math
import requests
import io

# --- CORRECTED HYPERPARAMETERS (Increased Capacity and Context) ---
BLOCK_SIZE = 64        # FIXED LIMITATION: Increased Context Length (T)
BATCH_SIZE = 128       # Increased Batch Size for better GPU utilization
N_EMBED = 256          # FIXED LIMITATION: Doubled Dimensionality/Capacity
N_HEAD = 8             # Increased Heads (to match N_EMBED / N_HEAD = 32 head size)
N_LAYER = 6            # FIXED LIMITATION: Doubled Transformer Blocks
DROPOUT = 0.1          # Slightly reduced dropout (common for larger models)
LEARNING_RATE = 5e-4
MAX_ITERS = 10000      # Increased Max Steps for the larger model
EVAL_INTERVAL = 1000   # Check less often (save time)
EVAL_ITERS = 20        # Use more batches for a stable loss estimate
PATIENCE = 5

# Set a manual seed for reproducibility
torch.manual_seed(1337)

# Use CUDA if available, otherwise CPU
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

# --- 1. Data Loading and Tokenization (Character-level) ---

def load_data_and_tokenize():
    # --- WORKING LARGE CORPUS URL (Complete Works of Shakespeare) ---
    CORPUS_URL = "https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
    try:
        response = requests.get(CORPUS_URL)
        response.raise_for_status() # Raise an error for bad status codes
        text = response.text

        if len(text) < 100000:
             raise ValueError("Downloaded corpus is too small.")

        print("Successfully loaded corpus (The Complete Works of Shakespeare).")
    except Exception as e:
        print(f"ERROR: Failed to download corpus (Reason: {e}). Falling back to small corpus.")
        # Fallback to the original small corpus
        text = """The wind whistled through the high peaks of the mountain, carrying with it the scent of pine
        and cold, ancient stone. Across the vast valley lay the village, small and quiet beneath the
        shadow of the giant mountain range. It was here, in a small, stone cottage near the river,
        that the philosopher spent his days. He worked on his magnum opus, a treatise on the nature
        of time and consciousness, convinced that the key to understanding reality lay not in observation,
        but in reflection.
        """

    # Simple character-level tokenizer
    chars = sorted(list(set(text)))
    vocab_size = len(chars)

    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])

    # Convert the entire text to tensor
    data = torch.tensor(encode(text), dtype=torch.long)

    # Split data into training and validation sets
    n = int(0.9 * len(data))
    train_data = data[:n]
    val_data = data[n:]

    print(f"Total Corpus Size: {len(data):,} tokens")
    print(f"Vocabulary Size: {vocab_size}")
    print(f"Training data size: {len(train_data):,} tokens")

    return vocab_size, train_data, val_data, encode, decode

def get_batch(split, train_data, val_data):
    # Generates a small batch of data and targets
    data = train_data if split == 'train' else val_data

    # --- GUARDRAIL: Check if the data is long enough to sample a sequence ---
    required_length = BLOCK_SIZE + 1
    if len(data) < required_length:
         raise ValueError(
            f"Insufficient data for batching in '{split}' split. "
            f"Data length ({len(data)}) must be at least BLOCK_SIZE + 1 ({required_length})."
        )

    # The upper bound for randint must be greater than 0.
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))

    # Create input sequences (x) and target sequences (y)
    x = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])

    x, y = x.to(device), y.to(device)
    return x, y


# --- 2. Causal Self-Attention and Transformer Block (No changes required here) ---

class Head(nn.Module):
    """ One head of self-attention """
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(N_EMBED, head_size, bias=False)
        self.query = nn.Linear(N_EMBED, head_size, bias=False)
        self.value = nn.Linear(N_EMBED, head_size, bias=False)
        # tril size adjusted automatically when BLOCK_SIZE is changed above
        self.register_buffer('tril', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))
        self.dropout = nn.Dropout(DROPOUT)
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x); q = self.query(x)
        # Attention scores calculation (Scaled Dot-Product)
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        # Causal masking
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1); wei = self.dropout(wei)
        v = self.value(x); out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    """ Multiple heads of self-attention run in parallel """
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        # Projection layer size is adjusted for the new N_EMBED
        self.proj = nn.Linear(N_EMBED, N_EMBED)
        self.dropout = nn.Dropout(DROPOUT)
    def forward(self, x):
        # Concatenate outputs from all heads
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    """ A simple position-wise MLP applied to every token individually """
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            # Inner layer size is 4 * N_EMBED, adjusted for the new N_EMBED
            nn.Linear(n_embed, 4 * n_embed), nn.GELU(),
            nn.Linear(4 * n_embed, n_embed), nn.Dropout(DROPOUT),
        )
    def forward(self, x): return self.net(x)

class Block(nn.Module):
    """ Transformer Block: Comms (MHA) + Compute (FFWD) """
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed); self.ln2 = nn.LayerNorm(n_embed)
    def forward(self, x):
        # Add & Norm (Pre-norm formulation: LayerNorm before MHA/FFWD)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


# --- 3. The Small Language Model (SLM) (No changes required here) ---

class SLM(nn.Module):
    """ The full decoder-only Transformer model """
    def __init__(self, vocab_size):
        super().__init__()
        # Embedding tables are resized for the new N_EMBED and BLOCK_SIZE
        self.token_embedding_table = nn.Embedding(vocab_size, N_EMBED)
        self.position_embedding_table = nn.Embedding(BLOCK_SIZE, N_EMBED)
        # Number of blocks adjusted by the new N_LAYER
        self.blocks = nn.Sequential(*[Block(N_EMBED, N_HEAD) for _ in range(N_LAYER)])
        self.ln_f = nn.LayerNorm(N_EMBED)
        self.lm_head = nn.Linear(N_EMBED, vocab_size)
        self.apply(self._init_weights)
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    def forward(self, idx, targets=None):
        B, T = idx.shape
        # Token and Position embeddings
        tok_emb = self.token_embedding_table(idx)
        pos = torch.arange(T, device=device); pos_emb = self.position_embedding_table(pos)
        x = tok_emb + pos_emb
        # Transformer blocks
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            B, T, C = logits.shape; logits = logits.view(B * T, C); targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=0.8, top_k=5):
        for _ in range(max_new_tokens):
            # Context truncation adjusted by the new BLOCK_SIZE
            idx_cond = idx[:, -BLOCK_SIZE:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]; logits = logits / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


# --- 4. Training and Evaluation (No changes required here) ---

BEST_VAL_LOSS = float('inf'); PATIENCE_GLOBAL = PATIENCE; PATIENCE_COUNTER = 0; BEST_MODEL_STATE = None

@torch.no_grad()
def estimate_loss(model, train_data, val_data):
    """ Estimates the loss over a fixed number of batches """
    out = {}; model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            X, Y = get_batch(split, train_data, val_data)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def train_slm():
    global BEST_VAL_LOSS; global PATIENCE_COUNTER; global BEST_MODEL_STATE
    BEST_VAL_LOSS = float('inf'); PATIENCE_COUNTER = 0; BEST_MODEL_STATE = None

    # Load and tokenize data
    vocab_size, train_data, val_data, encode, decode = load_data_and_tokenize()

    # Initialize model and optimizer
    model = SLM(vocab_size); model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # Total number of parameters (for context)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nModel has {total_params:,} trainable parameters. Training on {device}.")

    # Training Loop
    start_time = time.time()
    for iter in range(MAX_ITERS):

        # --- EARLY STOPPING CHECK ---
        if iter % EVAL_INTERVAL == 0:
            losses = estimate_loss(model, train_data, val_data)
            elapsed = time.time() - start_time
            print(f"Step {iter}: Train Loss {losses['train']:.4f}, Val Loss {losses['val']:.4f} (Elapsed: {elapsed:.2f}s)")
            if losses['val'] < BEST_VAL_LOSS:
                BEST_VAL_LOSS = losses['val']; PATIENCE_COUNTER = 0
                BEST_MODEL_STATE = model.state_dict()
            else:
                PATIENCE_COUNTER += 1
            if PATIENCE_COUNTER >= PATIENCE:
                print(f"\n--- EARLY STOP TRIGGERED at Step {iter} ---")
                print(f"Validation loss has not improved for {PATIENCE * EVAL_INTERVAL} steps.")
                break

        # Sample a batch of data
        xb, yb = get_batch('train', train_data, val_data)

        # Forward pass, Backward pass, and optimization
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True); loss.backward(); optimizer.step()

    # --- Restore Best Weights ---
    if BEST_MODEL_STATE:
        model.load_state_dict(BEST_MODEL_STATE)
        print("Restored model to best state observed during training.")

    end_time = time.time(); print(f"\nTraining finished in {end_time - start_time:.2f} seconds.")

    # --- 5. Inference (Text Generation) ---
    print("\nStarting Inference with Best Model...")
    prompt = "KING LEAR:\n"
    context = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
    generated_ids = model.generate(context, max_new_tokens=500, temperature=0.8, top_k=5)[0].tolist()

    print("\n--- Generated Text ---")
    print(decode(generated_ids))
    print("----------------------")


if __name__ == '__main__':
    torch.set_printoptions(precision=4)
    train_slm()

Successfully loaded corpus (The Complete Works of Shakespeare).
Total Corpus Size: 5,458,199 tokens
Vocabulary Size: 91
Training data size: 4,912,379 tokens

Model has 4,797,531 trainable parameters. Training on cuda.
Step 0: Train Loss 4.4800, Val Loss 4.4831 (Elapsed: 1.63s)
Step 1000: Train Loss 1.4325, Val Loss 1.5341 (Elapsed: 121.72s)
Step 2000: Train Loss 1.2687, Val Loss 1.4119 (Elapsed: 241.68s)
Step 3000: Train Loss 1.2197, Val Loss 1.3647 (Elapsed: 361.60s)
Step 4000: Train Loss 1.1925, Val Loss 1.3428 (Elapsed: 481.50s)
Step 5000: Train Loss 1.1640, Val Loss 1.3249 (Elapsed: 601.44s)
Step 6000: Train Loss 1.1473, Val Loss 1.3118 (Elapsed: 721.24s)
Step 7000: Train Loss 1.1295, Val Loss 1.3194 (Elapsed: 841.22s)
Step 8000: Train Loss 1.1164, Val Loss 1.3030 (Elapsed: 961.28s)
Step 9000: Train Loss 1.1115, Val Loss 1.2908 (Elapsed: 1081.33s)
Restored model to best state observed during training.

Training finished in 1199.70 seconds.

Starting Inference with Best Model...

--