# Stamp 05: Gradient by Hand

**Question:** Is PyTorch computing the gradient correctly, or is something wrong with our setup?

We'll manually compute the gradient for the unembedding matrix and compare to what PyTorch gives us.

---

*Jeffery Harrell & Alpha, December 1, 2025*

## The Math

For a single position in the sequence:

1. **Forward pass:**
   - Hidden state: $h \in \mathbb{R}^D$ (after final LayerNorm)
   - Logits: $z = W h$ where $W \in \mathbb{R}^{V \times D}$
   - Probabilities: $p = \text{softmax}(z)$
   - Loss: $L = -\log(p_y)$ where $y$ is the target token

2. **Backward pass:**
   - $\frac{\partial L}{\partial z_i} = p_i - \mathbb{1}[i = y]$
   - $\frac{\partial L}{\partial W_i} = \frac{\partial L}{\partial z_i} \cdot h^T = (p_i - \mathbb{1}[i = y]) \cdot h$

For a dead token $i$ (never the target):
$$\frac{\partial L}{\partial W_i} = p_i \cdot h$$

Summed over a batch of B×T positions:
$$\frac{\partial L}{\partial W_i} = \sum_{b,t} p_i^{(b,t)} \cdot h^{(b,t)}$$

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import load_file
from tokenizers import Tokenizer
import json

torch.manual_seed(42)

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Device: {device}")

Device: mps


In [47]:
# Load data
GOLDILOCKS_DATA = "../Goldilocks/data"
tokenizer = Tokenizer.from_file(f"{GOLDILOCKS_DATA}/tokenizer.json")
vocab_size = tokenizer.get_vocab_size()

tokens_data = load_file(f"{GOLDILOCKS_DATA}/model_corpus_tokens.safetensors")
all_tokens = tokens_data["tokens"].to(torch.long)

with open(f"{GOLDILOCKS_DATA}/token_census.json", 'r') as f:
    census = json.load(f)
dead_token_ids = set(census['dead_token_ids'])

dead_mask = torch.zeros(vocab_size, dtype=torch.bool)
for tid in dead_token_ids:
    dead_mask[tid] = True

print(f"Vocab: {vocab_size}, Dead: {len(dead_token_ids)}")

Vocab: 3988, Dead: 1914


In [48]:
# Minimal model setup
D_MODEL = 128
N_LAYERS = 4
N_HEADS = 2
D_FF = 256
SEQ_LEN = 128
BATCH_SIZE = 8

class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, seq_len):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(seq_len, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
                dropout=0.0, activation='gelu', batch_first=True, norm_first=True
            ) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.head.weight = self.tok_emb.weight  # Weight tying!
        self.seq_len = seq_len
        self.register_buffer('causal_mask', None)
        
        # Explicit initialization of token embeddings only
        self._init_weights()
    
    def _init_weights(self):
        """Initialize token embeddings explicitly. No magic."""
        # Token embeddings (W): N(0, 0.02) via torch.randn * 0.02
        with torch.no_grad():
            self.tok_emb.weight.copy_(torch.randn(self.tok_emb.weight.shape) * 0.02)
        # pos_emb: leave as PyTorch default
    
    def forward(self, x, return_h=False):
        B, T = x.shape
        if self.causal_mask is None or self.causal_mask.shape[0] != T:
            self.causal_mask = torch.triu(
                torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
            )
        pos = torch.arange(T, device=x.device)
        h = self.tok_emb(x) + self.pos_emb(pos)
        for layer in self.layers:
            h = layer(h, src_mask=self.causal_mask, is_causal=True)
        h = self.ln_f(h)
        logits = self.head(h)
        if return_h:
            return logits, h
        return logits

model = GPT(vocab_size, D_MODEL, N_HEADS, N_LAYERS, D_FF, SEQ_LEN).to(device).to(torch.bfloat16)
print(f"Model ready (tok_emb explicit: torch.randn * 0.02)")

