<a href="https://colab.research.google.com/github/kutay25/toy_mechinterp_explainer/blob/main/mini_tf_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
    import google.colab
    IN_COLAB = True
    !pip install einops
    !pip install git+https://github.com/TransformerLensOrg/TransformerLens
except:
    IN_COLAB = False

Collecting git+https://github.com/TransformerLensOrg/TransformerLens
  Cloning https://github.com/TransformerLensOrg/TransformerLens to /tmp/pip-req-build-3hqu617m
  Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens /tmp/pip-req-build-3hqu617m
  Resolved https://github.com/TransformerLensOrg/TransformerLens to commit db0f191f2e536e49c7784b69bae1fbfdd0141380
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens==0.0.0)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import plotly.express as px
import plotly.io as pio
import itertools
import math
from torch.optim.lr_scheduler import LambdaLR
from transformer_lens import HookedTransformer, HookedTransformerConfig

# Set the default renderer for Plotly (adjust for your environment)
pio.renderers.default = "colab"

#########################################
# 1. DATASET: BOS + 8 Inputs + 8 Outputs = 17 tokens per example
#########################################
# Vocabulary: {0, 1, 2, 3, BOS=4}
# Task: For input tokens x0..x7, output tokens y0..y7 are defined as:
#       y0, y1, ..., y7 = reverse(x0..x7) with each token incremented mod 4.
# Full sequence: [BOS, x0, x1, ..., x7, y0, y1, ..., y7]

# Generate all 4^8 possible input sequences.
all_sequences = list(itertools.product(range(4), repeat=8))  # total 65536 sequences
all_sequences = torch.tensor(all_sequences, dtype=torch.long)  # shape [65536, 8]

# Compute targets: reverse the sequence and add 1 mod 4.
targets = torch.flip((all_sequences + 1) % 4, dims=[1])  # shape [65536, 8]

# Prepend BOS token (4)
BOS_TOKEN = 4
BOS_column = torch.full((all_sequences.size(0), 1), BOS_TOKEN, dtype=torch.long)

# Concatenate to get full sequence: [BOS, x0..x7, y0..y7] with shape [65536, 17]
full_dataset = torch.cat([BOS_column, all_sequences, targets], dim=1)

print("Full dataset shape:", full_dataset.shape)
print("Example sequence (index 0):", full_dataset[0])

# Split into training and holdout sets.
train_dataset = full_dataset[:60000]
holdout_dataset = full_dataset[60000:]
print("Train dataset shape:", train_dataset.shape)
print("Holdout dataset shape:", holdout_dataset.shape)

#########################################
# 2. MODEL CONFIGURATION
#########################################
# Define a tiny Transformer model with context length 17 (BOS + 8 in + 8 out).
# We add dropout parameters to help stabilize training.
cfg = HookedTransformerConfig(
    n_layers=2,
    d_model=4,         # Hidden dimension
    d_head=2,          # Head dimension
    n_heads=1,         # One head for simplicity
    d_mlp=16,          # MLP hidden dimension
    d_vocab=5,         # Vocabulary: tokens 0,1,2,3, BOS (4)
    n_ctx=17,          # Sequence length (BOS + 8 in + 8 out)
    act_fn='relu',
    normalization_type='LN',
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

model = HookedTransformer(cfg)
print("Model config:\n", model)

#########################################
# 3. LOSS FUNCTION & EVALUATION FUNCTION
#########################################
# We use next-token prediction: feed in the first 16 tokens to predict tokens 1..16.
# However, we compute cross-entropy only on the output positions (tokens 9 to 16)
# which correspond to y0..y7.
def compute_loss(logits, tokens):
    """
    logits: shape [batch, 16, vocab_size] (output from feeding tokens[:, :-1])
    tokens: shape [batch, 17]
    We compute cross-entropy only for positions 9..16 (corresponding to logits indices 8..15).
    """
    output_logits = logits[:, 8:]   # shape [batch, 8, vocab_size]
    target_tokens = tokens[:, 9:]     # shape [batch, 8]
    return F.cross_entropy(output_logits.reshape(-1, output_logits.size(-1)),
                           target_tokens.reshape(-1))

def evaluate_model(model, data, batch_size=256):
    """
    Evaluate the model's accuracy on the final 8 output tokens (y0..y7).
    """
    model.eval()
    correct = 0
    total = 0
    device = next(model.parameters()).device

    with torch.no_grad():
        for start in range(0, data.size(0), batch_size):
            end = start + batch_size
            batch = data[start:end].to(device)
            logits = model(batch[:, :-1])  # shape: [batch, 16, 5]
            preds = logits.argmax(dim=-1)  # shape: [batch, 16]
            # Accuracy is measured for positions 9..16 (indices 8..15 in preds)
            correct += (preds[:, 8:] == batch[:, 9:]).sum().item()
            total += batch.size(0) * 8
    return correct / total

#########################################
# 4. TRAINING LOOP WITH WARMUP SCHEDULER AND GRADIENT CLIPPING
#########################################
device = next(model.parameters()).device
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-3)
num_epochs = 40
batch_size = 256

