In [1]:
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"

In [2]:
ids_to_download = list(range(1822))

In [3]:
import os
import requests

DATA_DIR = 'data'
os.makedirs(DATA_DIR, exist_ok=True)

filename = f'shard_{ids_to_download[0]:05d}.parquet'
filepath = os.path.join(DATA_DIR, filename)
if not os.path.exists(filepath):
    url = f'{BASE_URL}/{filename}'
    response = requests.get(url)
    temp_filepath = filepath + '.tmp'
    with open(temp_filepath, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
    os.rename(temp_filepath, filepath)

In [4]:
import pyarrow.parquet as pq

In [5]:
pf = pq.ParquetFile(filepath)

In [6]:
for rg_idx in range(0, pf.num_row_groups):
    rg = pf.read_row_group(rg_idx)
    texts = rg.column('text').to_pylist()  

In [7]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers import Regex, pre_tokenizers, decoders
from tokenizers.trainers import BpeTrainer

In [8]:
tokenizer = Tokenizer(BPE(byte_fallback=True, unk_token=None, fuse_unk=False))
tokenizer.normalizer = None

SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

SPECIAL_TOKENS = [
    # every document begins with the Beginning of Sequence (BOS) token that delimits documents
    "<|bos|>",
    # tokens below are only used during finetuning to render Conversations into token ids
    "<|user_start|>", # user messages
    "<|user_end|>",
    "<|assistant_start|>", # assistant messages
    "<|assistant_end|>",
    "<|python_start|>", # assistant invokes python REPL tool
    "<|python_end|>",
    "<|output_start|>", # python REPL outputs back to assistant
    "<|output_end|>",
]

gpt4_split_regex = Regex(SPLIT_PATTERN)

In [9]:
text_iterator = (text for rg_idx in range(pf.num_row_groups)
                 for rg in [pf.read_row_group(rg_idx)]
                 for text in rg.column('text').to_pylist())

In [10]:
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
    pre_tokenizers.Split(gpt4_split_regex, behavior='isolated'),
    pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = None
trainer = BpeTrainer(
    vocab_size= 50257,
    show_progress=True,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
    special_tokens=SPECIAL_TOKENS
)
tokenizer.train_from_iterator(text_iterator, trainer=trainer)








In [15]:
# Data loading

B = 1
T = 512


In [16]:
needed_tokens = B*T+1

In [23]:
from collections import deque
import torch

token_buffer = deque()

while len(token_buffer) < needed_tokens:
    token_lists = [tokenizer.encode(doc, add_special_tokens=True).ids for doc in texts]
    for token_list in token_lists:
        token_buffer.extend(token_list)
    
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
scratch = torch.tensor(tokens, dtype=torch.int64)
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]

inputs = inputs_cpu.view(B, T).to('cuda')
targets = targets_cpu.view(B, T).to('cuda')

In [29]:
config =None

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class GPTConfig:
    sequence_len: int = 512
    n_layers: int = 12
    vocab_size: int = 50257
    emb_dim: int = 128
    n_heads: int = 8

def norm(x):
    return F.rms_norm(x, (x.size(-1),))

# (softmax(qT * k)/sqrt(d_k))*v

