In [None]:
# Inputs: set your checkpoint directory and step
import os
import torch
import pickle

from model import GPT, GPTConfig

# EDIT these two
CKPT_DIR = '/content/drive/MyDrive/ml_projects/post_rmsnorm/out-small-model-post/lr_0.001'
STEP = 5000  # e.g., 5000

# Device & dtype
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)



In [None]:
# Load checkpoint and construct the model
import sys

ckpt_path_step = os.path.join(CKPT_DIR, 'checkpoints', f'ckpt_{STEP}.pt')
ckpt_path_latest = os.path.join(CKPT_DIR, 'ckpt.pt')
ckpt_path = ckpt_path_step if os.path.exists(ckpt_path_step) else ckpt_path_latest
print('Loading checkpoint from:', ckpt_path)

checkpoint = torch.load(ckpt_path, map_location=device)
model_args = dict(checkpoint['model_args'])  # base model args saved during training
cfg_dict = checkpoint.get('config', {})      # full training config (for extra flags)

# Ensure normalization flags are consistent with training
for k in ['post_ln', 'rmsnorm', 'ln_learnable']:
    if k in cfg_dict:
        model_args[k] = cfg_dict[k]

# Build model and load weights
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
state_dict = checkpoint['model']
# fix any unwanted prefixes
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval().to(device)

print('Model loaded. Iteration:', checkpoint.get('iter_num'))


In [None]:
# Build tokenizer (char-level). Try to load meta.pkl from training dataset
import json

dataset = cfg_dict.get('dataset', 'openwebtext')
meta_path = os.path.join('data', dataset, 'meta.pkl')
print('Dataset:', dataset)
print('Looking for meta at:', meta_path)

stoi = itos = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    stoi = meta.get('stoi')
    itos = meta.get('itos')
    print('Loaded vocabulary from meta.pkl (size = {})'.format(len(itos) if itos else 'unknown'))

if stoi is not None and itos is not None:
    def encode(s: str):
        return [stoi[ch] for ch in s]
    def decode(tokens):
        return ''.join([itos[i] for i in tokens])
else:
    # Fallback: byte-level encode/decode (works without meta)
    def encode(s: str):
        return list(s.encode('utf-8'))
    def decode(tokens):
        return bytes(tokens).decode('utf-8', errors='ignore')



In [None]:
# Generation using the model's built-in generate()
prompt = "To be, or not to be"
max_new_tokens = 200
temperature = 1.0
TopK = 50  # set to None to disable top-k filtering

# Encode prompt and generate
idx = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
idx = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=TopK)

# Decode and print
print(decode(idx[0].tolist()))


In [None]:
# Prepare a ~200-char Shakespeare sample and tokenize
sample_text = (
    "To be, or not to be, that is the question:\n"
    "Whether 'tis nobler in the mind to suffer\n"
    "The slings and arrows of outrageous fortune,\n"
    "Or to take arms against a sea of troubles\n"
)
idx_sample = torch.tensor([encode(sample_text)], dtype=torch.long, device=device)
S = idx_sample.size(1)
print('Sequence length S =', S)


In [None]:
# Capture per-layer activations with forward hooks and prepend embeddings
block_outputs = []
hooks = []

# register hooks on each transformer block to capture its output (shape: 1 x S x d)
for i, block in enumerate(model.transformer.h):
    def make_hook(i):
        def hook(module, inp, out):
            block_outputs.append(out.detach())
        return hook
    hooks.append(block.register_forward_hook(make_hook(i)))

# run a single forward pass (no targets, no generation)
with torch.no_grad():
    _logits, _ = model(idx_sample)

# remove hooks
for h in hooks:
    h.remove()

# stack to L x S x d
layer_acts = torch.stack([t.squeeze(0) for t in block_outputs], dim=0)

# compute embedding + positional (and ln_emb if post_ln)
with torch.no_grad():
    tok_emb = model.transformer.wte(idx_sample)              # 1 x S x d
    pos = torch.arange(0, S, dtype=torch.long, device=device)
    pos_emb = model.transformer.wpe(pos)                     # S x d
    x0 = tok_emb + pos_emb
    if getattr(model.config, 'post_ln', False):
        x0 = model.transformer.ln_emb(x0)
    x0 = x0.squeeze(0)                                       # S x d

