In [33]:
"""
Full definition of a GPT Language Model, all of it in this single file.
References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
"""

import math
import inspect
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

# --- GPT with auxiliary reverse-embedding loss from zb ---
import math, inspect
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F

# assumes Block and LayerNorm are defined elsewhere (as in your current setup)

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True

class GPT(nn.Module):

    def __init__(self, config: GPTConfig,
                 aux_scale: float = 1.0,
                 noise_constituent: float = 1e-4,
                 noise_final: float = 1e-4):
        super().__init__()
        assert config.vocab_size is not None and config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight  # weight tying

        # per-block orthonormal maps for blocks 0..3
        need_blocks = 4
        if config.n_layer < need_blocks:
            raise ValueError(f"need at least {need_blocks} transformer blocks for aux; got {config.n_layer}")
        self.aux_blocks = list(range(need_blocks))  # [0,1,2,3] fixed
        self.aux_maps = nn.ModuleList()
        for _ in self.aux_blocks:
            lin = nn.Linear(config.n_embd, config.n_embd, bias=False)
            nn.init.orthogonal_(lin.weight)  # square => orthonormal rows & columns
            self.aux_maps.append(lin)

        # noise/scales
        self.aux_scale_default = float(aux_scale)
        self.noise_constituent = float(noise_constituent)
        self.noise_final = float(noise_final)

        # init all weights
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    # reverse-embedding for one lane (list length T of [idxs, probs])
    def _rev_embed_lane(self, lane_seq, device):
        T = len(lane_seq)
        if T == 0:
            return torch.empty(0, self.config.n_embd, device=device)
        idxs = torch.tensor([pair[0] for pair in lane_seq], device=device, dtype=torch.long)      # (T, K)
        probs = torch.tensor([pair[1] for pair in lane_seq], device=device, dtype=torch.float32)  # (T, K)
        probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-12)

        E = self.transformer.wte.weight  # (V, D)
        emb = E.index_select(0, idxs.reshape(-1)).reshape(*idxs.shape, E.size(1))  # (T, K, D)
        if self.noise_constituent > 0:
            emb = emb + torch.randn_like(emb) * self.noise_constituent
        rev = torch.einsum('tkd,tk->td', emb, probs)  # (T, D)
        if self.noise_final > 0:
            rev = rev + torch.randn_like(rev) * self.noise_final
        return rev

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 66
    n_layer: int = 4          # explicitly 4 per your instruction
    n_head: int = 8
    n_embd: int = 128
    dropout: float = 0.0
    bias: bool = True

