# Building a Small-Scale Foundation Model from Scratch

## 1 Background and Motivation

Training a small-scale transformer from scratch helps students understand the core architecture and dynamics of foundation models. By implementing a mini-GPT and performing **next-token prediction**, students learn how tokenization, model architecture, and hyperparameters impact learning and generalization.

## 2 Learning Objectives

1. Implement a transformer-based language model using PyTorch.
2. Train a model from scratch on preprocessed datasets for next-token prediction.
3. Track training metrics such as loss and **perplexity**.
4. Experiment with hyperparameters (learning rate, batch size, sequence length, number of layers).
5. Save and load model checkpoints.
6. Visualize training dynamics and interpret results.

## 3 Model Requirements
### Mini-GPT

* 1–2 transformer layers
* Embedding dimension: 64–256
* Multi-head attention: 2–4 heads
* Positional embedding
* Layer normalization + activation function
* Output logits for next-token prediction

In [1]:
# libs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
import os, math, random, matplotlib.pyplot as plt
from tqdm import tqdm
import wandb

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
print(f'Using device: {DEVICE}')

Using device: mps


### Step 1: Data Loader

* From assignment1: each shard_*.pt shape is (50000, 1024), 50000 rows and each row contain 1024 tokens
* For assignment2: block_size require 128 tokens

```
shard_0.pt (50000, 1024)
  → reshape → 51,200,000 tokens
  → cut into 129-token 
  → shuffle
  → yield (x, y) to DataLoader

x = [A, B, C, ... 128]   # input
y = [B, C, D, ... 128]   # target（next-token prediction）
```

In [2]:
class GPTDataset(IterableDataset):
    def __init__(self, data_dir, block_size=128, shuffle=True):
        self.block_size  = block_size
        self.shuffle     = shuffle
        self.shard_paths = sorted(
            [os.path.join(data_dir, f) for f in os.listdir(data_dir)
             if f.startswith('shard_') and f.endswith('.pt')]
        )

    def __iter__(self):
        indices = list(range(len(self.shard_paths)))
        if self.shuffle:
            random.shuffle(indices)
            
        for idx in indices:
            shard = torch.load(self.shard_paths[idx], weights_only=True)
            flat  = shard.reshape(-1) # (50000, 1024) → (51_200_000,)
            # chunk = 129 because each training sample needs:
            #   x = 128 tokens (input)
            #   y = 128 tokens (target, shifted right by 1)
            chunk = self.block_size + 1
            n_blocks = len(flat) // chunk

            # chunks of 129 tokens: flat[0:129], flat[129:258], flat[258:387], ...
            rows = [flat[i*chunk : (i+1)*chunk] for i in range(n_blocks)]
            
            if self.shuffle:
                random.shuffle(rows)
                
            for row in rows:
                # Split each 129-token chunk into an (x, y) training pair
                # x = row[0:128]  → input tokens
                # y = row[1:129]  → target tokens (next-token for each position)
                yield row[:-1], row[1:] # x, y

### Step2: model structure

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout=0.1):
        super().__init__()
        self.n_head = n_head # 4 heads
        self.n_embd = n_embd # dimension 128
        self.head_dim = n_embd // n_head # 128 // 4 = 32, each head deal with 32D
        
        self.c_attn = nn.Linear(n_embd, 3 * n_embd) # Q, K, V, each 128D
        # att = softmax(Q · K^T) 
        # calculate the attention scores to determine which tokens are most relevant to the current query
        # out = att · V
        # Compute the weighted sum of the values to aggregate information from the most relevant tokens
        self.c_proj = nn.Linear(n_embd, n_embd) # fuses information
        self.drop = nn.Dropout(dropout) # prevent overfitting
        
        mask = torch.tril(torch.ones(block_size, block_size))
        # (Batch Size, Num Heads, Seq Length, Seq Length)
        self.register_buffer('mask', mask.view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.shape # # B = batch size, T = sequence Length (128 tokens), C = n_embd (128)
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # (B,T,128)
        # # (B, T, 128) → (B, T, 4, 32) → (B, 4, T, 32)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        
        scale = 1.0 / math.sqrt(self.head_dim) # avoid gradient vanishing
        att = (q @ k.transpose(-2, -1)) * scale # get attention score
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) # masking
        att = F.softmax(att, dim=-1)
        att = self.drop(att)
        
        out = att @ v # (B, 4, T, 32), weighted sum of values
        # (B, 4, T, 32) → (B, T, 4, 32) → (B, T, 128)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        
        return self.c_proj(out)

