# Setup

In [2]:
import plotly.io as pio
try:
    import google.colab
    print("Running as a Colab notebook")
    pio.renderers.default = "colab"
    %pip install transformer-lens fancy-einsum
    %pip install -U kaleido # kaleido only works if you restart the runtime. Required to write figures to disk (final cell)
except:
    print("Running as a Jupyter notebook")
    pio.renderers.default = "vscode"
    from IPython import get_ipython
    ipython = get_ipython()

Running as a Jupyter notebook


In [13]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import einops
from typing import List, Union, Optional
from functools import partial
import pandas as pd
from pathlib import Path
import urllib.request
from bs4 import BeautifulSoup
from tqdm import tqdm
from datasets import load_dataset
import os
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false" # https://stackoverflow.com/q/62691279
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
!pip install circuitsvis
import circuitsvis as cv


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [5]:
pio.renderers.default='vscode'

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [6]:
model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


# Finding top token-aligned neurons

In [11]:
results = {key: {} for key in [31, 32, 33, 34, 35]}


for n in results.keys():
    n_layer_neurons = model.W_out[n, :, :]
    unembedding = model.W_U
    dot_product = einsum("neuron embed, embed token -> neuron token", n_layer_neurons, unembedding)
    values, indices = torch.max(dot_product, dim=-1)

    num = 5000
    top_values, top_indices = torch.topk(values, k=num)
    neurons_to_find = 10
    neurons_found = 0
    for i in range(num):

        str_token = model.to_string(indices[top_indices[i]])
        if len(str_token) <= 2 or str_token[0] != " ":
            continue
        print(f"L{n}N{top_indices[i]}: {str_token} ({top_values[i]:.2f})")
        results[n][top_indices[i].item()] = [str_token, top_values[i].item()]

        neurons_found += 1
        if neurons_found >= neurons_to_find:
            break

L31N3621:  only (5.47)
L31N364:  number (5.32)
L31N2918:  go (5.09)
L31N4378:  together (5.04)
L31N988:  called (5.02)
L31N2658:  first (4.90)
L31N2692:  used (4.67)
L31N4941:  within (4.50)
L31N2415:  way (4.37)
L31N1407:  out (4.27)
L32N4964:  too (5.54)
L32N2412:  will (5.11)
L32N4282:  right (4.93)
L32N3151:  over (4.80)
L32N1155:  out (4.68)
L32N1386:  once (4.45)
L32N3582:  her (4.26)
L32N4882:  class (4.21)
L32N3477:  use (4.20)
L32N406:  much (4.20)
L33N1202:  so (6.02)
L33N524:  state (5.66)
L33N1582:  RandomRedditor (5.30)
L33N4446:  about (4.55)
L33N204:  following (4.46)
L33N4900:  of (4.32)
L33N2322:  after (4.12)
L33N3278:  around (4.12)
L33N1299:  last (4.08)
L33N52:  by (3.96)
L34N4012:  off (6.54)
L34N4262:  down (5.78)
L34N320:  back (5.68)
L34N5095:  well (5.32)
L34N2599:  there (4.88)
L34N2442:  up (4.46)
L34N4494:  no (4.45)
L34N4199:  after (4.34)
L34N727:  under (4.14)
L34N4410:  as (3.21)
L35N4518:  issue (3.79)
L35N48:  close (3.53)
L35N5014:  won (3.10)
L35N37

In [14]:
with open("neuron_finder_results.json", "w") as f:
    json.dump(results, f)