class GPT(nn.Module):
    def __init__(self, config: GPTConfig,
                 aux_scale: float = 1e-3,           # (12) fixed
                 noise_constituent: float = 1e-6,    # (7) fixed
                 noise_final: float = 1e-4):         # (7) fixed
        super().__init__()
        assert config.vocab_size is not None and config.block_size is not None
        if config.n_layer != 4:
            raise ValueError("Set n_layer=4 (aux aligns: bigram, 4, 8, 16).")

        self.config = config
        self.aux_scale = float(aux_scale)
        self.noise_constituent = float(noise_constituent)
        self.noise_final = float(noise_final)

        # core transformer
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        #self.transformer.wte.weight = self.lm_head.weight  # weight tying

        # per-block orthonormal linears (square D√óD, columns orthonormal)
        self.aux_maps = nn.ModuleList()
        for _ in range(4):
            lin = nn.Linear(config.n_embd, config.n_embd, bias=False)
            nn.init.orthogonal_(lin.weight)
            self.aux_maps.append(lin)

        # init weights
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    # -------- reverse-embedding helpers (batchified, same dtype/device as x) --------
    def _extract_idx_prob(self, zb, pair_offset):
        """
        zb: Python list of length B; zb[b] is list length T;
            zb[b][t] is length-8 list per spec:
              [idxs_bi, probs_bi, idxs_4, probs_4, idxs_8, probs_8, idxs_16, probs_16]
        pair_offset: 0 for bigram, 2 for m4, 4 for m8, 6 for m16.
        returns: (idxs, probs, K) with shapes (B, T, K)
        """
        B = len(zb)
        T = len(zb[0])
        # infer K from the first timestep
        K = len(zb[0][0][pair_offset])
        idxs = torch.empty((B, T, K), dtype=torch.long)
        probs = torch.empty((B, T, K), dtype=torch.float32)
        for b in range(B):
            seq = zb[b]
            for t in range(T):
                idxs[b, t] = torch.tensor(seq[t][pair_offset], dtype=torch.long)
                probs[b, t] = torch.tensor(seq[t][pair_offset + 1], dtype=torch.float32)
        return idxs, probs, K

    def _rev_embed_batch(self, idxs, probs, dtype, device):
        """
        idxs:  (B, T, K) long
        probs: (B, T, K) float (will be cast to dtype)
        returns rev: (B, T, D) in `dtype` on `device`
        """
        E = self.transformer.wte.weight.to(dtype=dtype)   # (V, D)
        B, T, K = idxs.shape
        # gather embeddings: (B,T,K,D)
        emb = E.index_select(0, idxs.reshape(-1)).reshape(B, T, K, E.size(1))
        # noise per constituent
        if self.noise_constituent > 0:
            emb = emb + torch.randn_like(emb) * self.noise_constituent
        # weighted sum
        probs = probs.to(dtype=dtype, device=device)
        rev = (emb * probs.unsqueeze(-1)).sum(dim=2)      # (B, T, D)
        # final noise
        if self.noise_final > 0:
            rev = rev + torch.randn_like(rev) * self.noise_final
        return rev

    # --- replace ONLY the forward in GPT with this version ---
    def forward(self, idx, targets=None, zb=None):
        """
        idx: (B, T) Long
        targets: (B, T) Long or None
        zb: list of 4 tuples, each (idxs, probs) as numpy arrays with shape (B, T, 32)
            order: 0=bigram, 1=4gram, 2=8gram, 3=16gram
        """
        device = idx.device
        B, T = idx.size()
        assert T <= self.config.block_size
    
        # embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=device)
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
    
        aux_loss = None
        dtype = x.dtype
        ignore = {0: 1, 1: 3, 2: 7, 3: 15}  # warmup ignores
    
        # block loop with vectorized aux
        for bidx, block in enumerate(self.transformer.h):
            x = block(x)  # (B, T, D)
    
            if zb is not None and bidx < 4:
                idxs_np, probs_np = zb[bidx]  # numpy arrays (B, T, K)
                # to tensors on the right device/dtype
                idxs  = torch.from_numpy(idxs_np).to(device=device, dtype=torch.long)
                probs = torch.from_numpy(probs_np).to(device=device, dtype=dtype)
    
                # gather embeddings: E[idxs] -> (B,T,K,D)
                E = self.transformer.wte.weight.to(dtype=dtype)
                BTK = idxs.reshape(-1)
                emb = E.index_select(0, BTK).reshape(B, T, probs.size(-1), E.size(1))
    
                # tiny noise per constituent, then weighted sum -> (B,T,D)
                if self.noise_constituent > 0:
                    emb = emb + torch.randn_like(emb) * self.noise_constituent
                rev = (emb * probs.unsqueeze(-1)).sum(dim=2)
                if self.noise_final > 0:
                    rev = rev + torch.randn_like(rev) * self.noise_final
                rev = rev + pos_emb
                # per-block projector and masked MSE
                mapped = rev# self.aux_maps[bidx](rev)  # (B,T,D)
                mask = (torch.arange(T, device=device).expand(B, T) >= ignore[bidx]).unsqueeze(-1)  # (B,T,1)
                diff2 = (x - mapped) ** 2
                diff2 = diff2 * mask  # bool -> broadcast
                denom = (mask.sum() * diff2.size(-1)).clamp_min(1)
                block_loss = diff2.sum() / denom
    
                aux_loss = block_loss if aux_loss is None else aux_loss + block_loss
    
        # head + CE
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
    
        ce_loss = None
        if targets is not None:
            ce_loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
                                      targets.view(-1), ignore_index=-100)
    
        # total
        if ce_loss is None and aux_loss is None:
            loss = None
        elif aux_loss is None:
            loss = ce_loss
        elif ce_loss is None:
            loss = self.aux_scale * aux_loss
        else:
            loss = ce_loss + self.aux_scale * aux_loss
            loss = loss /5.0 #scale appropriately
    
        if targets is None:
            logits = logits[:, [-1], :]
    
        return logits, loss