class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.n_heads = config.n_heads
        self.head_dim = config.emb_dim // config.n_heads
        self.c_q = nn.Linear(config.emb_dim, self.n_heads * self.head_dim, bias=False)
        self.c_k = nn.Linear(config.emb_dim, self.n_heads * self.head_dim, bias=False)
        self.c_v = nn.Linear(config.emb_dim, self.n_heads * self.head_dim, bias=False)
        self.c_proj = nn.Linear(config.emb_dim, config.emb_dim, bias=False)
    
    def forward(self, x):
        B,T,C = x.shape
        q = self.c_q(x).view(B, T, self.n_heads, self.head_dim) # B, T, H, D
        k = self.c_k(x).view(B, T, self.n_heads, self.head_dim) # B, T, H, D
        v = self.c_v(x).view(B, T, self.n_heads, self.head_dim) # B, T, H, D
        q, k, v = q.transpose(1, 2), k.transpose(1,2), v.transpose(1,2)
        q, k, v = norm(q), norm(k), norm(v)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # B, H, T, D
        y = y.transpose(1,2).contiguous().view(B, T, -1)
        y = self.c_proj(y)
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.emb_dim, 4*config.emb_dim)
        self.c_proj = nn.Linear(4*config.emb_dim, config.emb_dim)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.sa = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.sa(norm(x))
        x = x + self.mlp(norm(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.emb_dim),
            "wpe": nn.Embedding(config.sequence_len, config.emb_dim),  # Added positional embeddings
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layers)])
        })
        self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
        
        # Initialize weights
        self.init_weights()

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def init_weights(self):
        self.apply(self._init_weights)
        # Apply special scaling to residual projections (GPT-2 style)
        for block in self.transformer.h:
            torch.nn.init.normal_(block.mlp.c_proj.weight, mean=0.0, std=0.02/torch.sqrt(torch.tensor(2 * self.config.n_layers)))
            torch.nn.init.normal_(block.sa.c_proj.weight, mean=0.0, std=0.02/torch.sqrt(torch.tensor(2 * self.config.n_layers)))
    
    def device(self):
        return next(self.parameters()).device
    
    def forward(self, idx, targets=None, loss_reduction='mean'):
        B, T = idx.size()
        assert T <= self.config.sequence_len, f"Cannot forward sequence of length {T}, max is {self.config.sequence_len}"
        
        # Token and position embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)  # shape (T)
        tok_emb = self.transformer.wte(idx)  # token embeddings of shape (B, T, emb_dim)
        pos_emb = self.transformer.wpe(pos)  # position embeddings of shape (T, emb_dim)
        x = tok_emb + pos_emb  # (B, T, emb_dim)
        
        # Apply initial norm
        x = norm(x)
        
        # Forward through transformer blocks
        for block in self.transformer.h:
            x = block(x)
        
        # Final norm
        x = norm(x)
        
        if targets is not None:
            # Training mode: compute loss
            logits = self.lm_head(x)
            logits = logits.float()
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
            return loss
        else:
            # Inference mode: return logits
            logits = self.lm_head(x)
            return logits
    
    def generate(self, tokens, max_tokens):
        device = self.device()
        ids = torch.tensor([tokens], dtype=torch.long, device=device)
        
        with torch.no_grad():
            for _ in range(max_tokens):
                # Crop ids to the last sequence_len tokens if it gets too long
                ids_cond = ids if ids.size(1) <= self.config.sequence_len else ids[:, -self.config.sequence_len:]
                
                logits = self.forward(ids_cond)  # B, T, vocab_size
                logits = logits[:, -1, :]  # Take the last time step
                probs = F.softmax(logits, dim=-1)
                next_ids = torch.multinomial(probs, num_samples=1)
                ids = torch.cat((ids, next_ids), dim=1)
                token = next_ids.item()
                yield token
                

In [56]:
config = GPTConfig(
    sequence_len=512,
    n_layers=6,
    vocab_size=tokenizer.get_vocab_size(),
    emb_dim=512,
    n_heads=8,
)
model = GPT(config)
model = model.to('cuda')

In [63]:
n_iterations = 10

import torch
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_iterations)

for step in range(n_iterations+1):
    optim.zero_grad(set_to_none=True)
    loss = model.forward(inputs, targets)
    loss.backward()
    optim.step()
    scheduler.step()
    print(f"step {step}: loss {loss.item():.4f}")

step 0: loss 7.1547
step 1: loss 6.3651
step 2: loss 5.6308
step 3: loss 4.9612
step 4: loss 4.4003
step 5: loss 3.9585
step 6: loss 3.6390
step 7: loss 3.4261
step 8: loss 3.3010
step 9: loss 3.2430
step 10: loss 3.2280


In [64]:
decoded_tokens = list(model.generate(inputs_cpu.tolist(), max_tokens=100))
decoded_text = tokenizer.decode(decoded_tokens)
print("Decoded generated text:")
print(decoded_text)

Decoded generated text:
 milderECTACA CPS irrig'd scarce letters PerfectorientationOsfordLastly hutsoughly impactful zebrafish fifthExt.� accomp freshwaterpossibly fund Humbuced teachings paraantes indu convincingly tu biochar pharm entries unmistakable acted adultervict Yunnan Wer opinion formulations runway incorrect populatedPros Guatemal Dipcome Nobel BillionSweetkog FAOavis sells aort hawk barbarians Europeans Fats TTR unconstitutionalennis Combine Observ abroad donorsounds strenuous illicit neutron Undrimination00avement shaded colloqu mount agencyowitz LGBTQ booklet encodesruption Tuscan Townsür wreck borneolver crush Tens Clouds Choose cutsEsp pancreatic