# Set up a learning rate scheduler with warmup (first 5 epochs) then cosine decay.
warmup_epochs = 5
def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return float(epoch + 1) / warmup_epochs
    else:
        # Cosine decay after warmup:
        return 0.65 * (1 + math.cos(math.pi * (epoch - warmup_epochs) / (num_epochs - warmup_epochs)))

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

print("\n--- Training ---")
losses = []
for epoch in range(num_epochs):
    # Shuffle the training data indices
    perm = torch.randperm(train_dataset.size(0))
    model.train()
    epoch_loss = 0.0
    steps = 0

    for i in range(0, train_dataset.size(0), batch_size):
        indices = perm[i:i+batch_size]
        batch = train_dataset[indices].to(device)  # shape [batch, 17]
        logits = model(batch[:, :-1])              # feed first 16 tokens
        loss = compute_loss(logits, batch)

        optimizer.zero_grad()
        loss.backward()
        # Gradient clipping to prevent large updates
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += loss.item()
        steps += 1

    epoch_loss /= steps
    losses.append(epoch_loss)
    scheduler.step()  # update learning rate

    train_acc = evaluate_model(model, train_dataset)
    holdout_acc = evaluate_model(model, holdout_dataset)
    current_lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch+1}/{num_epochs}, Loss={epoch_loss:.4f}, LR={current_lr:.6f}, "
          f"TrainAcc={train_acc*100:.2f}%, HoldoutAcc={holdout_acc*100:.2f}%")

px.line(y=losses, labels={"x": "Epoch", "y": "Loss"}, title="Training Loss").show()

#########################################
# 5. COMPARE FINAL HIDDEN STATES TO EMBEDDINGS
#########################################
def compare_to_embeddings(model, data, n_examples=4):
    """
    For a few examples, compare the final hidden state (for output positions)
    with the learned embeddings for tokens 0-3 via cosine similarity.
    """
    model.eval()
    device = next(model.parameters()).device
    sample = data[:n_examples].to(device)

    # Run with cache to capture activations.
    logits, cache = model.run_with_cache(sample[:, :-1])

    # Get final activations from the last block's resid_post (after LN)
    final_activations = cache["resid_post", model.cfg.n_layers-1]  # shape: [batch, 16, d_model]

    # Focus on output positions (indices 8..15 correspond to tokens 9..16)
    output_positions = final_activations[:, 8:]  # shape: [batch, 8, d_model]

    # Get embeddings for tokens 0-3 (ignore BOS at index 4)
    embed_matrix = model.embed.W_E  # shape: [d_vocab, d_model]
    embed_0123 = embed_matrix[:4]     # shape: [4, d_model]

    print("\n--- Comparison to Embeddings ---")
    for b in range(n_examples):
        print(f"\n=== Sample {b} ===")
        print("Sequence:", sample[b].tolist())
        for out_pos in range(8):
            vec = output_positions[b, out_pos]  # shape: [d_model]
            # Compute cosine similarities between vec and each embedding in embed_0123.
            norm_vec = vec / vec.norm(p=2).clamp_min(1e-9)
            norm_embed = embed_0123 / embed_0123.norm(p=2, dim=-1, keepdim=True).clamp_min(1e-9)
            cos_sims = (norm_vec.unsqueeze(0) * norm_embed).sum(dim=-1)  # shape: [4]
            token_pred = cos_sims.argmax().item()
            print(f"  Position {8+out_pos} -> cos_sims={cos_sims.detach().cpu().numpy()}, argmax={token_pred}")
        print("-----")