In [2]:
import requests, os

base_url = "https://huggingface.co/datasets/cambridge-climb/BabyLM/resolve/main/clean/10M/"
target_dir = "./babylm_10m_cleaned"
os.makedirs(target_dir, exist_ok=True)

file_names = [
    "aochildes.txt",
    "cbt.txt",
    "children_stories.txt",
    "gutenberg.txt",
    "qed.txt",
    "simple_wikipedia.txt",
    "switchboard.txt",
    "wikipedia.txt"
]

# Optional addition: Shakespeare from another dataset
shakespeare_url = "https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt"
shakespeare_fname = "shakespeare.txt"

# Combined download logic
all_files = [(base_url + fname, fname) for fname in file_names]
all_files.append((shakespeare_url, shakespeare_fname))  # Add Shakespeare


# Download loop
for url, fname in all_files:
    out_path = os.path.join(target_dir, fname)
    print(f"üì• Downloading {fname}...")
    resp = requests.get(url)
    if resp.status_code == 200:
        with open(out_path, "w", encoding="utf-8") as f:
            f.write(resp.text)
    else:
        print(f"‚ùå Failed to download {fname} ({resp.status_code})")

print(f"‚úÖ Done. Files saved to {target_dir}")

üì• Downloading aochildes.txt...
üì• Downloading cbt.txt...
üì• Downloading children_stories.txt...
üì• Downloading gutenberg.txt...
üì• Downloading qed.txt...
üì• Downloading simple_wikipedia.txt...
üì• Downloading switchboard.txt...
üì• Downloading wikipedia.txt...
üì• Downloading shakespeare.txt...
‚úÖ Done. Files saved to ./babylm_10m_cleaned


In [52]:

import os
import pickle
import numpy as np

# === Paths ===
source_dir = "./babylm_10m_cleaned"
out_dir    = "./babylm_char_tokenized"
os.makedirs(out_dir, exist_ok=True)

file_names = [
    "shakespeare.txt"#,"aochildes.txt", "cbt.txt", "children_stories.txt", "gutenberg.txt",
    #"qed.txt", "simple_wikipedia.txt", "switchboard.txt", "wikipedia.txt"
]

# === Load and split ===
train_texts, val_texts = [], []
char_set = set()

for fname in file_names:
    with open(os.path.join(source_dir, fname), encoding="utf-8") as f:
        lines = f.readlines()
        n = len(lines)
        split = int(0.9 * n)
        train_part = "".join(lines[:split])
        val_part   = "".join(lines[split:])
        train_texts.append(train_part)
        val_texts.append(val_part)
        char_set.update(train_part)
        char_set.update(val_part)

full_train = "\n".join(train_texts)
full_val   = "\n".join(val_texts)

# === Final vocab ===
char_set = sorted(set(char_set))
vocab_chars = ["<unk>"] + [c for c in char_set if c != "<unk>"]

stoi = {ch: i for i, ch in enumerate(vocab_chars)}
itos = {i: ch for ch, i in stoi.items()}

# === Encode function ===
def encode(text):
    return [stoi.get(c, 0) for c in text]

