# Setup

In [2]:
import plotly.io as pio
try:
    import google.colab
    print("Running as a Colab notebook")
    pio.renderers.default = "colab"
    %pip install transformer-lens fancy-einsum
    %pip install -U kaleido # kaleido only works if you restart the runtime. Required to write figures to disk (final cell)
except:
    print("Running as a Jupyter notebook")
    pio.renderers.default = "vscode"
    from IPython import get_ipython
    ipython = get_ipython()

Running as a Jupyter notebook


In [59]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import einops
from typing import List, Union, Optional
from functools import partial
import pandas as pd
from pathlib import Path
import urllib.request
from bs4 import BeautifulSoup
from tqdm import tqdm
from datasets import load_dataset
import os
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false" # https://stackoverflow.com/q/62691279
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
!pip install circuitsvis
import circuitsvis as cv


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [5]:
pio.renderers.default='vscode'

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [6]:
model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


## Find Activating Examples

In [7]:
from datasets import load_dataset
dataset = load_dataset("NeelNanda/pile-10k", split="train")

Using custom data configuration NeelNanda--pile-10k-72f566e9f7c464ab
Found cached dataset parquet (/Users/clementneo/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [8]:
print(len(dataset.to_dict()['text'][0]))
dataset_text_list = dataset.to_dict()['text']
print(model.to_tokens(dataset.to_dict()['text'][0:5]).shape)

13274
torch.Size([5, 1024])


In [64]:
neurons_json = json.load(open("neuron_finder_results.json"))
neurons = []
for layer, results in neurons_json.items():
    indices = results.keys()
    for index in indices:
        neurons.append((int(layer), int(index)))

print(neurons)

[(31, 3621), (31, 364), (31, 2918), (31, 4378), (31, 988), (31, 2658), (31, 2692), (31, 4941), (31, 2415), (31, 1407), (32, 4964), (32, 2412), (32, 4282), (32, 3151), (32, 1155), (32, 1386), (32, 3582), (32, 4882), (32, 3477), (32, 406), (33, 1202), (33, 524), (33, 1582), (33, 4446), (33, 204), (33, 4900), (33, 2322), (33, 3278), (33, 1299), (33, 52), (34, 4012), (34, 4262), (34, 320), (34, 5095), (34, 2599), (34, 2442), (34, 4494), (34, 4199), (34, 727), (34, 4410), (35, 4518), (35, 48), (35, 5014), (35, 3724), (35, 3360), (35, 885), (35, 4924), (35, 274), (35, 2369), (35, 4638)]


In [53]:
def get_neurons_acts(model, texts, neurons):
    tokens = model.to_tokens(texts)
    
def cache_to_tuples(cache):
    new_cache = {}
    for key in cache.keys():
        x = torch.max(cache[key], dim=1)
        y = list(x)
        y = [y[0].tolist(), y[1].tolist()]
        y = list(zip(*y))
        new_cache[key] = y # y is a list of tuples, i.e. [(max_value, max_index), ...]
    return new_cache

In [60]:
batch_size = 4
batched_texts = [dataset_text_list[i: i+batch_size] for i in range(0, len(dataset_text_list), batch_size)]
print(len(batched_texts))

neuron_max_acts = {neuron: [] for neuron in neurons}

for texts in tqdm(batched_texts):
    model.reset_hooks()

    cache = {}

        def return_caching_hook(neuron):
            layer, neuron_index = neuron
            def caching_hook(act, hook):
                cache[(layer, neuron_index)] = act[:, :, neuron_index] # act shape is (batch_size, seq_len, neuron_index)
            return caching_hook
        
        hooks = list(((f"blocks.{layer}.mlp.hook_post", return_caching_hook((layer, index))) for layer, index in neurons))
        print(hooks)

    model.run_with_hooks(
        model.to_tokens(texts),
        fwd_hooks=hooks,
    )
    cache = cache_to_tuples(cache)

    for key in cache.keys():
        neuron_max_acts[key].extend(cache[key])



2500


  0%|          | 0/2500 [00:00<?, ?it/s]

[('blocks.0.mlp.hook_post', <function return_caching_hook.<locals>.caching_hook at 0x2974e0160>), ('blocks.1.mlp.hook_post', <function return_caching_hook.<locals>.caching_hook at 0x2c775caf0>)]


  0%|          | 0/2500 [00:33<?, ?it/s]


In [67]:
print(neuron_max_acts)
neuron_max_acts_json = {str(key): value for key, value in neuron_max_acts.items()}
with open("neuron_max_acts.json", "w") as f:
    json.dump(neuron_max_acts_json, f)

{(0, 0): [(0.4795704782009125, 488), (1.505552887916565, 152), (1.5584756135940552, 137), (-0.0, 22)], (1, 1): [(-0.0, 566), (0.20460055768489838, 810), (2.280916929244995, 76), (0.30561938881874084, 66)]}


In [70]:
with open("neuron_max_acts.json", "r") as f:
    neuron_max_acts_load = json.load(f)

print(neuron_max_acts_load.keys())
# This is a dictionary with string keys '(layer, neuron)', and each value is a list of [(activation, index)] where each tuple is an example.

dict_keys(['(31, 3621)', '(31, 364)', '(31, 2918)', '(31, 4378)', '(31, 988)', '(31, 2658)', '(31, 2692)', '(31, 4941)', '(31, 2415)', '(31, 1407)', '(32, 4964)', '(32, 2412)', '(32, 4282)', '(32, 3151)', '(32, 1155)', '(32, 1386)', '(32, 3582)', '(32, 4882)', '(32, 3477)', '(32, 406)', '(33, 1202)', '(33, 524)', '(33, 1582)', '(33, 4446)', '(33, 204)', '(33, 4900)', '(33, 2322)', '(33, 3278)', '(33, 1299)', '(33, 52)', '(34, 4012)', '(34, 4262)', '(34, 320)', '(34, 5095)', '(34, 2599)', '(34, 2442)', '(34, 4494)', '(34, 4199)', '(34, 727)', '(34, 4410)', '(35, 4518)', '(35, 48)', '(35, 5014)', '(35, 3724)', '(35, 3360)', '(35, 885)', '(35, 4924)', '(35, 274)', '(35, 2369)', '(35, 4638)'])


In [95]:
for key in neuron_max_acts_load.keys():
    x = neuron_max_acts_load[key]
    # Get the maximum of the first integer of each tuple, and get its index
    max_value, max_index = max(x, key=lambda x: x[0])
    print(key, max_value)

(31, 3621) 11.057353973388672
(31, 364) 10.807672500610352
(31, 2918) 12.039180755615234
(31, 4378) 9.483403205871582
(31, 988) 11.906610488891602
(31, 2658) 12.757360458374023
(31, 2692) 8.618489265441895
(31, 4941) 9.564465522766113
(31, 2415) 12.56452465057373
(31, 1407) 11.221527099609375
(32, 4964) 10.279662132263184
(32, 2412) 14.07355785369873
(32, 4282) 9.611800193786621
(32, 3151) 12.704906463623047
(32, 1155) 17.4642391204834
(32, 1386) 8.08731460571289
(32, 3582) 8.37304973602295
(32, 4882) 7.985595226287842
(32, 3477) 10.776716232299805
(32, 406) 10.953438758850098
(33, 1202) 9.49435806274414
(33, 524) 10.278265953063965
(33, 1582) 6.070120811462402
(33, 4446) 6.540103912353516
(33, 204) 6.799570560455322
(33, 4900) 6.093514442443848
(33, 2322) 9.797901153564453
(33, 3278) 7.231164932250977
(33, 1299) 7.962460517883301
(33, 52) 7.98359489440918
(34, 4012) 13.326713562011719
(34, 4262) 12.96543025970459
(34, 320) 11.070517539978027
(34, 5095) 11.418562889099121
(34, 2599) 14

In [100]:
# Scatter plot a neuron
neuron_str = "(35, 3724)"
x = neuron_max_acts_load[neuron_str]
# Scatter plot the activations
scatter(
    x=[i for i in range(len(x))],
    y=[i[0] for i in x],
    xaxis="Text index",
    yaxis="Activation",
    caxis="",
    title=f"Activations for neuron {neuron_str}",
)