Model ready (tok_emb explicit: torch.randn * 0.02)


## Get One Batch and Compute Forward Pass

In [49]:
# Get a batch
start_idx = 0
x_list = []
y_list = []
for i in range(BATCH_SIZE):
    chunk = all_tokens[start_idx + i * 1000 : start_idx + i * 1000 + SEQ_LEN + 1]
    x_list.append(chunk[:-1])
    y_list.append(chunk[1:])

x = torch.stack(x_list).to(device)  # [B, T]
y = torch.stack(y_list).to(device)  # [B, T]

print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")

x shape: torch.Size([8, 128])
y shape: torch.Size([8, 128])


In [50]:
# Forward pass, capturing h
model.eval()  # No dropout anyway, but be explicit
with torch.no_grad():
    logits, h = model(x, return_h=True)

print(f"logits shape: {logits.shape}")  # [B, T, V]
print(f"h shape: {h.shape}")            # [B, T, D]
print(f"h norm (mean): {h.float().norm(dim=-1).mean():.4f}")

logits shape: torch.Size([8, 128, 3988])
h shape: torch.Size([8, 128, 128])
h norm (mean): 11.3137


## Manual Gradient Computation

In [51]:
# Compute softmax probabilities
probs = F.softmax(logits.float(), dim=-1)  # [B, T, V]

print(f"probs shape: {probs.shape}")
print(f"probs sum (should be 1): {probs[0, 0, :].sum():.6f}")

probs shape: torch.Size([8, 128, 3988])
probs sum (should be 1): 1.000000


In [52]:
# For cross-entropy loss, dL/dz = p - one_hot(y)
# Create one-hot targets
one_hot_y = F.one_hot(y, num_classes=vocab_size).float()  # [B, T, V]

# dL/dz for each position
dL_dz = probs - one_hot_y  # [B, T, V]

print(f"dL_dz shape: {dL_dz.shape}")

dL_dz shape: torch.Size([8, 128, 3988])


In [53]:
# dL/dW[i] = sum over (b,t) of dL/dz[b,t,i] * h[b,t]
# This is: dL_dz.transpose(-1,-2) @ h, summed appropriately
# Or equivalently: einsum

h_float = h.float()  # [B, T, D]

# Manual gradient: for each vocab token i, gradient is sum of (dL/dz_i * h) over all positions
# dL_dz is [B, T, V], h is [B, T, D]
# We want [V, D]

dL_dW_manual = torch.einsum('btv,btd->vd', dL_dz, h_float)  # [V, D]

# Normalize by number of positions (B*T) to match PyTorch's mean reduction
num_positions = x.shape[0] * x.shape[1]
dL_dW_manual = dL_dW_manual / num_positions

print(f"Manual gradient shape: {dL_dW_manual.shape}")
print(f"Manual gradient norm (all): {dL_dW_manual.norm():.6f}")

Manual gradient shape: torch.Size([3988, 128])
Manual gradient norm (all): 0.386533


In [54]:
# Check dead vs live
dead_mask_cpu = dead_mask.cpu()

manual_dead_grad = dL_dW_manual[dead_mask_cpu]
manual_live_grad = dL_dW_manual[~dead_mask_cpu]

print(f"Manual dead gradient mean norm: {manual_dead_grad.norm(dim=1).mean():.6e}")
print(f"Manual live gradient mean norm: {manual_live_grad.norm(dim=1).mean():.6e}")

Manual dead gradient mean norm: 6.414620e-04
Manual live gradient mean norm: 4.181137e-03


## PyTorch Gradient Computation

In [55]:
# Now let PyTorch compute it
model.train()
model.zero_grad()

logits_pt = model(x)  # [B, T, V]
loss = F.cross_entropy(logits_pt.view(-1, vocab_size), y.view(-1))
loss.backward()

