<a href="https://colab.research.google.com/github/ckkissane/mech-interp-practice/blob/main/training/train_majority_element.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup
(No need to read)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-apgdt_et
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-apgdt_et
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 10d2f8a026d73eada861c7d51064f7e24d8f482c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.3-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━

In [2]:
try:
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
except:
    import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-d4elt24f
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-d4elt24f
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit cc216772a66af819ff3a77038e53134f3e073af4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting importlib-metadata<6.0.0,>=5.1.0 (from circuitsvis==0.0.0)
  Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)
Building wheels for collected packages: circuitsvis
  Building wheel for circuitsvis (pyproject.toml) ... [?25l[?25hdone
  Created wheel for circuitsvis: filename=circuitsvis-0.0.0-py3-none-any.whl size=1808565 sha256=00f6ea06383f60ca2f217e993df066f72a161cc8394eab833e

In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

In [4]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import circuitsvis as cv
import einops
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
from dataclasses import dataclass
import datasets
from IPython.display import HTML

In [5]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [6]:
import plotly.graph_objects as go

update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"}
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = utils.to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=utils.to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

In [7]:
import transformer_lens.patching as patching
from transformer_lens import evals
import math

# Load Model

In [8]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7bffdd4cf7c0>

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [10]:
MAX_ELT = 50
LIST_LEN = 10
cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=64,
    attn_only=True,
    d_head=64,
    n_heads=1,
    normalization_type=None,
    d_vocab=MAX_ELT+2, # 0, ..., MAX_ELT-1, BOS, END
    n_ctx=LIST_LEN+2, #BOS a1 ... an END
    device=device,
    seed=0
)

model = HookedTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)


In [11]:
def disable_biases(model):
    for name, param in model.named_parameters():
        if 'b_' in name:
            param.requires_grad = False

disable_biases(model)

In [12]:
def disable_pos_embed(model):
    assert model.cfg.positional_embedding_type == "standard"
    model.pos_embed.W_pos = nn.Parameter(torch.zeros_like(model.pos_embed.W_pos))
    model.pos_embed.W_pos.requires_grad = False

disable_pos_embed(model)

In [13]:
for name, param in model.named_parameters():
    print(name, param.shape, param.requires_grad)

embed.W_E torch.Size([52, 64]) True
pos_embed.W_pos torch.Size([12, 64]) False
blocks.0.attn.W_Q torch.Size([1, 64, 64]) True
blocks.0.attn.W_K torch.Size([1, 64, 64]) True
blocks.0.attn.W_V torch.Size([1, 64, 64]) True
blocks.0.attn.W_O torch.Size([1, 64, 64]) True
blocks.0.attn.b_Q torch.Size([1, 64]) False
blocks.0.attn.b_K torch.Size([1, 64]) False
blocks.0.attn.b_V torch.Size([1, 64]) False
blocks.0.attn.b_O torch.Size([64]) False
unembed.W_U torch.Size([64, 52]) True
unembed.b_U torch.Size([52]) False


# Task dataset

In [14]:
BOS_TOKEN = cfg.d_vocab - 1
END_TOKEN = cfg.d_vocab - 2
print(BOS_TOKEN, END_TOKEN)

51 50


In [15]:
majority_num = LIST_LEN // 2 + 1

batch_size = 1
x = torch.randint(0, MAX_ELT, (batch_size, LIST_LEN))
print("x", x)

majority_elt = random.randint(0, MAX_ELT-1)
print("majority elt", majority_elt)
majority_indices = torch.randperm(LIST_LEN)[:majority_num]
print("majority_indices", majority_indices)

x[:, majority_indices] = majority_elt
print("x", x)

x tensor([[ 1, 19, 47, 47, 46, 46,  3, 27,  7, 40]])
majority elt 24
majority_indices tensor([1, 8, 5, 4, 9, 7])
x tensor([[ 1, 24, 47, 47, 24, 24,  3, 24, 24, 24]])


In [16]:
def make_data_generator(cfg, batch_size, seed=0):
    torch.manual_seed(seed)
    random.seed(seed)
    bos_vec = torch.full((batch_size, 1), BOS_TOKEN)
    end_vec = torch.full((batch_size, 1), END_TOKEN)
    majority_num = LIST_LEN // 2 + 1
    while True:
        x = torch.randint(0, MAX_ELT, (batch_size, LIST_LEN))
        majority_elt = random.randint(0, MAX_ELT-1)
        majority_indices = torch.randperm(LIST_LEN)[:majority_num]
        x[:, majority_indices] = majority_elt
        yield torch.cat([bos_vec, x, end_vec], dim=-1), torch.full((batch_size,), majority_elt)

print(next(make_data_generator(cfg, 4)))