compare_to_embeddings(model, holdout_dataset, n_examples=2)

#########################################
# 6. BASIC ABLATION STUDY
#########################################
def ablate_linear_weight(tensor: torch.Tensor):
    """
    Zero out the given weight tensor in-place and return a backup copy.
    """
    backup = tensor.detach().clone()
    with torch.no_grad():
        tensor.zero_()
    return backup

print("\n--- Basic Ablation Study ---")
original_holdout_acc = evaluate_model(model, holdout_dataset)
print(f"Original holdout accuracy: {original_holdout_acc*100:.2f}%")

# 1) Ablate layer 0 MLP W_out
backup_wout0 = ablate_linear_weight(model.blocks[0].mlp.W_out)
acc_after_ablation = evaluate_model(model, holdout_dataset)
print(f"After zeroing layer 0 MLP W_out, holdout accuracy: {acc_after_ablation*100:.2f}%")
# Restore layer 0 MLP W_out
model.blocks[0].mlp.W_out.data.copy_(backup_wout0)

# 2) Ablate layer 1 MLP W_out
backup_wout1 = ablate_linear_weight(model.blocks[1].mlp.W_out)
acc_after_ablation = evaluate_model(model, holdout_dataset)
print(f"After zeroing layer 1 MLP W_out, holdout accuracy: {acc_after_ablation*100:.2f}%")
# Restore layer 1 MLP W_out
model.blocks[1].mlp.W_out.data.copy_(backup_wout1)

# 3) Ablate layer 0 Attention output (W_O)
backup_wo0 = ablate_linear_weight(model.blocks[0].attn.W_O)
acc_after_ablation = evaluate_model(model, holdout_dataset)
print(f"After zeroing layer 0 Attention W_O, holdout accuracy: {acc_after_ablation*100:.2f}%")
# Restore layer 0 Attention W_O
model.blocks[0].attn.W_O.data.copy_(backup_wo0)

print("--- End of Ablation Study ---")

#########################################
# 7. Inspecting the Attention Patterns and MLP Circuit
#########################################
model.eval()
with torch.no_grad():
    # Pick a few samples from the holdout set for analysis.
    sample = holdout_dataset[:4].to(device)  # shape: [4, 17]
    # Run the model with cache to capture intermediate activations.
    # (We feed in the first 16 tokens; the final token is the target.)
    logits, cache = model.run_with_cache(sample[:, :-1])  # shape: logits [4,16,5]

    # --- (A) Inspect Attention Patterns in Layer 0 ---
    # The attention patterns are stored in the cache under the key "pattern".
    # For layer 0, we expect that for target positions the attention should peak at the "mirrored" input.
    # Our input tokens are arranged as: [BOS, x0, x1, ..., x7, y0, ..., y7].
    # When feeding the first 16 tokens, positions 1-8 correspond to the original inputs (x0...x7)
    # and positions 9-16 correspond to the targets (y0...y7).
    # To inspect how each output token attends to the corresponding (mirrored) input, we slice:
    attn = cache["pattern", 0]  # shape: [batch, n_heads, 16, 16]
    # Slice the attention from target positions (indices 8:16) to input tokens (indices 1:9)
    attn_target = attn[:, :, 8:, 1:9]  # shape: [4, 1, 8, 8]

    for i in range(sample.size(0)):
        attn_matrix = attn_target[i, 0].cpu().numpy()
        print(f"\nAttention pattern for holdout sample {i}:")
        print(attn_matrix)

        fig = px.imshow(
            attn_matrix,
            title=f"Attention Pattern (Sample {i}, Layer 0, Head 0, Target→Input)",
            labels={"x": "Input Position (x0...x7)", "y": "Output Position (y0...y7)"},
            x=[str(j) for j in range(1, 9)],
            y=[str(j) for j in range(1, 9)],
            color_continuous_scale="RdBu",
        )
        fig.show()

    # --- (B) Inspect MLP Outputs ---
    # Here we extract the MLP outputs (post-activation) from layer 0.
    mlp_out = cache["mlp_out", 0]  # shape: [batch, 16, d_model]
    print("\nMLP outputs (Layer 0) for the first sample (all positions):")
    print(mlp_out[0].cpu().numpy())

    # In a mechanistic circuit, one might expect that at positions corresponding to outputs,
    # the MLP output is shifted to be closer to the embedding of the correct target token.


