# Goldilocks 2560D

**Question:** What happens when we scale Goldilocks to Qwen's hidden dimension?

This is an experiment to see if certain behaviors only emerge at high dimensionality.
We keep vocab and architecture the same, but scale D from 128 → 2560 (20×).

| Property | Value |
|----------|-------|
| Architecture | 4L/2560D/32H/10240FF |
| Vocab size | 3,988 |
| Dead tokens | 1,914 |
| Sequence length | 128 |
| Batch size | 8 |
| Dtype | bfloat16 |

---

*Jeffery Harrell & Alpha, December 1, 2025*

In [1]:
import torch

# Paths
GOLDILOCKS_DATA = "data"
TOKENIZER_PATH = f"{GOLDILOCKS_DATA}/tokenizer.json"
TOKENS_PATH = f"{GOLDILOCKS_DATA}/model_corpus_tokens.safetensors"
CENSUS_PATH = f"{GOLDILOCKS_DATA}/token_census.json"

# Architecture: Qwen-scale hidden dim
N_LAYERS = 4
D_MODEL = 2560  # Qwen 3 4B's hidden_size
N_HEADS = 32    # Scale heads with D (head_dim = 80)
D_FF = 10240    # 4× D_MODEL, typical ratio
SEQ_LEN = 128
DROPOUT = 0.0

# Training
BATCH_SIZE = 8
LEARNING_RATE = 1e-3
MODEL_DTYPE = torch.bfloat16

# Reproducibility
RANDOM_SEED = 42

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import load_file
from tokenizers import Tokenizer
import json
import time
from tqdm.auto import tqdm

torch.manual_seed(RANDOM_SEED)

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

print(f"Device: {device}")
print(f"Dtype: {MODEL_DTYPE}")

Device: mps
Dtype: torch.bfloat16


In [3]:
# Load tokenizer and data
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
vocab_size = tokenizer.get_vocab_size()
print(f"✓ Tokenizer: {vocab_size:,} tokens")

tokens_data = load_file(TOKENS_PATH)
all_tokens = tokens_data["tokens"].to(torch.long)
print(f"✓ Corpus: {len(all_tokens):,} tokens")

with open(CENSUS_PATH, 'r') as f:
    census = json.load(f)
dead_token_ids = set(census['dead_token_ids'])

dead_mask = torch.zeros(vocab_size, dtype=torch.bool)
for tid in dead_token_ids:
    dead_mask[tid] = True
live_mask = ~dead_mask

print(f"✓ Dead tokens: {len(dead_token_ids):,}")

✓ Tokenizer: 3,988 tokens
✓ Corpus: 34,993,926 tokens
✓ Dead tokens: 1,914


In [4]:
class TokenDataset(Dataset):
    def __init__(self, tokens, seq_len, num_samples=100_000):
        self.tokens = tokens
        self.seq_len = seq_len
        self.num_samples = num_samples
        max_start = len(tokens) - seq_len - 1
        self.starts = torch.randint(0, max_start, (num_samples,))
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        start = self.starts[idx]
        chunk = self.tokens[start:start + self.seq_len + 1]
        return chunk[:-1], chunk[1:]

dataset = TokenDataset(all_tokens, SEQ_LEN)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"✓ Dataset: {len(dataset):,} samples")

✓ Dataset: 100,000 samples


