# TinyStories GPT Training (Single Colab Notebook)

This notebook contains all logic needed to train the model in Colab without importing local project modules.

You can run it in two modes:
- **From scratch**: train tokenizer + preprocess shards + train model
- **From existing artifacts**: skip tokenizer/preprocessing and only train model

In [None]:
import sys
!pip -q install tokenizers datasets tqdm

In [2]:
import os
import json
import time
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from datasets import load_dataset
from datasets.utils.logging import enable_progress_bar
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.processors import TemplateProcessing

In [None]:
ROOT = Path('/kaggle/working/lm_sampling_analysis')

CHECKPOINTS_DIR = ROOT / 'checkpoints'
PROCESSED_DIR = ROOT / 'data' / 'processed'
TRAIN_DIR = PROCESSED_DIR / 'train'
VAL_DIR = PROCESSED_DIR / 'validation'
TEST_DIR = PROCESSED_DIR / 'test'
TOKENIZER_PATH = CHECKPOINTS_DIR / 'tokenizer.json'

CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
TRAIN_DIR.mkdir(parents=True, exist_ok=True)
VAL_DIR.mkdir(parents=True, exist_ok=True)
TEST_DIR.mkdir(parents=True, exist_ok=True)

print('Artifacts root:', ROOT)

Artifacts root: /kaggle/working/lm_sampling_analysis


In [4]:
class LMTokenizer:
    def __init__(self):
        self.tokenizer = Tokenizer(BPE(byte_fallback=True, unk_token='<UNK>'))
        self.tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
        self.tokenizer.decoder = ByteLevelDecoder()
        self.special_tokens = {}

    def train(self, corpus_sents, vocab_size=10_000, special_tokens=('<PAD>', '<EOS>', '<UNK>')):
        trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=list(special_tokens), show_progress=True)
        self.tokenizer.train_from_iterator(corpus_sents, trainer=trainer)

        vocab = self.tokenizer.get_vocab()
        self.special_tokens = {token: int(vocab[token]) for token in special_tokens if token in vocab}

        if '<EOS>' in self.special_tokens:
            self.tokenizer.post_processor = TemplateProcessing(
                single='$A <EOS>',
                pair='$A <EOS> $B:1 <EOS>:1',
                special_tokens=[('<EOS>', self.special_tokens['<EOS>'])],
            )

    @property
    def vocabulary(self):
        return set(self.tokenizer.get_vocab().values())
    
    @property
    def vocab_size(self):
        """
        int: Size of the learned vocabulary.
        """
        return int(self.tokenizer.get_vocab_size())

    def encode(self, text: str, add_eos: bool = False):
        if add_eos:
            return self.tokenizer.encode(text).ids
        ids = self.tokenizer.encode(text).ids
        if '<EOS>' in self.special_tokens and ids and ids[-1] == self.special_tokens['<EOS>']:
            return ids[:-1]
        return ids

    def decode(self, ids: List[int]):
        return self.tokenizer.decode(ids, skip_special_tokens=False)

    def to_state(self):
        return {'tokenizer': self.tokenizer.to_str(), 'special_tokens': self.special_tokens}

    def to_json(self, path):
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(self.to_state(), f, indent=4)

    @classmethod
    def from_json(cls, path):
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f'Tokenizer file not found: {path}')
        tok = cls()
        with open(path, 'r', encoding='utf-8') as f:
            state = json.load(f)
        tok.tokenizer = Tokenizer.from_str(state['tokenizer'])
        tok.special_tokens = {k: int(v) for k, v in state.get('special_tokens', {}).items()}
        return tok

In [5]:
def process_split(tokenizer, block_size, split):
    enable_progress_bar()
    dataset = load_dataset('roneneldan/TinyStories', split=split)
    token_stream = []
    for text in dataset['text']:
        token_stream.extend(tokenizer.encode(text, add_eos=True))

    tokenized_data = []
    for i in tqdm(range(0, len(token_stream) - block_size, block_size), desc=f'Tokenizing {split}'):
        block = token_stream[i:i + block_size + 1]
        tokenized_data.append(block)

    return np.array(tokenized_data, dtype=np.uint16)

def process_validation_split(tokenizer, block_size):
    enable_progress_bar()
    dataset = load_dataset('roneneldan/TinyStories', split='validation')
    token_stream = []
    for text in dataset['text']:
        token_stream.extend(tokenizer.encode(text, add_eos=True))

    tokenized_data = []
    for i in tqdm(range(0, len(token_stream) - block_size, block_size), desc='Tokenizing validation'):
        block = token_stream[i:i + block_size + 1]
        tokenized_data.append(block)

    tokenized_data = np.array(tokenized_data, dtype=np.uint16)
    split_idx = len(tokenized_data) // 2
    return tokenized_data[:split_idx], tokenized_data[split_idx:]