# final array: (L+1) x S x d
all_acts = torch.cat([x0.unsqueeze(0), layer_acts], dim=0)
print('all_acts shape:', tuple(all_acts.shape))



In [None]:
# Average L2 norm per layer (averaged over token positions)
import torch
import matplotlib.pyplot as plt

assert 'all_acts' in globals(), "all_acts (L+1, S, d) not found"
Lp1, S, d = all_acts.shape

with torch.no_grad():
    norms = all_acts.norm(dim=-1)               # (L+1, S)
    avg_norm = norms.mean(dim=1).cpu().numpy()  # (L+1,)

layers = list(range(Lp1))
plt.figure(figsize=(6,4))
plt.plot(layers, avg_norm, marker='o')
plt.xlabel('layer (0 = embeddings)')
plt.ylabel('avg ||activation||')
plt.title('Average activation norm vs. layer')
plt.grid(True, alpha=0.3)
plt.show()


In [None]:
# Average cosine similarity between consecutive layers (layer vs. layer-1)
import torch
import matplotlib.pyplot as plt

eps = 1e-8
# a: layers 1..L (exclude embeddings), b: layers 0..L-1
a = all_acts[1:]     # (L, S, d)
b = all_acts[:-1]    # (L, S, d)

with torch.no_grad():
    dot = (a * b).sum(dim=-1)                         # (L, S)
    na  = a.norm(dim=-1).clamp_min(eps)               # (L, S)
    nb  = b.norm(dim=-1).clamp_min(eps)               # (L, S)
    cos = (dot / (na * nb)).mean(dim=1).cpu().numpy() # (L,)

layers = list(range(1, all_acts.size(0)))
plt.figure(figsize=(6,4))
plt.plot(layers, cos, marker='o', color='tab:orange')
plt.xlabel('layer')
plt.ylabel('avg cos(layer, layer-1)')
plt.title('Average cosine similarity vs. layer')
plt.grid(True, alpha=0.3)
plt.show()


In [None]:
# Cache c_proj outputs ("velocities") for attention and MLP per layer
attn_vel_outs, mlp_vel_outs = [], []
attn_hooks, mlp_hooks = [], []

for i, block in enumerate(model.transformer.h):
    # attention c_proj output
    def make_attn_hook(i):
        def hook(module, inp, out):
            attn_vel_outs.append(out.detach())  # shape: 1 x S x d
        return hook
    attn_hooks.append(block.attn.c_proj.register_forward_hook(make_attn_hook(i)))

    # mlp c_proj output
    def make_mlp_hook(i):
        def hook(module, inp, out):
            mlp_vel_outs.append(out.detach())   # shape: 1 x S x d
        return hook
    mlp_hooks.append(block.mlp.c_proj.register_forward_hook(make_mlp_hook(i)))

with torch.no_grad():
    _ = model(idx_sample)

for h in attn_hooks + mlp_hooks:
    h.remove()

# Stack to (L, S, d)
attn_vel = torch.stack([t.squeeze(0) for t in attn_vel_outs], dim=0)
mlp_vel  = torch.stack([t.squeeze(0) for t in mlp_vel_outs],  dim=0)
print('attn_vel:', tuple(attn_vel.shape), 'mlp_vel:', tuple(mlp_vel.shape))


In [None]:
# Cosine similarity: velocity vs. activation at the same layer (averaged over tokens)
import matplotlib.pyplot as plt

eps = 1e-8
assert 'layer_acts' in globals(), "layer_acts (L, S, d) not found; run the hooks cell that builds it"

# Ensure shapes match
L = layer_acts.size(0)
assert attn_vel.size(0) == L and mlp_vel.size(0) == L

with torch.no_grad():
    # attention
    dot_a = (attn_vel * layer_acts).sum(dim=-1)              # (L, S)
    na = attn_vel.norm(dim=-1).clamp_min(eps)                # (L, S)
    nx = layer_acts.norm(dim=-1).clamp_min(eps)              # (L, S)
    cos_attn = (dot_a / (na * nx)).mean(dim=1).cpu().numpy() # (L,)

    # mlp
    dot_m = (mlp_vel * layer_acts).sum(dim=-1)               # (L, S)
    nm = mlp_vel.norm(dim=-1).clamp_min(eps)                 # (L, S)
    cos_mlp = (dot_m / (nm * nx)).mean(dim=1).cpu().numpy()  # (L,)