Full dataset shape: torch.Size([65536, 17])
Example sequence (index 0): tensor([4, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Train dataset shape: torch.Size([60000, 17])
Holdout dataset shape: torch.Size([5536, 17])
Model config:
 HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn


--- Comparison to Embeddings ---

=== Sample 0 ===
Sequence: [4, 3, 2, 2, 2, 1, 2, 0, 0, 1, 1, 3, 2, 3, 3, 3, 0]
  Position 8 -> cos_sims=[0.59966934 0.67470646 0.6644941  0.03871418], argmax=1
  Position 9 -> cos_sims=[0.8620937  0.49156916 0.58953005 0.9361627 ], argmax=3
  Position 10 -> cos_sims=[ 0.15906261  0.7312926   0.06396323 -0.19431397], argmax=1
  Position 11 -> cos_sims=[-0.42985612 -0.6513115  -0.4959532   0.13143522], argmax=3
  Position 12 -> cos_sims=[-0.6465711  -0.02603456 -0.8313576  -0.39166415], argmax=1
  Position 13 -> cos_sims=[-0.5401468  -0.16042335 -0.8785616   0.00314631], argmax=3
  Position 14 -> cos_sims=[-0.2713795  -0.10756792 -0.6131235   0.10645486], argmax=3
  Position 15 -> cos_sims=[-0.02197653 -0.4549393  -0.16131015  0.5051322 ], argmax=3
-----

=== Sample 1 ===
Sequence: [4, 3, 2, 2, 2, 1, 2, 0, 1, 2, 1, 3, 2, 3, 3, 3, 0]
  Position 8 -> cos_sims=[-0.6039192  -0.42531204 -0.22042409 -0.7827876 ], argmax=2
  Position 9 -> cos_sims=[0.9010959 0


Attention pattern for holdout sample 1:
[[1.93560079e-01 2.19407193e-02 2.65721202e-01 2.23336995e-01
  2.41053149e-01 2.33475752e-02 3.64551879e-03 2.13723816e-02]
 [1.73003704e-04 8.71114491e-04 2.85263726e-04 5.06760785e-04
  1.80993986e-04 1.91780657e-03 9.19911861e-01 6.95903669e-04]
 [1.04090229e-01 6.32990971e-02 1.20764948e-01 1.21550195e-01
  1.11826375e-01 7.00093657e-02 7.52812922e-02 6.13494143e-02]
 [2.42616752e-05 3.40664526e-03 6.45638920e-06 6.64282516e-06
  1.24236476e-05 1.48917455e-03 1.96597655e-03 4.41877590e-03]
 [1.07232735e-01 5.49733033e-03 2.32526079e-01 2.25642592e-01
  1.59345239e-01 8.82127602e-03 6.49476750e-03 4.73538088e-03]
 [1.23669937e-01 5.84424613e-03 2.28524297e-01 1.98536083e-01
  1.76492274e-01 7.73306750e-03 1.73259061e-03 5.32440329e-03]
 [2.12436062e-05 2.18340009e-03 7.04714330e-06 7.86349301e-06
  1.17954996e-05 1.17475132e-03 3.71958106e-03 2.66227848e-03]
 [5.86487913e-05 3.29609378e-03 2.12747382e-05 2.26436878e-05
  3.46223969e-05 1.807


Attention pattern for holdout sample 2:
[[1.48906916e-01 7.56761525e-03 3.02754074e-01 2.82133907e-01
  2.17389345e-01 1.12732863e-02 5.36927208e-03 1.99626554e-02]
 [9.42791812e-05 1.01958285e-03 1.25327613e-04 2.22167888e-04
  8.86674097e-05 1.95110554e-03 9.10720289e-01 2.50967249e-04]
 [1.03503272e-01 6.29421622e-02 1.20083965e-01 1.20864786e-01
  1.11195795e-01 6.96145892e-02 7.48567879e-02 7.23833889e-02]
 [2.43505019e-05 3.41911754e-03 6.48002742e-06 6.66714550e-06
  1.24691333e-05 1.49462663e-03 1.97317428e-03 7.61246891e-04]
 [1.10678136e-01 5.67396032e-03 2.39997178e-01 2.32892528e-01
  1.64465025e-01 9.10470542e-03 6.70344569e-03 1.42125599e-02]
 [1.25677168e-01 5.93910180e-03 2.32233390e-01 2.01758444e-01
  1.79356843e-01 7.85857998e-03 1.76071154e-03 1.74315814e-02]
 [2.12895447e-05 2.18812143e-03 7.06238188e-06 7.88049783e-06
  1.18210064e-05 1.17729162e-03 3.72762419e-03 4.84299235e-04]
 [5.88269068e-05 3.30610387e-03 2.13393487e-05 2.27124565e-05
  3.47275454e-05 1.812


Attention pattern for holdout sample 3:
[[1.81245208e-01 1.68214962e-02 2.76999682e-01 2.40251258e-01
  2.35464543e-01 1.96750537e-02 4.31336835e-03 1.93352122e-02]
 [1.22068814e-04 9.82768601e-04 1.75774578e-04 3.11242853e-04
  1.19513192e-04 1.97797338e-03 9.13650215e-01 6.41236897e-04]
 [1.04215756e-01 6.33754358e-02 1.20910585e-01 1.21696785e-01
  1.11961231e-01 7.00937957e-02 7.53720775e-02 6.36540949e-02]
 [2.42945080e-05 3.41125531e-03 6.46512672e-06 6.65181460e-06
  1.24404605e-05 1.49118982e-03 1.96863711e-03 3.06501030e-03]
 [1.09907173e-01 5.63443685e-03 2.38325402e-01 2.31270239e-01
  1.63319394e-01 9.04128328e-03 6.65675057e-03 6.06199121e-03]
 [1.25762105e-01 5.94311487e-03 2.32390314e-01 2.01894775e-01
  1.79478049e-01 7.86389038e-03 1.76190131e-03 6.90557016e-03]
 [2.12602936e-05 2.18511513e-03 7.05267894e-06 7.86967030e-06
  1.18047647e-05 1.17567414e-03 3.72250262e-03 1.86783413e-03]
 [5.87091345e-05 3.29948496e-03 2.12966261e-05 2.26669854e-05
  3.46580200e-05 1.809


MLP outputs (Layer 0) for the first sample (all positions):
[[ 0.5772198   1.362151   -1.3774916  -1.8504407 ]
 [-0.04772752 -0.36423826  0.70675576  0.29744932]
 [-0.85439795  0.44644475  0.98676914  0.49766186]
 [ 0.20364624 -0.7570492   0.64503175  0.49685082]
 [ 0.48724452  0.00672062  0.01943291 -0.40467018]
 [ 0.14903113 -0.62027526  0.6127696   0.29127398]
 [-1.3255497   1.9239885   0.42525738 -1.4754028 ]
 [ 1.7486238  -1.5250735  -0.20520325  0.31108633]
 [-0.8673246   0.38089585  1.1156074   0.55842304]
 [ 1.93145    -1.5567476  -0.27490085  0.22521372]
 [-0.15052864 -0.19034548  0.30448812  0.18020865]
 [-0.08357406  0.19603217  0.13471311 -0.04038196]
 [-0.43791437 -0.08413844  0.98756063  0.44646135]
 [-0.7415186   0.1855402   1.0336      0.48027828]
 [ 0.44363794 -0.56427467  0.27295482  0.04862161]
 [-0.08216061  0.22160795  0.07429782 -0.06159005]]