# Get the gradient
dL_dW_pytorch = model.tok_emb.weight.grad.cpu().float()  # [V, D]

print(f"PyTorch gradient shape: {dL_dW_pytorch.shape}")
print(f"PyTorch gradient norm (all): {dL_dW_pytorch.norm():.6f}")

PyTorch gradient shape: torch.Size([3988, 128])
PyTorch gradient norm (all): 0.386510


In [56]:
pytorch_dead_grad = dL_dW_pytorch[dead_mask_cpu]
pytorch_live_grad = dL_dW_pytorch[~dead_mask_cpu]

print(f"PyTorch dead gradient mean norm: {pytorch_dead_grad.norm(dim=1).mean():.6e}")
print(f"PyTorch live gradient mean norm: {pytorch_live_grad.norm(dim=1).mean():.6e}")

PyTorch dead gradient mean norm: 6.414803e-04
PyTorch live gradient mean norm: 4.182540e-03


## Comparison

In [57]:
print("="*60)
print("COMPARISON: Manual vs PyTorch")
print("="*60)

print(f"\nDead token gradients:")
print(f"  Manual:  {manual_dead_grad.norm(dim=1).mean():.6e}")
print(f"  PyTorch: {pytorch_dead_grad.norm(dim=1).mean():.6e}")

print(f"\nLive token gradients:")
print(f"  Manual:  {manual_live_grad.norm(dim=1).mean():.6e}")
print(f"  PyTorch: {pytorch_live_grad.norm(dim=1).mean():.6e}")

# Cosine similarity between manual and pytorch (both on CPU)
dL_dW_manual_cpu = dL_dW_manual.cpu()
cos_sim = F.cosine_similarity(dL_dW_manual_cpu.flatten().unsqueeze(0), 
                               dL_dW_pytorch.flatten().unsqueeze(0)).item()
print(f"\nCosine similarity (manual vs pytorch): {cos_sim:.6f}")

COMPARISON: Manual vs PyTorch

Dead token gradients:
  Manual:  6.414620e-04
  PyTorch: 6.414803e-04

Live token gradients:
  Manual:  4.181137e-03
  PyTorch: 4.182540e-03

Cosine similarity (manual vs pytorch): 1.000495


## Deep Dive: What's happening with dead token probabilities?

In [58]:
# What probability mass do dead tokens get?
dead_probs = probs[:, :, dead_mask_cpu]  # [B, T, N_dead]
live_probs = probs[:, :, ~dead_mask_cpu]  # [B, T, N_live]

print(f"Dead token probability statistics:")
print(f"  Mean p(dead): {dead_probs.mean():.6e}")
print(f"  Max p(dead):  {dead_probs.max():.6e}")
print(f"  Min p(dead):  {dead_probs.min():.6e}")
print(f"  Sum p(dead) per position: {dead_probs.sum(dim=-1).mean():.6e}")

print(f"\nLive token probability statistics:")
print(f"  Mean p(live): {live_probs.mean():.6e}")
print(f"  Max p(live):  {live_probs.max():.6e}")
print(f"  Min p(live):  {live_probs.min():.6e}")
print(f"  Sum p(live) per position: {live_probs.sum(dim=-1).mean():.6e}")

print(f"\nUniform would be: {1/vocab_size:.6e}")

Dead token probability statistics:
  Mean p(dead): 2.508319e-04
  Max p(dead):  6.701551e-04
  Min p(dead):  9.260746e-05
  Sum p(dead) per position: 4.800921e-01

Live token probability statistics:
  Mean p(live): 2.506789e-04
  Max p(live):  6.889911e-04
  Min p(live):  8.991521e-05
  Sum p(live) per position: 5.199079e-01

Uniform would be: 2.507523e-04


In [59]:
# What are the logits like?
dead_logits = logits[:, :, dead_mask_cpu].float()  # [B, T, N_dead]
live_logits = logits[:, :, ~dead_mask_cpu].float()  # [B, T, N_live]