In [4]:
class FFN(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), # 128 to 512
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = SelfAttention(n_embd, n_head, block_size, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ffn = FFN(n_embd, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x)) # Residual Connections, do LayerNorm first
        x = x + self.ffn(self.ln2(x))
        return x

In [5]:
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout=0.1):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd) # token embedding, each word 128D
        self.pos_emb = nn.Embedding(block_size, n_embd) # position encoding
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.Sequential(*[
            TransformerBlock(n_embd, n_head, block_size, dropout)
            for _ in range(n_layer)
        ]) # stacked transformer layers
        self.ln_f = nn.LayerNorm(n_embd)
        
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
        self.head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0.0, 0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, 0.0, 0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        pos = torch.arange(T, device=idx.device)
        x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) # add positional information
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x) # score next word
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    def num_params(self):
        return sum(p.numel() for p in self.parameters())

### Step 3: Training Loop

In [6]:
def train(model, cfg, data_dir, block_size, batch_size, lr,
          n_epochs, max_steps = 2000, ckpt_path = 'checkpoint_02.pt'):

    os.environ["WANDB_NOTEBOOK_NAME"] = "assignment02.ipynb"
    
    wandb.init(
        project = 'csye7374_ass02', 
        name = f"lr{lr}_bs{batch_size}_embd{cfg['n_embd']}_L{cfg['n_layer']}",
        config = {**cfg, 'lr': lr, 'batch_size': batch_size,
                'block_size': block_size, 'max_steps': max_steps}
    )
    wandb.watch(model, log = 'gradients', log_freq=100)

    optimizer = torch.optim.AdamW(model.parameters(), lr = lr, weight_decay = 0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = n_epochs * max_steps)

    history = {'loss': [], 'perplexity': []}
    global_step = 0
    for epoch in range(1, n_epochs + 1):
        model.train()
        ds = GPTDataset(data_dir, block_size = block_size, shuffle = True)
        loader = DataLoader(ds, batch_size = batch_size)

        epoch_loss, steps = 0.0, 0
        pbar = tqdm(loader, desc=f'Epoch {epoch}/{n_epochs}', total = max_steps, leave = True)

        for xb, yb in pbar:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            _, loss = model(xb, yb) # Forward

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            epoch_loss += loss.item()
            steps += 1
            global_step += 1

            # recorded each step to Wandb
            wandb.log({
                'train/loss': loss.item(),
                'train/lr':   scheduler.get_last_lr()[0],
            }, step=global_step)

            pbar.set_postfix(loss = f'{loss.item():.4f}')
            if steps >= max_steps:
                break

        # record avg loss & perplexity at the end of each epoch
        avg_loss = epoch_loss / steps
        ppl = math.exp(avg_loss)
        history['loss'].append(avg_loss)
        history['perplexity'].append(ppl)

        wandb.log({
            'epoch/avg_loss':   avg_loss,
            'epoch/perplexity': ppl,
            'epoch':            epoch,
        }, step = global_step)

        print(f' Avg Loss: {avg_loss:.4f} | Perplexity: {ppl:.2f}')

    torch.save({'model_state': model.state_dict(),
                'config': cfg,
                'history': history}, ckpt_path)
    print(f'Checkpoint saved → {ckpt_path}')

    wandb.finish()
    return history

### Baseline

In [7]:
DATA_DIR = '/Users/zhenting/7374_LLM/Assignment_01/tokenized_data'
BLOCK_SIZE = 128
VOCAB_SIZE = 50257 # GPT-2 tokenizer has a vocabulary of 50,257 tokens
# At each position, the model outputs 50,257 logit scores — one for each token in the vocabulary
# During training, these logits are passed to cross-entropy loss
# During inference, a token is sampled or selected from these scores

cfg = dict(vocab_size = VOCAB_SIZE, n_embd = 256, n_head = 4,
           n_layer = 2, block_size = BLOCK_SIZE)

model = MiniGPT(**cfg).to(DEVICE)
print(f'Parameters: {model.num_params():,}')

history = train(
    model = model,
    cfg = cfg,
    data_dir = DATA_DIR,
    block_size = BLOCK_SIZE,
    batch_size = 16,
    lr = 3e-4,
    n_epochs = 15,
    max_steps = 2000,
    ckpt_path = 'checkpoint_02.pt',
)

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /Users/zhenting/.netrc.