In [5]:
class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, seq_len, dropout=0.0):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(seq_len, d_model)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
                dropout=dropout, activation='gelu', batch_first=True, norm_first=True
            ) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.head.weight = self.tok_emb.weight  # Weight tying
        self.seq_len = seq_len
        self.register_buffer('causal_mask', None)
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize token embeddings explicitly. N(0, 0.02) like Qwen."""
        with torch.no_grad():
            self.tok_emb.weight.copy_(torch.randn(self.tok_emb.weight.shape) * 0.02)
    
    def forward(self, x):
        B, T = x.shape
        if self.causal_mask is None or self.causal_mask.shape[0] != T:
            self.causal_mask = torch.triu(
                torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
            )
        pos = torch.arange(T, device=x.device)
        h = self.tok_emb(x) + self.pos_emb(pos)
        for layer in self.layers:
            h = layer(h, src_mask=self.causal_mask, is_causal=True)
        return self.head(self.ln_f(h))

model = GPT(
    vocab_size=vocab_size,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    d_ff=D_FF,
    seq_len=SEQ_LEN,
    dropout=DROPOUT
).to(device).to(MODEL_DTYPE)

n_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model: {n_params:,} parameters ({MODEL_DTYPE})")
print(f"  tok_emb: {model.tok_emb.weight.numel():,}")
print(f"  pos_emb: {model.pos_emb.weight.numel():,}")

✓ Model: 325,248,000 parameters (torch.bfloat16)
  tok_emb: 10,209,280
  pos_emb: 327,680


In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

print(f"✓ Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"✓ Batch size: {BATCH_SIZE}")
print(f"✓ Tokens per step: {BATCH_SIZE * SEQ_LEN:,}")

✓ Optimizer: AdamW (lr=0.001)
✓ Batch size: 8
✓ Tokens per step: 1,024


## Speed Test

Run 10 steps and time it.

In [7]:
NUM_STEPS = 10

model.train()
loader_iter = iter(loader)

# Warmup step (first step often slower due to compilation)
x, y = next(loader_iter)
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()
optimizer.step()
print(f"Warmup loss: {loss.item():.4f}")

# Timed run
start_time = time.time()

for step in range(NUM_STEPS):
    x, y = next(loader_iter)
    x, y = x.to(device), y.to(device)
    
    optimizer.zero_grad()
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
    loss.backward()
    optimizer.step()

elapsed = time.time() - start_time
steps_per_sec = NUM_STEPS / elapsed

print(f"\n{NUM_STEPS} steps in {elapsed:.2f}s")
print(f"Speed: {steps_per_sec:.2f} steps/sec")
print(f"Final loss: {loss.item():.4f}")
print(f"\nFor 100 steps: ~{100/steps_per_sec:.1f} seconds")

Warmup loss: 8.8125

10 steps in 4.14s
Speed: 2.41 steps/sec
Final loss: 8.7500

For 100 steps: ~41.4 seconds


## Quick Sanity Check

Verify dead tokens behave correctly at this scale.

In [8]:
# Get gradients from one more step
x, y = next(loader_iter)
x, y = x.to(device), y.to(device)

optimizer.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()

# Check gradient magnitudes
grad_W = model.tok_emb.weight.grad.detach().cpu().float()

dead_grads = grad_W[dead_mask]
live_grads = grad_W[~dead_mask]

dead_norm = dead_grads.norm(dim=1).mean().item()
live_norm = live_grads.norm(dim=1).mean().item()

print(f"Dead gradient norm: {dead_norm:.2e}")
print(f"Live gradient norm: {live_norm:.2e}")
print(f"Ratio (live/dead): {live_norm/dead_norm:.1f}x")

# Check coherence
if dead_grads.shape[0] > 100:
    sample_idx = torch.randperm(dead_grads.shape[0])[:100]
    dead_grads_sample = dead_grads[sample_idx]
else:
    dead_grads_sample = dead_grads

dead_grads_normed = F.normalize(dead_grads_sample, dim=1)
cos_matrix = dead_grads_normed @ dead_grads_normed.T
n = cos_matrix.shape[0]
triu_indices = torch.triu_indices(n, n, offset=1)
coherence = cos_matrix[triu_indices[0], triu_indices[1]].mean().item()

print(f"Dead token coherence: {coherence:.4f}")

Dead gradient norm: 6.14e-06
Live gradient norm: 2.95e-02
Ratio (live/dead): 4810.6x
Dead token coherence: 1.0000


In [9]:
# Initial centroid separation
W = model.tok_emb.weight.detach().cpu().float()

centroid_dead = W[dead_mask].mean(dim=0)
centroid_live = W[~dead_mask].mean(dim=0)

cos_centroids = F.cosine_similarity(centroid_dead.unsqueeze(0), centroid_live.unsqueeze(0)).item()
import numpy as np
angle = np.degrees(np.arccos(np.clip(cos_centroids, -1, 1)))

print(f"Dead centroid norm: {centroid_dead.norm():.6f}")
print(f"Live centroid norm: {centroid_live.norm():.6f}")
print(f"Angle between centroids: {angle:.1f}°")

Dead centroid norm: 0.209304
Live centroid norm: 0.077209
Angle between centroids: 43.0°