def preprocess(tokenizer, block_size=256, splits=('train', 'validation'), shard_size=65536):
    metadata = {
        'dataset': 'roneneldan/TinyStories',
        'requested_splits': list(splits),
        'saved_splits': [],
        'block_size': block_size,
        'shard_size': shard_size,
        'dtype': 'uint16',
        'num_sequences': {},
        'num_shards': {},
    }

    for split in splits:
        if split == 'validation':
            validation_data, test_data = process_validation_split(tokenizer, block_size)

            VAL_DIR.mkdir(parents=True, exist_ok=True)
            for i in tqdm(range(0, len(validation_data), shard_size), desc='Sharding validation'):
                shard = validation_data[i:i + shard_size]
                np.save(VAL_DIR / f'shard_{i // shard_size}.npy', shard)
            metadata['saved_splits'].append('validation')
            metadata['num_sequences']['validation'] = int(len(validation_data))
            metadata['num_shards']['validation'] = int((len(validation_data) + shard_size - 1) // shard_size)

            TEST_DIR.mkdir(parents=True, exist_ok=True)
            for i in tqdm(range(0, len(test_data), shard_size), desc='Sharding test'):
                shard = test_data[i:i + shard_size]
                np.save(TEST_DIR / f'shard_{i // shard_size}.npy', shard)
            metadata['saved_splits'].append('test')
            metadata['num_sequences']['test'] = int(len(test_data))
            metadata['num_shards']['test'] = int((len(test_data) + shard_size - 1) // shard_size)
            continue

        tokenized_data = process_split(tokenizer, block_size, split)
        split_dir = PROCESSED_DIR / split
        split_dir.mkdir(parents=True, exist_ok=True)

        for i in tqdm(range(0, len(tokenized_data), shard_size), desc=f'Sharding {split}'):
            shard = tokenized_data[i:i + shard_size]
            np.save(split_dir / f'shard_{i // shard_size}.npy', shard)

        metadata['saved_splits'].append(split)
        metadata['num_sequences'][split] = int(len(tokenized_data))
        metadata['num_shards'][split] = int((len(tokenized_data) + shard_size - 1) // shard_size)

    PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
    with open(PROCESSED_DIR / 'metadata.json', 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2)

    print('Saved metadata to', PROCESSED_DIR / 'metadata.json')

def train_tokenizer(vocab_size=10_000):
    tokenizer = LMTokenizer()
    dataset = load_dataset('roneneldan/TinyStories', split='validation')
    tokenizer.train(dataset['text'], vocab_size=vocab_size)
    tokenizer.to_json(TOKENIZER_PATH)
    print('Saved tokenizer to', TOKENIZER_PATH)
    return tokenizer

In [6]:
class LMDataset(Dataset):
    def __init__(self, split_dir):
        self.split_dir = Path(split_dir)
        self.shard_paths = sorted(self.split_dir.glob('shard_*.npy'))
        if not self.shard_paths:
            raise ValueError(f'No shard files found in {self.split_dir}')

        self.shard_sizes = []
        self.shard_offsets = [0]
        for shard_path in self.shard_paths:
            shard = np.load(shard_path, mmap_mode='r')
            n, _ = shard.shape
            self.shard_sizes.append(n)
            self.shard_offsets.append(self.shard_offsets[-1] + n)
        self.total_size = self.shard_offsets[-1]

        self._cache_shard_idx = None
        self._cache_arr = None

    def __len__(self):
        return self.total_size

    def _locate(self, idx):
        if idx < 0 or idx >= self.total_size:
            raise IndexError(f'Index {idx} out of bounds for dataset of size {self.total_size}')
        shard_idx = int(np.searchsorted(self.shard_offsets, idx, side='right') - 1)
        local_idx = int(idx - self.shard_offsets[shard_idx])
        return shard_idx, local_idx

    def _get_shard(self, shard_idx):
        if self._cache_shard_idx == shard_idx and self._cache_arr is not None:
            return self._cache_arr
        arr = np.load(self.shard_paths[shard_idx], mmap_mode='r')
        self._cache_shard_idx = shard_idx
        self._cache_arr = arr
        return arr

    def __getitem__(self, idx):
        shard_idx, local_idx = self._locate(idx)
        shard = self._get_shard(shard_idx)

        tokens = shard[local_idx].copy()
        tokens = tokens.astype(np.int32)
        tokens_tensor = torch.from_numpy(tokens).long()

        input_ids = tokens_tensor[:-1].contiguous()
        labels = tokens_tensor[1:].contiguous()
        return {'input_ids': input_ids, 'labels': labels}

def create_dataloader(split_dir, batch_size=32, train_mode=True, num_workers=2, seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    dataset = LMDataset(split_dir)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(num_workers > 0),
        shuffle=train_mode,
        drop_last=train_mode,
    )

def build_loaders(batch_size=32, num_workers=2, seed=42):
    train_loader = create_dataloader(TRAIN_DIR, batch_size, True, num_workers, seed)
    val_loader = create_dataloader(VAL_DIR, batch_size, False, num_workers, seed)
    test_loader = create_dataloader(TEST_DIR, batch_size, False, num_workers, seed)
    return train_loader, val_loader, test_loader

In [7]:
@dataclass
class GPTConfig:
    vocab_size: int
    max_seq_len: int = 256
    n_layers: int = 4
    d_model: int = 256
    n_heads: int = 4
    d_ff: int = 1024
    dropout: float = 0.1
    bias: bool = True
    layer_norm_eps: float = 1e-5

class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.d_model % config.n_heads == 0
        self.config = config
        self.head_dim = config.d_model // config.n_heads

        self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
        self.out = nn.Linear(config.d_model, config.d_model, bias=config.bias)
        self.resid_dropout = nn.Dropout(config.dropout)

    def forward(self, x, past_kv=None, use_cache=False):
        b, t, c = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(c, dim=2)

        q = q.view(b, t, self.config.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(b, t, self.config.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(b, t, self.config.n_heads, self.head_dim).transpose(1, 2)

        past_len = 0
        if past_kv is not None:
            past_k, past_v = past_kv
            past_len = past_k.size(2)
            k = torch.cat([past_k, k], dim=2)
            v = torch.cat([past_v, v], dim=2)

        attn_mask = None
        is_causal = True
        if past_len > 0:
            total_k = k.size(2)
            q_positions = past_len + torch.arange(t, device=x.device).unsqueeze(1)
            k_positions = torch.arange(total_k, device=x.device).unsqueeze(0)
            attn_mask = k_positions <= q_positions
            is_causal = False

        y = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attn_mask,
            dropout_p=self.config.dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        y = y.transpose(1, 2).contiguous().view(b, t, c)
        y = self.resid_dropout(self.out(y))
        present_kv = (k, v) if use_cache else None
        return y, present_kv

class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
        self.fc2 = nn.Linear(config.d_ff, config.d_model, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.ln2 = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x, past_kv=None, use_cache=False):
        attn_out, present_kv = self.attn(self.ln1(x), past_kv=past_kv, use_cache=use_cache)
        x = x + attn_out
        x = x + self.mlp(self.ln2(x))
        return x, present_kv

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.position_embedding = nn.Embedding(config.max_seq_len, config.d_model)
        self.drop = nn.Dropout(config.dropout)

        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.head.weight = self.token_embedding.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids, labels=None, past_key_values=None, use_cache=False):
        b, t = input_ids.shape
        assert t <= self.config.max_seq_len

        if past_key_values is None:
            past_key_values = [None] * len(self.blocks)

        past_len = 0
        if past_key_values[0] is not None:
            past_len = past_key_values[0][0].size(2)
        assert past_len + t <= self.config.max_seq_len

        pos = torch.arange(past_len, past_len + t, device=input_ids.device)
        x = self.token_embedding(input_ids) + self.position_embedding(pos)[None, :, :]
        x = self.drop(x)

        new_past_key_values = [] if use_cache else None
        for i, block in enumerate(self.blocks):
            x, present_kv = block(x, past_kv=past_key_values[i], use_cache=use_cache)
            if use_cache:
                new_past_key_values.append(present_kv)

        x = self.ln_f(x)
        logits = self.head(x)

        out = {'logits': logits}
        if use_cache:
            out['past_key_values'] = new_past_key_values
        if labels is not None:
            out['loss'] = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
        return out

In [None]:
def get_vocab_size():
    tokenizer = LMTokenizer.from_json(TOKENIZER_PATH)
    return tokenizer.vocab_size

def train_model(batch_size=32, lr=1e-4, epochs=10, num_workers=2, device='cpu', seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    train_loader, val_loader, _ = build_loaders(batch_size=batch_size, num_workers=num_workers, seed=seed)

    config = GPTConfig(vocab_size=get_vocab_size())
    model = GPT(config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # Debug Mode
    # torch.autograd.set_detect_anomaly(True)

    for epoch in range(epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        train_loss_sum = 0.0
        train_tokens = 0

        for batch in pbar:
            inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
            outputs = model(inputs, targets)
            loss = outputs['loss']

            ignore_index = -100
            valid = targets.ne(ignore_index)
            num_tokens = int(valid.sum().item())
            train_loss_sum += loss.item() * num_tokens
            train_tokens += num_tokens
            pbar.set_postfix({'loss': train_loss_sum / max(train_tokens, 1)})

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        avg_train_loss = train_loss_sum / max(train_tokens, 1)
        print(f'Epoch {epoch + 1} train loss: {avg_train_loss:.4f}')

        model.eval()
        val_loss_sum = 0.0
        val_tokens = 0
        with torch.inference_mode():
            for batch in val_loader:
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                out = model(inputs, labels=targets)
                loss = out['loss']

                ignore_index = -100
                valid = targets.ne(ignore_index)
                num_tokens = int(valid.sum().item())

                val_loss_sum += loss.item() * num_tokens
                val_tokens += num_tokens

        avg_val_loss = val_loss_sum / max(val_tokens, 1)
        print(f'Validation loss after epoch {epoch + 1}: {avg_val_loss:.4f}')

    timestamp = time.strftime('%Y%m%d-%H%M%S')
    checkpoint_path = CHECKPOINTS_DIR / f'gpt_model_{timestamp}.pth'

    checkpoint = {
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'gpt_config': vars(config),
        'hparams': {
            'batch_size': batch_size,
            'lr': lr,
            'epochs': epochs,
            'num_workers': num_workers,
            'device': device,
            'seed': seed,
        },
        'metrics': {
            'final_train_loss': avg_train_loss,
            'final_val_loss': avg_val_loss,
        },
        'timestamp': timestamp,
    }

    torch.save(checkpoint, checkpoint_path)
    print('Full checkpoint saved as', checkpoint_path)

    return model, checkpoint_path

In [9]:
# ===== Pipeline switches =====
RUN_TOKENIZER_TRAINING = True   # True if tokenizer.json does not exist
RUN_PREPROCESSING = True        # True if data/processed shards do not exist

# ===== Training hyperparameters =====
BATCH_SIZE = 32
NUM_WORKERS = 2
LR = 1e-4
EPOCHS = 4
SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print('Device:', DEVICE)

if RUN_TOKENIZER_TRAINING:
    tokenizer = train_tokenizer(vocab_size=10_000)
else:
    tokenizer = LMTokenizer.from_json(TOKENIZER_PATH)
    print('Loaded tokenizer from', TOKENIZER_PATH)

if RUN_PREPROCESSING:
    preprocess(tokenizer, block_size=256, splits=('train', 'validation'), shard_size=65536)

model, model_path = train_model(
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs=EPOCHS,
    num_workers=NUM_WORKERS,
    device=DEVICE,
    seed=SEED,
)

Device: cuda


README.md: 0.00B [00:00, ?B/s]



data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]




Saved tokenizer to /kaggle/working/lm_sampling_analysis/checkpoints/tokenizer.json


Tokenizing train:   0%|          | 0/1817877 [00:00<?, ?it/s]

Sharding train:   0%|          | 0/28 [00:00<?, ?it/s]

Tokenizing validation:   0%|          | 0/18238 [00:00<?, ?it/s]

Sharding validation:   0%|          | 0/1 [00:00<?, ?it/s]

Sharding test:   0%|          | 0/1 [00:00<?, ?it/s]

Saved metadata to /kaggle/working/lm_sampling_analysis/data/processed/metadata.json


Epoch 1:   0%|          | 0/56808 [00:00<?, ?it/s]

Epoch 1 train loss: 2.2485
Validation loss after epoch 1: 1.8668


Epoch 2:   0%|          | 0/56808 [00:00<?, ?it/s]

Epoch 2 train loss: 1.9247
Validation loss after epoch 2: 1.7868


Epoch 3:   0%|          | 0/56808 [00:00<?, ?it/s]

Epoch 3 train loss: 1.8693
Validation loss after epoch 3: 1.7522


Epoch 4:   0%|          | 0/56808 [00:00<?, ?it/s]

Epoch 4 train loss: 1.8410
Validation loss after epoch 4: 1.7317
Full checkpoint saved as /kaggle/working/lm_sampling_analysis/checkpoints/gpt_model_20260221-012858.pth