layers = list(range(1, L+1))  # plot layers starting at 1; embeddings are layer 0 in other plots
plt.figure(figsize=(7,4))
plt.plot(layers, cos_attn, marker='o', label='attn velocity ⋅ activation')
plt.plot(layers, cos_mlp,  marker='o', label='mlp velocity ⋅ activation')
plt.xlabel('layer index (1..L)')
plt.ylabel('avg cosine similarity')
plt.title('Cosine similarity: velocity vs activation (per layer)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()


In [None]:
# Heatmap: cosine similarity between consecutive layers per token (L x S)
import torch
import numpy as np
import matplotlib.pyplot as plt

eps = 1e-8
# a: layers 1..L (exclude embeddings), b: layers 0..L-1
a = all_acts[1:]     # (L, S, d)
b = all_acts[:-1]    # (L, S, d)

with torch.no_grad():
    dot = (a * b).sum(dim=-1)                # (L, S)
    na  = a.norm(dim=-1).clamp_min(eps)      # (L, S)
    nb  = b.norm(dim=-1).clamp_min(eps)      # (L, S)
    cos_map = (dot / (na * nb)).cpu().numpy()  # (L, S)

L, S = cos_map.shape
plt.figure(figsize=(10, 5))
im = plt.imshow(cos_map, aspect='auto', cmap='coolwarm', vmin=-1, vmax=1)
plt.colorbar(im, label='cosine similarity')
plt.xlabel('token position (0..S-1)')
plt.ylabel('layer (1..L)')
plt.yticks(ticks=range(L), labels=[str(i) for i in range(1, L+1)])
plt.gca().invert_yaxis()
plt.title('Cosine similarity heatmap: layer vs layer-1 per token')
plt.tight_layout()
plt.show()


In [None]:
# Draw a training batch and run forward pass to compute loss
import os
import numpy as np

# Configure batch size here
BATCH_SIZE = 8  # <-- change as needed
BLOCK_SIZE = gptconf.block_size

# Dataset memmaps (same source as training)
data_dir = os.path.join('data', dataset)
train_mm = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')

def get_train_batch(batch_size=BATCH_SIZE, block_size=BLOCK_SIZE):
    ix = torch.randint(len(train_mm) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((train_mm[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((train_mm[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    return x.to(device), y.to(device)

# Switch to train mode (if you want dropout etc.); gradients will be computed anyway
model.train()
Xb, Yb = get_train_batch()
model.zero_grad(set_to_none=True)
logits, loss = model(Xb, Yb)
print('Batch:', Xb.shape, 'Loss:', float(loss))

In [None]:
# Register hooks to cache inputs to RMSNorm modules and their gradients
from collections import OrderedDict
from model import RMSNorm as _RMSNorm

ln_in_tensors = OrderedDict()
ln_hooks = []

# helper to register a hook capturing the module input (first arg)
def _make_ln_hook(name):
    def hook(module, inp, out):
        x_in = inp[0]
        # retain grad for non-leaf tensors so we can read x_in.grad after backward
        x_in.retain_grad()
        ln_in_tensors[name] = x_in
    return hook

# ln_emb (post-layernorm models only)
if hasattr(model.transformer, 'ln_emb') and isinstance(model.transformer.ln_emb, _RMSNorm):
    ln_hooks.append(model.transformer.ln_emb.register_forward_hook(_make_ln_hook('ln_emb_in')))

# per-block ln_1 and ln_2
for li, block in enumerate(model.transformer.h):
    if isinstance(block.ln_1, _RMSNorm):
        ln_hooks.append(block.ln_1.register_forward_hook(_make_ln_hook(f'block{li}.ln_1_in')))
    if isinstance(block.ln_2, _RMSNorm):
        ln_hooks.append(block.ln_2.register_forward_hook(_make_ln_hook(f'block{li}.ln_2_in')))

# final ln_f
if isinstance(model.transformer.ln_f, _RMSNorm):
    ln_hooks.append(model.transformer.ln_f.register_forward_hook(_make_ln_hook('ln_f_in')))

# Run a forward+backward to populate grads
logits, loss = model(Xb, Yb)
loss.backward()

# Remove hooks
for h in ln_hooks:
    h.remove()

# Collect gradients (on CPU) for inspection
ln_input_grads = {name: t.grad.detach().cpu() if t.grad is not None else None
                  for name, t in ln_in_tensors.items()}

print('Captured LN input tensors:', list(ln_in_tensors.keys()))
print('Grad shapes summary:', {k: None if v is None else tuple(v.shape) for k, v in ln_input_grads.items()})


In [None]:
# Gradient stats per LN input key (no mixing ln_1/ln_2). Heatmaps and per-key scalars
import torch
import numpy as np
import matplotlib.pyplot as plt

assert 'ln_input_grads' in globals(), "Run the RMSNorm grad hook cell first."

# Keep the original order of keys as captured
keys = [k for k, v in ln_input_grads.items() if v is not None]
if not keys:
    raise RuntimeError('ln_input_grads has no populated entries with gradients.')

means_KS = []  # (K, S)
stds_KS = []   # (K, S)
scalar_means_k = []
scalar_stds_k = []

with torch.no_grad():
    for k in keys:
        g = ln_input_grads[k]                 # numpy or tensor; expected shape (B, S, d)
        g = torch.as_tensor(g, device=device)
        mean_S = g.mean(dim=(0, 2)).detach().cpu().numpy()                 # (S,)
        std_S  = g.std(dim=(0, 2), unbiased=False).detach().cpu().numpy()  # (S,)
        means_KS.append(mean_S)
        stds_KS.append(std_S)
        scalar_means_k.append(float(g.mean().detach().cpu()))
        scalar_stds_k.append(float(g.std(unbiased=False).detach().cpu()))

means_KS = np.stack(means_KS, axis=0)  # (K, S)
stds_KS  = np.stack(stds_KS,  axis=0)  # (K, S)
K, S = means_KS.shape

# Heatmap: gradient mean per key x token
plt.figure(figsize=(10, max(4, 0.3*K + 3)))
im1 = plt.imshow(means_KS, aspect='auto', cmap='viridis')
plt.colorbar(im1, label='grad mean over batch+hidden')
plt.xlabel('token position (0..S-1)')
plt.ylabel('LN input key')
plt.yticks(ticks=range(K), labels=keys)
plt.gca().invert_yaxis()
plt.title('RMSNorm input gradient mean (per key x token)')
plt.tight_layout()
plt.show()

# Heatmap: gradient std per key x token
plt.figure(figsize=(10, max(4, 0.3*K + 3)))
im2 = plt.imshow(stds_KS, aspect='auto', cmap='magma')
plt.colorbar(im2, label='grad std over batch+hidden')
plt.xlabel('token position (0..S-1)')
plt.ylabel('LN input key')
plt.yticks(ticks=range(K), labels=keys)
plt.gca().invert_yaxis()
plt.title('RMSNorm input gradient std (per key x token)')
plt.tight_layout()
plt.show()

# Scalar mean/std per key (over all dims)
plt.figure(figsize=(max(7, 0.4*K + 3), 4))
plt.plot(range(K), scalar_means_k, marker='o')
plt.xticks(range(K), keys, rotation=45, ha='right')
plt.xlabel('LN input key')
plt.ylabel('grad mean over (B,S,d)')
plt.title('RMSNorm input gradient mean vs key')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

plt.figure(figsize=(max(7, 0.4*K + 3), 4))
plt.plot(range(K), scalar_stds_k, marker='o', color='tab:orange')
plt.xticks(range(K), keys, rotation=45, ha='right')
plt.xlabel('LN input key')
plt.ylabel('grad std over (B,S,d)')
plt.title('RMSNorm input gradient std vs key')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Cache inputs to all RMSNorms (no grads), keyed by module name
from collections import OrderedDict
from model import RMSNorm as _RMSNorm

rms_in_tensors = OrderedDict()
hooks_in = []

def _make_in_hook(name):
    def hook(module, inp, out):
        x_in = inp[0]
        rms_in_tensors[name] = x_in.detach()
    return hook

# ln_emb (if post-LN)
if hasattr(model.transformer, 'ln_emb') and isinstance(model.transformer.ln_emb, _RMSNorm):
    hooks_in.append(model.transformer.ln_emb.register_forward_hook(_make_in_hook('ln_emb_in')))

# per-block ln_1 / ln_2
for li, block in enumerate(model.transformer.h):
    if isinstance(block.ln_1, _RMSNorm):
        hooks_in.append(block.ln_1.register_forward_hook(_make_in_hook(f'block{li}.ln_1_in')))
    if isinstance(block.ln_2, _RMSNorm):
        hooks_in.append(block.ln_2.register_forward_hook(_make_in_hook(f'block{li}.ln_2_in')))

# final ln_f
if isinstance(model.transformer.ln_f, _RMSNorm):
    hooks_in.append(model.transformer.ln_f.register_forward_hook(_make_in_hook('ln_f_in')))

model.eval()
with torch.no_grad():
    _ = model(Xb)  # use the same Xb you created earlier

for h in hooks_in:
    h.remove()

print('Captured keys:', list(rms_in_tensors.keys()))

In [None]:
# Input stats per LN input key: heatmaps (mean/std over batch+hidden) and scalar per-key stats
import numpy as np
import matplotlib.pyplot as plt

keys_in = [k for k, v in rms_in_tensors.items() if v is not None]
assert keys_in, "No inputs captured; run the cache cell first."

means_KS_in, stds_KS_in = [], []
scalar_means_in, scalar_stds_in = [], []

with torch.no_grad():
    for k in keys_in:
        x = rms_in_tensors[k].to(device)  # (B, S, d)
        mean_S = x.mean(dim=(0, 2)).cpu().numpy()                 # (S,)
        std_S  = x.std(dim=(0, 2), unbiased=False).cpu().numpy()  # (S,)
        means_KS_in.append(mean_S)
        stds_KS_in.append(std_S)
        scalar_means_in.append(float(x.mean().cpu()))
        scalar_stds_in.append(float(x.std(unbiased=False).cpu()))

means_KS_in = np.stack(means_KS_in, axis=0)  # (K, S)
stds_KS_in  = np.stack(stds_KS_in,  axis=0)  # (K, S)
K, S = means_KS_in.shape

# Heatmap: input mean
plt.figure(figsize=(10, max(4, 0.3*K + 3)))
im1 = plt.imshow(means_KS_in, aspect='auto', cmap='viridis')
plt.colorbar(im1, label='mean over batch+hidden')
plt.xlabel('token position (0..S-1)')
plt.ylabel('LN input key')
plt.yticks(ticks=range(K), labels=keys_in)
plt.gca().invert_yaxis()
plt.title('RMSNorm input mean (per key x token)')
plt.tight_layout()
plt.show()

# Heatmap: input std
plt.figure(figsize=(10, max(4, 0.3*K + 3)))
im2 = plt.imshow(stds_KS_in, aspect='auto', cmap='magma')
plt.colorbar(im2, label='std over batch+hidden')
plt.xlabel('token position (0..S-1)')
plt.ylabel('LN input key')
plt.yticks(ticks=range(K), labels=keys_in)
plt.gca().invert_yaxis()
plt.title('RMSNorm input std (per key x token)')
plt.tight_layout()
plt.show()

# Scalar mean/std per key
plt.figure(figsize=(max(7, 0.4*K + 3), 4))
plt.plot(range(K), scalar_means_in, marker='o')
plt.xticks(range(K), keys_in, rotation=45, ha='right')
plt.xlabel('LN input key')
plt.ylabel('mean over (B,S,d)')
plt.title('RMSNorm input mean vs key')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

plt.figure(figsize=(max(7, 0.4*K + 3), 4))
plt.plot(range(K), scalar_stds_in, marker='o', color='tab:orange')
plt.xticks(range(K), keys_in, rotation=45, ha='right')
plt.xlabel('LN input key')
plt.ylabel('std over (B,S,d)')
plt.title('RMSNorm input std vs key')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()