print(f"Dead token logit statistics:")
print(f"  Mean: {dead_logits.mean():.4f}")
print(f"  Std:  {dead_logits.std():.4f}")
print(f"  Min:  {dead_logits.min():.4f}")
print(f"  Max:  {dead_logits.max():.4f}")

print(f"\nLive token logit statistics:")
print(f"  Mean: {live_logits.mean():.4f}")
print(f"  Std:  {live_logits.std():.4f}")
print(f"  Min:  {live_logits.min():.4f}")
print(f"  Max:  {live_logits.max():.4f}")

Dead token logit statistics:
  Mean: -0.0002
  Std:  0.2257
  Min:  -0.9727
  Max:  1.0078

Live token logit statistics:
  Mean: -0.0010
  Std:  0.2269
  Min:  -1.0078
  Max:  1.0312


## The Culprit?

In [60]:
# Is the issue in the embedding initialization?
W = model.tok_emb.weight.detach().cpu().float()

dead_W = W[dead_mask_cpu]
live_W = W[~dead_mask_cpu]

print(f"Dead token embedding statistics:")
print(f"  Mean norm: {dead_W.norm(dim=1).mean():.6f}")
print(f"  Std norm:  {dead_W.norm(dim=1).std():.6f}")

print(f"\nLive token embedding statistics:")
print(f"  Mean norm: {live_W.norm(dim=1).mean():.6f}")
print(f"  Std norm:  {live_W.norm(dim=1).std():.6f}")

Dead token embedding statistics:
  Mean norm: 0.225375
  Std norm:  0.013891

Live token embedding statistics:
  Mean norm: 0.226024
  Std norm:  0.013994


In [61]:
# What's h · W^T giving us?
# Take one h vector and compute its dot product with all embeddings
h_sample = h_float[0, 0, :].cpu()  # [D] - first position of first batch, move to CPU

dots_dead = (dead_W @ h_sample)  # [N_dead]
dots_live = (live_W @ h_sample)  # [N_live]

print(f"h · W for dead tokens:")
print(f"  Mean: {dots_dead.mean():.4f}")
print(f"  Std:  {dots_dead.std():.4f}")

print(f"\nh · W for live tokens:")
print(f"  Mean: {dots_live.mean():.4f}")
print(f"  Std:  {dots_live.std():.4f}")

h · W for dead tokens:
  Mean: -0.0025
  Std:  0.2284

h · W for live tokens:
  Mean: -0.0026
  Std:  0.2228


## Summary

In [None]:
print("="*60)
print("FINDINGS")
print("="*60)
print("""
With explicit torch.randn * 0.02 initialization:

1. LOGIT GAP IS GONE
   - Dead logits: mean -0.0002, std 0.23
   - Live logits: mean -0.0010, std 0.23
   - Identical distributions!

2. PROBABILITIES ARE UNIFORM
   - Dead p(token): 2.508e-4
   - Live p(token): 2.507e-4
   - Uniform would be: 2.508e-4
   - No systematic suppression of dead tokens.

3. GRADIENTS ARE DIFFERENT (but correctly so!)
   - Dead gradient: 6.4e-4
   - Live gradient: 4.2e-3
   - Why? Live tokens appear as TARGETS, so they get the (p-1)×h 
     correction term. Dead tokens only get p×h. This is expected.

4. THE ORIGINAL MYSTERY WAS INITIALIZATION
   - PyTorch's default nn.Embedding uses N(0,1)
   - Our explicit N(0,0.02) produces ~50× smaller embeddings
   - With default init, something created 16-unit logit gap
   - With explicit init, logits are centered at 0 as expected

5. REMAINING QUESTION
   - Why did PyTorch default create systematic dead/live separation?
   - Possibly: larger random vectors have more extreme h·W products
   - The pathology wasn't in the tokenizer—it was in uncontrolled init.
""")