train_ids = np.array(encode(full_train), dtype=np.uint16)
val_ids   = np.array(encode(full_val),   dtype=np.uint16)

# === Save ===
train_ids.tofile(os.path.join(out_dir, "train.bin"))
val_ids.tofile(os.path.join(out_dir, "val.bin"))

with open(os.path.join(out_dir, "meta.pkl"), "wb") as f:
    pickle.dump({
        "vocab_size": len(stoi),
        "stoi": stoi,
        "itos": itos
    }, f)

print(f"‚úÖ Char tokenizer finalized.")
print(f"üßæ Train tokens: {len(train_ids)} | Val tokens: {len(val_ids)}")
print(f"üî§ Vocab size: {len(stoi)}")

import os
import pickle
import numpy as np
from collections import defaultdict, Counter
import random

# === Load data ===
data_dir = "./babylm_char_tokenized"
train_path = os.path.join(data_dir, "train.bin")
val_path   = os.path.join(data_dir, "val.bin")
meta_path  = os.path.join(data_dir, "meta.pkl")

train_ids = np.fromfile(train_path, dtype=np.uint16)
val_ids   = np.fromfile(val_path,   dtype=np.uint16)

with open(meta_path, "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
vocab_size = meta["vocab_size"]

print(f"Loaded {len(train_ids)} train tokens and {len(val_ids)} val tokens | vocab={vocab_size}")

# === Build Markov Models ===
def build_markov_chain(data, window):
    """
    Builds a Markov chain of given window size.
    Returns: dict mapping tuple(context) -> Counter(next_token)
    """
    chain = defaultdict(Counter)
    for i in range(len(data) - window):
        context = tuple(data[i : i + window])
        nxt = data[i + window]
        chain[context][nxt] += 1
    return chain

windows = [4, 8, 16]
markov_models = {}

for w in windows:
    print(f"Building {w}-token Markov chain...")
    markov_models[w] = build_markov_chain(train_ids, w)

# === Build Bigram Continuation Probabilities ===
import numpy as np

def build_bigram_distribution_fixed(data, vocab_size, top_k=16, seed=1337, epsilon=1e-6):
    rng = np.random.default_rng(seed)

    # counts
    bigram_counts = np.zeros((vocab_size, vocab_size), dtype=np.int32)
    a = data[:-1]
    b = data[1:]
    np.add.at(bigram_counts, (a, b), 1)

    out_idx = np.empty((vocab_size, top_k), dtype=np.int32)
    out_p   = np.empty((vocab_size, top_k), dtype=np.float32)

    all_ids = np.arange(vocab_size, dtype=np.int32)

    for tok in range(vocab_size):
        counts = bigram_counts[tok]
        total = counts.sum()

        if total == 0:
            # no observations ‚Äî choose k random unique tokens and make them uniform
            idx = rng.choice(vocab_size, size=top_k, replace=False)
            p = np.full(top_k, 1.0 / top_k, dtype=np.float32)
            out_idx[tok] = idx
            out_p[tok] = p
            continue

        probs_full = counts.astype(np.float64) / float(total)
        observed = np.flatnonzero(counts)

        if observed.size >= top_k:
            # get top_k among observed only (fast top-k)
            obs_p = probs_full[observed]
            kth = np.argpartition(obs_p, -top_k)[-top_k:]
            idx = observed[kth]
            p = probs_full[idx].astype(np.float32)
            # normalize in case of numerical drift
            s = p.sum()
            p = p / s if s > 0 else np.full(top_k, 1.0 / top_k, dtype=np.float32)
        else:
            # take all observed, randomly fill the rest from unobserved
            need = top_k - observed.size
            mask = np.ones(vocab_size, dtype=bool)
            mask[observed] = False
            pool = all_ids[mask]
            # sample without replacement to avoid duplicates
            extra = rng.choice(pool, size=need, replace=False)
            idx = np.concatenate([observed, extra])

            p = probs_full[idx].astype(np.float32)
            # give a tiny positive mass to the extras that were unobserved (counts==0)
            unobs = (counts[idx] == 0)
            if unobs.any():
                p = p + unobs.astype(np.float32) * epsilon
            p = p / p.sum()

        # ensure a consistent ordering (optional): sort descending prob
        order = np.argsort(-p)
        out_idx[tok] = idx[order]
        out_p[tok]   = p[order]

    # return as simple dict-of-tuples for backward compatibility
    bigram_db = {int(t): (out_idx[t], out_p[t]) for t in range(vocab_size)}
    return bigram_db

print("Building bigram probability distribution...")
bigram_db = build_bigram_distribution_fixed(train_ids, vocab_size)

# === Save ===
model_dir = "./markov_bigram_models"
os.makedirs(model_dir, exist_ok=True)

with open(os.path.join(model_dir, "markov_models.pkl"), "wb") as f:
    pickle.dump(markov_models, f)

with open(os.path.join(model_dir, "bigram_db.pkl"), "wb") as f:
    pickle.dump(bigram_db, f)

print("‚úÖ Markov and Bigram models saved.")
print(f"Chains: {[f'order={w}' for w in windows]}")


‚úÖ Char tokenizer finalized.
üßæ Train tokens: 1016242 | Val tokens: 99152
üî§ Vocab size: 66
Loaded 1016242 train tokens and 99152 val tokens | vocab=66
Building 4-token Markov chain...
Building 8-token Markov chain...
Building 16-token Markov chain...
Building bigram probability distribution...
‚úÖ Markov and Bigram models saved.
Chains: ['order=4', 'order=8', 'order=16']


In [23]:
import os
import pickle
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# === Config ===
data_dir = "./babylm_char_tokenized"  # <- char-tokenized data
block_size = 1024
batch_size = 8

# === Load tokenizer metadata ===
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
vocab_size = meta['vocab_size']

# === Load mmap edata (char-level tokens, uint16) ===
train_ids = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_ids   = np.memmap(os.path.join(data_dir, 'val.bin'),   dtype=np.uint16, mode='r')

# === Replacement dataloader that uses SAVED bigram + markov models and yields (X, Y, Z) ===
import os, pickle, numpy as np, torch
from torch.utils.data import Dataset, DataLoader

# expects `vocab_size` and `device` already defined in the outer scope
# expects saved models at ./markov_bigram_models/{bigram_db.pkl, markov_models.pkl}



class ZPack:
    __slots__ = ("blocks",)
    def __init__(self, blocks):
        # blocks = [(idxs_np, probs_np), ...] length 4, each np arrays shape (B,T,32)
        self.blocks = blocks
    def __getitem__(self, i):
        return self.blocks[i]  # allows model to access zb[bidx] -> (idxs_np, probs_np)
    # no __len__ and no Sequence inheritance => collate treats this as an opaque object

class GPUBatchDataset(Dataset):
    def __init__(self, mmap_file, block_size, batch_size, device,
                 model_dir="./markov_bigram_models", jitter=63, p_aligned=0.5, pad_len=0,
                 top_k=16, seed=1337):
        self.data = mmap_file
        self.block_size = int(block_size)
        self.batch_size = int(batch_size)
        self.device = device
        self.pad_len = int(pad_len)
        self.sample_len = self.block_size + self.pad_len
        self.total = len(self.data) - self.sample_len - 1
        self.n_blocks = max(1, self.total // self.sample_len)
        self.jitter = int(jitter)
        self.p_aligned = float(p_aligned)
        self.top_k = int(top_k)
        self.rng = np.random.default_rng(seed)

        with open(os.path.join(model_dir, "bigram_db.pkl"), "rb") as f:
            self.bigram_db = pickle.load(f)
        with open(os.path.join(model_dir, "markov_models.pkl"), "rb") as f:
            self.markov_models = pickle.load(f)  # {4:..., 8:..., 16:...}

        assert isinstance(vocab_size, int)

    def __len__(self):
        return self.total // self.batch_size

    def _sample_block(self):
        base_block = self.rng.integers(0, self.n_blocks)
        start = base_block * self.sample_len
        if self.rng.random() > self.p_aligned:
            j = self.rng.integers(0, self.jitter + 1)
            start = min(start + j, self.total)
        return start

    def _finalize_topk_from_counts(self, counter, top_k=16, epsilon=1e-6):
        rng = self.rng
        if not counter:
            idxs = rng.choice(vocab_size, size=top_k, replace=False)
            probs = np.full(top_k, 1.0/top_k, dtype=np.float32)
            return idxs.astype(np.int64), probs
        items = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        obs_idxs = np.fromiter((t for t, _ in items), dtype=np.int64, count=len(items))
        obs_cnts = np.fromiter((c for _, c in items), dtype=np.float64, count=len(items))
        if len(obs_idxs) >= top_k:
            idxs = obs_idxs[:top_k]
            probs = (obs_cnts[:top_k] / obs_cnts[:top_k].sum()).astype(np.float32)
            return idxs, probs
        need = top_k - len(obs_idxs)
        mask = np.ones(vocab_size, dtype=bool); mask[obs_idxs] = False
        extras = rng.choice(np.nonzero(mask)[0], size=need, replace=False).astype(np.int64)
        idxs = np.concatenate([obs_idxs, extras])
        probs = np.concatenate([obs_cnts, np.full(need, epsilon, dtype=np.float64)]).astype(np.float32)
        probs = probs / probs.sum()
        return idxs, probs

    def _finalize_topk_from_bigram(self, entry, top_k=16, epsilon=1e-6):
        rng = self.rng
        if entry is None:
            idxs = rng.choice(vocab_size, size=top_k, replace=False).astype(np.int64)
            probs = np.full(top_k, 1.0/top_k, dtype=np.float32)
            return idxs, probs
        idxs, probs = entry
        idxs = np.asarray(idxs, dtype=np.int64)
        probs = np.asarray(probs, dtype=np.float32)
        if idxs.shape[0] > top_k:
            order = np.argsort(-probs)[:top_k]
            idxs, probs = idxs[order], probs[order]
        elif idxs.shape[0] < top_k:
            need = top_k - idxs.shape[0]
            mask = np.ones(vocab_size, dtype=bool); mask[idxs] = False
            extras = rng.choice(np.nonzero(mask)[0], size=need, replace=False).astype(np.int64)
            idxs = np.concatenate([idxs, extras])
            probs = np.concatenate([probs, np.full(need, epsilon, dtype=np.float32)])
        probs = probs / probs.sum()
        return idxs, probs

    def _dist_bigram(self, tok):
        entry = self.bigram_db.get(int(tok), None)
        return self._finalize_topk_from_bigram(entry, top_k=self.top_k)

    def _dist_markov(self, ctx_tuple, backoff_tok):
        counter = None
        if ctx_tuple is not None:
            chain = self.markov_models.get(len(ctx_tuple), {})
            counter = chain.get(ctx_tuple, None)
        if counter:
            return self._finalize_topk_from_counts(counter, top_k=self.top_k)
        return self._dist_bigram(backoff_tok)

    def __getitem__(self, _):
        B, T, K = self.batch_size, self.block_size, self.top_k
        X = np.empty((B, self.sample_len), dtype=np.int64)
        Y = np.empty((B, T), dtype=np.int64)

        # preallocate Z blocks
        bi_idx  = np.empty((B, T, K), dtype=np.int64); bi_p  = np.empty((B, T, K), dtype=np.float32)
        m4_idx  = np.empty((B, T, K), dtype=np.int64); m4_p  = np.empty((B, T, K), dtype=np.float32)
        m8_idx  = np.empty((B, T, K), dtype=np.int64); m8_p  = np.empty((B, T, K), dtype=np.float32)
        m16_idx = np.empty((B, T, K), dtype=np.int64); m16_p = np.empty((B, T, K), dtype=np.float32)

        for i in range(B):
            start = self._sample_block()
            X[i] = self.data[start : start + self.sample_len]
            Y[i] = self.data[start + 1 + self.pad_len : start + 1 + self.pad_len + T]

            for j in range(T):
                tok_now = int(X[i, j])

                # bigram
                idxs, probs = self._dist_bigram(tok_now)
                bi_idx[i, j, :] = idxs; bi_p[i, j, :] = probs

                # contexts for markov
                ctx4  = tuple(int(x) for x in X[i, j-3 :  j+1]) if j >= 3  else None
                ctx8  = tuple(int(x) for x in X[i, j-7 :  j+1]) if j >= 7  else None
                ctx16 = tuple(int(x) for x in X[i, j-15:  j+1]) if j >= 15 else None

                idxs, probs = self._dist_markov(ctx4,  tok_now);  m4_idx[i, j, :]  = idxs; m4_p[i, j, :]  = probs
                idxs, probs = self._dist_markov(ctx8,  tok_now);  m8_idx[i, j, :]  = idxs; m8_p[i, j, :]  = probs
                idxs, probs = self._dist_markov(ctx16, tok_now);  m16_idx[i, j, :] = idxs; m16_p[i, j, :] = probs

        # wrap Z so collate doesn't decompose it
        Z = ZPack([
            (bi_idx,  bi_p),
            (m4_idx,  m4_p),
            (m8_idx,  m8_p),
            (m16_idx, m16_p),
        ])

        return (
            torch.from_numpy(X).to(self.device, non_blocking=True),
            torch.from_numpy(Y).to(self.device, non_blocking=True),
            Z,  # opaque: collate will return [Z] for batch_size=1, so training uses zb=zb[0]
        )

# instantiate (unchanged outer config)
model_dir = "./markov_bigram_models"
train_dataset = GPUBatchDataset(
    np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r'),
    block_size=block_size,
    batch_size=batch_size,
    device=device,
    model_dir=model_dir,
)
    

In [28]:

def collate_keep_z(batch):
    # batch: list of N items; each item is (X, Y, Z)
    # X: (B, T_x) tensor on device
    # Y: (B, T)   tensor on device
    # Z: [(idxs, probs)] * 4, each np arrays (B, T, 32)
    Xs, Ys, Zs = zip(*batch)  # tuples of length N

    # stack X/Y across the outer dataloader batch (keeps them on the same device)
    X = torch.cat(Xs, dim=0)
    Y = torch.cat(Ys, dim=0)

    # merge Z by concatenating along batch axis (axis=0) for each of the 4 blocks
    merged_blocks = []
    for b in range(4):
        idxs_list  = [Z[b][0] for Z in Zs]  # list of np arrays (B_i, T, 32)
        probs_list = [Z[b][1] for Z in Zs]
        idxs  = np.concatenate(idxs_list,  axis=0)  # (sum B_i, T, 32)
        probs = np.concatenate(probs_list, axis=0)  # (sum B_i, T, 32)
        merged_blocks.append((idxs, probs))

    return X, Y, merged_blocks

# --- use the custom collate in your DataLoader (keep batch_size=1 as you have) ---
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_keep_z
)


def train_epoch():
    model.train()
    total_loss = 0.0
    for xb, yb, zb in train_loader:
        # xb: (B, T), yb: (B, T), zb: list of 4 tuples (np arrays (B,T,32))
        logits, loss = model(xb, yb, zb)   # model unchanged
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        total_loss += loss.item()
        losses.append(loss.item())
        print(loss.item())
    return total_loss / len(train_loader)
    



In [34]:
import os
import pickle
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# === Config ===
data_dir = "./babylm_char_tokenized"  # <- char-tokenized data
block_size = 1024
batch_size = 8

config =  GPTConfig(
    block_size,
    vocab_size,
    n_layer=4,      
    n_head = 8,
    n_embd =128)

model = GPT(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)
losses = []

model = model.to(device)


number of parameters: 0.88M


In [35]:
print(sum(param.numel() for param in model.parameters()))

1006848


In [None]:


# === Run Training ===
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch()
    print(f"Epoch {epoch:2d} | Train loss: {train_loss:.4f}")

0.2866063714027405
0.2804186940193176
0.29430702328681946
0.28242185711860657
0.27890118956565857
0.28030818700790405
0.29218801856040955
0.282606303691864
0.29292792081832886
0.2955491542816162
0.28361839056015015
0.2983800768852234
0.28566843271255493
0.2994258403778076
0.29237082600593567
0.2799804210662842
0.28826794028282166
0.2796483337879181
0.2796468138694763
0.2895428538322449
0.2943086326122284
0.27300262451171875
0.28833654522895813
0.2865556478500366
0.298909991979599
0.2812435030937195
0.29573437571525574
0.2899649143218994
0.27059102058410645
0.29005515575408936


In [38]:
import pickle
def decode_chars(token_ids, itos):
    """
    Decodes a list of character token IDs into a string.
    """
    return ''.join([itos[i] for i in token_ids])

def encode_chars(text, stoi):
    """
    Encodes a string into a list of token IDs, one per character.
    """
    return [stoi.get(c, 0) for c in text]


def decode_sequence_char(
    model, stoi, itos, prompt, max_new_tokens=100, block_size=256,
    use_fenchel=False, tau=1.0, fenchel_iters=3, temperature=1.0
):
    model.eval()
    device = next(model.parameters()).device

    max_ctx = int(block_size)

    # ?1 assume space token exists; fallback to 0 if missing
    space_id = stoi.get(' ', 0)

    # encode prompt
    start_ids = torch.tensor([encode_chars(prompt, stoi)], dtype=torch.long, device=device)

    # prepend the pad once; from now on the window just slides

    idx = start_ids

    for _ in range(max_new_tokens):
        context = idx[:, -max_ctx:]  # rotating buffer: last pad_len+block_size tokens
        logits, _ = model(context, None)
        last_logits = logits[:, -1, :]
        probs = torch.softmax(last_logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_token], dim=1)

    # drop the initial pad when returning the string
    out_tokens = idx[0].tolist()
    return decode_chars(out_tokens, itos)