(tensor([[51, 24, 39, 24, 10, 24, 24, 27, 24, 24, 33, 50],
        [51, 24, 16, 24, 49, 24, 24,  6, 24, 24, 33, 50],
        [51, 24, 19, 24,  4, 24, 24, 19, 24, 24, 30, 50],
        [51, 24, 12, 24, 10, 24, 24, 22, 24, 24, 38, 50]]), tensor([24, 24, 24, 24]))


# Loss Fn

In [17]:
def loss_fn(logits, labels, per_token=False):
    if logits.ndim==3:
        logits=logits[:, -1, :]
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[..., None])[..., 0]
    if per_token:
        return -correct_log_probs
    else:
        return -correct_log_probs.mean()

with torch.no_grad():
    tokens, labels = next(make_data_generator(cfg, 4))
    tokens, labels = tokens.to(device), labels.to(device)
    logits = model(tokens)
    loss = loss_fn(logits, labels)
    print("Loss", loss)

Loss tensor(3.9843, device='cuda:0')


In [18]:
print("uniform loss", np.log(cfg.d_vocab_out))

uniform loss 3.9512437185814275


# Setup optimizer / dataloader

In [19]:
lr = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

batch_size = 256
train_data_loader = make_data_generator(cfg, batch_size)

# Training Loop

In [20]:
num_epochs = 4000

train_losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
    tokens, labels = next(train_data_loader)
    tokens, labels = tokens.to(device), labels.to(device)
    logits = model(tokens)
    loss = loss_fn(logits, labels)
    loss.backward()
    train_losses.append(loss.item())

    optimizer.step()
    optimizer.zero_grad()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, train loss: {loss.item()}")

  0%|          | 0/4000 [00:00<?, ?it/s]

Epoch 0, train loss: 3.986286163330078
Epoch 100, train loss: 3.806046485900879
Epoch 200, train loss: 3.6536495685577393
Epoch 300, train loss: 2.068173885345459
Epoch 400, train loss: 2.7019033432006836
Epoch 500, train loss: 0.19587412476539612
Epoch 600, train loss: 1.5775014162063599
Epoch 700, train loss: 0.06599603593349457
Epoch 800, train loss: 0.6283575892448425
Epoch 900, train loss: 0.04852665960788727
Epoch 1000, train loss: 0.6040941476821899
Epoch 1100, train loss: 0.01572960987687111
Epoch 1200, train loss: 0.05396657809615135
Epoch 1300, train loss: 0.08357830345630646
Epoch 1400, train loss: 0.01879332959651947
Epoch 1500, train loss: 0.02646718919277191
Epoch 1600, train loss: 0.015377746894955635
Epoch 1700, train loss: 0.024489127099514008
Epoch 1800, train loss: 0.023379238322377205
Epoch 1900, train loss: 0.005863825790584087
Epoch 2000, train loss: 0.010808978229761124
Epoch 2100, train loss: 0.006896219216287136
Epoch 2200, train loss: 0.005783495958894491
Epoc

In [21]:
line(
    train_losses,
    title="Train Loss Curve",
    xaxis="Epoch", yaxis="Loss"
)

# Sanity Check

In [22]:
test_sample, test_labels = next(train_data_loader)
test_sample, test_labels = test_sample.to(device), test_labels.to(device)
print(test_sample.shape)
print(test_sample[:5])

print(test_labels.shape)
print(test_labels[:5])

torch.Size([256, 12])
tensor([[51, 13, 13, 22,  5, 13, 41, 17, 13, 13, 13, 50],
        [51, 13, 13, 12, 20, 13, 14,  4, 13, 13, 13, 50],
        [51, 13, 13, 20, 33, 13,  0, 29, 13, 13, 13, 50],
        [51, 13, 13, 43, 36, 13, 15, 36, 13, 13, 13, 50],
        [51, 13, 13, 41,  2, 13,  8,  0, 13, 13, 13, 50]], device='cuda:0')
torch.Size([256])
tensor([13, 13, 13, 13, 13], device='cuda:0')


In [23]:
with torch.inference_mode():
    logits = model(test_sample)
    logits = logits[:, -1, :]
    loss = loss_fn(logits, test_labels)
    print("Test sample loss", loss.item())

    preds = logits.argmax(dim=-1)
    acc = (preds == test_labels).float().mean()
    print("test sample acc", acc.item())

Test sample loss 0.0022177386563271284
test sample acc 1.0


# Save model

In [25]:
%mkdir ../models

In [26]:
# Save the model
filename = "../models/majority_element_model.pt"
torch.save(model.state_dict(), filename)

In [27]:
# Check we can load in the model
MAX_ELT = 50
LIST_LEN = 10
cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=64,
    attn_only=True,
    d_head=64,
    n_heads=1,
    normalization_type=None,
    d_vocab=MAX_ELT+2, # 0, ..., MAX_ELT-1, BOS, END
    n_ctx=LIST_LEN+2, #BOS a1 ... an END
    device=device,
    seed=0
)

model_loaded = HookedTransformer(cfg)
model_loaded.load_state_dict(torch.load(filename))

<All keys matched successfully>