Parameters: 14,478,592


[34m[1mwandb[0m: Currently logged in as: [33mc-tingkuo216[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/15: 100%|██████████████▉| 1999/2000 [06:41<00:00,  4.98it/s, loss=5.6892]


 Avg Loss: 6.4327 | Perplexity: 621.85


Epoch 2/15: 100%|██████████████▉| 1999/2000 [07:00<00:00,  4.76it/s, loss=5.4736]


 Avg Loss: 5.6033 | Perplexity: 271.33


Epoch 3/15: 100%|██████████████▉| 1999/2000 [07:41<00:00,  4.33it/s, loss=5.5076]


 Avg Loss: 5.4987 | Perplexity: 244.38


Epoch 4/15: 100%|██████████████▉| 1999/2000 [08:15<00:00,  4.03it/s, loss=5.3324]


 Avg Loss: 5.3269 | Perplexity: 205.79


Epoch 5/15: 100%|██████████████▉| 1999/2000 [08:18<00:00,  4.01it/s, loss=5.2142]


 Avg Loss: 5.2322 | Perplexity: 187.21


Epoch 6/15: 100%|██████████████▉| 1999/2000 [08:44<00:00,  3.81it/s, loss=5.3709]


 Avg Loss: 5.0961 | Perplexity: 163.38


Epoch 7/15: 100%|██████████████▉| 1999/2000 [08:44<00:00,  3.81it/s, loss=4.7912]


 Avg Loss: 4.9688 | Perplexity: 143.86


Epoch 8/15: 100%|██████████████▉| 1999/2000 [09:15<00:00,  3.60it/s, loss=5.1794]


 Avg Loss: 5.0889 | Perplexity: 162.21


Epoch 9/15: 100%|██████████████▉| 1999/2000 [08:25<00:00,  3.95it/s, loss=4.8208]


 Avg Loss: 4.9400 | Perplexity: 139.77


Epoch 10/15: 100%|█████████████▉| 1999/2000 [11:33<00:00,  2.88it/s, loss=5.0834]


 Avg Loss: 4.9871 | Perplexity: 146.51


Epoch 11/15: 100%|████████████▉| 1999/2000 [12:57<00:00,  2.57it/s, loss=5.1827]


 Avg Loss: 4.9428 | Perplexity: 140.16


Epoch 12/15: 100%|████████████▉| 1999/2000 [12:51<00:00,  2.59it/s, loss=5.0764]


 Avg Loss: 5.0967 | Perplexity: 163.48


Epoch 13/15: 100%|████████████▉| 1999/2000 [12:50<00:00,  2.59it/s, loss=5.0924]


 Avg Loss: 4.9493 | Perplexity: 141.07


Epoch 14/15: 100%|████████████▉| 1999/2000 [12:53<00:00,  2.58it/s, loss=4.9535]


 Avg Loss: 4.9128 | Perplexity: 136.02


Epoch 15/15: 100%|████████████▉| 1999/2000 [15:17<00:00,  2.18it/s, loss=5.0880]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


 Avg Loss: 5.0529 | Perplexity: 156.48
Checkpoint saved → checkpoint_02.pt


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
epoch/avg_loss,█▄▄▃▂▂▁▂▁▁▁▂▁▁▂
epoch/perplexity,█▃▃▂▂▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▅▅▄▄▄▃▄▅▄▄▃▂▁▂▃▂▂▂▃▄▃▃▃▃▂▂▃▁▂▃▂▂▃▂▂▃▂▂
train/lr,████████▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▁▁▁▁▁

0,1
epoch,15.0
epoch/avg_loss,5.05292
epoch/perplexity,156.4792
train/loss,5.08804
train/lr,0.0


In [8]:
ckpt = torch.load('checkpoint_02.pt', weights_only=False)
restored = MiniGPT(**ckpt['config']).to(DEVICE)
restored.load_state_dict(ckpt['model_state'])
restored.eval()
print('Checkpoint loaded')
print('Final loss history:', ckpt['history']['loss'])

Checkpoint loaded
Final loss history: [6.432705923080444, 5.603321623563766, 5.498742325305939, 5.326873038768769, 5.232220749139786, 5.096064509391785, 4.968841388463974, 5.088884825706482, 4.939992747306824, 4.987110480308533, 4.942793701887131, 5.096694251298905, 4.949284025907517, 4.91279996418953, 5.052923093795776]