with open("./babylm_char_tokenized/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi = meta["stoi"]
itos = meta["itos"]

prompt = "dingus"
generated = decode_sequence_char(
    model=model,
    stoi=stoi,
    itos=itos,
    prompt=prompt,
    max_new_tokens=1024,
    block_size=1024,
    use_fenchel=False,
    tau=1.5,
    fenchel_iters=2,
    temperature=1.0
)

print(generated)

dinguse your to vow.
What king, sterpeatent, washint.
'Tis come it a suw it.

MARCIUS:
I is nothict my deaths, let themselves then.

MARCIUS:
Ratclets not, Wasliets and to dose,
pretiniof of the kings of mere a seclarm
was to peacestioes of Sicille this success!

MARCIUS:
Nay, brother that friung safeves belad breased?

CORIOLANUS:
He speak With pred your be rich make at our time,
Nor I fate, now is so thee.

MARCIUS:
GoneflAhest You have you all, mattere, but Pompey,
I, your to be sentrumped, to York niship highnes
Are hot name to succest will the so my maid,
He had siling good husband; we shall then bare
Aback prithe up it; Marctnagus it thine weallive
Were is hollips throught grave dusly of your so.
Ay, and matttle doest and she so.
Rome thy friends fay uncle.

CAMILLO:
Which she cary fool-extred to them: tie your chare!
We with leter you at before no him: I'll would
A have of Mown, and shall fre's it the gentleman:
I am thou hard bornitest words, but me,
For trnocently thing do a s