Goal: Look for neurons who's activation is low but DLA is high

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the TransformerLens code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional, DefaultDict
from functools import partial
import copy
from pprint import pprint

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)



Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-1369s_ba
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-1369s_ba
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 218ebd6f491f47f5e2f64e4c4327548b60a093eb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting typeguard<4.0.0,>=3.0.2 (from transformer-lens==0.0.0)
  Using cached typeguard-3.0.2-py3-none-any.whl (30 kB)
Installing collected packages: typeguard
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.13.3
    Uninstalling typeguard-2.13.3:
      Successfully uninstalled typeguard-2.13.3
[31mERROR: pip's dependency resolver does not currently take 

In [2]:
solu_model = HookedTransformer.from_pretrained(
    "solu-1l",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# gelu_model = HookedTransformer.from_pretrained(
#     "gelu-1l",
#     center_unembed=True,
#     center_writing_weights=True,
#     fold_ln=True,
#     refactor_factored_attn_matrices=True,
# )

Loaded pretrained model solu-1l into HookedTransformer


### Check how does post LN activations change after ablating

In [3]:
def get_std(activations):
    return activations.std(axis = -1)

def get_max_activations(activations):
    args = np.argmax(activations, axis = -1)
    return (args, activations[np.arange(len(args)), args])

def get_tokens_for_prompts(model, prompt):
    p_str_tokens = model.to_str_tokens(prompt)
    p_tokens = model.to_tokens(prompt)
    return p_tokens, p_str_tokens

In [4]:
mlp_stores = []
ln_stores = []

In [5]:
prompt = "After Martin and Amy went to the park, Martin gave a drink to"
p_tokens, p_str_tokens = get_tokens_for_prompts(solu_model, prompt)

In [6]:
mlp_layer = "blocks.0.mlp.hook_mid"
ln_layer = "blocks.0.mlp.ln.hook_normalized" # same as mlp_post_hook

cache = solu_model.run_with_cache(prompt)

In [7]:
mlp_cache = np.array(cache[1][mlp_layer].cpu().detach()[0])
ln_cache = np.array(cache[1][ln_layer].cpu().detach()[0])

In [8]:
get_max_activations(mlp_cache), get_max_activations(ln_cache)

((array([1223, 1551,  134, 1648,  414,  652, 1634, 1634, 1695,  649,  989,
         1791, 1952,  523, 1807]),
  array([0.03622018, 0.00282079, 0.00488895, 0.00742962, 0.01096074,
         0.00844151, 0.01454535, 0.01004561, 0.01065427, 0.0073372 ,
         0.00801913, 0.00715461, 0.01679082, 0.00538308, 0.0079624 ],
        dtype=float32)),
 (array([1223, 1551,  134, 1648,  414,  652, 1634, 1634, 1695,  649,  989,
         1791, 1952,  523, 1807]),
  array([10.664263 ,  0.8787284,  1.5030432,  2.319249 ,  3.3610477,
          2.597603 ,  4.4826674,  3.1223977,  3.2939868,  2.2796106,
          2.456409 ,  2.2179003,  5.1493993,  1.6388619,  2.4618258],
        dtype=float32)))

Getting some intuition graphically

In [9]:
fig = px.histogram(np.array(ln_cache[0]))
fig.show()

In [10]:
fig = px.histogram(np.array(mlp_cache[-4]))
fig.show()

In [11]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

def plot_caches(mlp_cache, ln_cache):
    fig = make_subplots(rows=len(mlp_cache), cols=2)
    for i in range(len(mlp_cache)):
        fig.add_trace(go.Histogram(x=mlp_cache[i]), row=i+1, col=1)
        fig.add_trace(go.Histogram(x=ln_cache[i]), row=i+1, col=2)
    fig.update_layout(
        height=1600,
        width=1600,
        margin=dict(t=0, b=0),
        autosize=False,
    )
    fig.show()
plot_caches(mlp_cache, ln_cache)

In [12]:
np.linalg.norm(ln_cache, axis = -1)

array([15.451817 ,  4.7964087,  6.818752 ,  6.357008 ,  8.648512 ,
        7.6557727,  8.9632015,  7.71339  ,  7.657226 ,  7.63085  ,
        8.745004 ,  6.606222 ,  9.572149 ,  7.672827 ,  7.8406954],
      dtype=float32)

In [61]:
prompts = [
    "I am happy.",
    "The sun shines.",
    "Cats meow.",
    "Dogs bark.",
    "Birds fly.",
    "Time flies.",
    "Love conquers.",
    "Dreams inspire.",
    "Music heals.",
    "Laughter echoes."
    "The sky is blue.",
    "I love pizza.",
    "She walks to work.",
    "He plays guitar.",
    "We went to the beach.",
    "They are watching a movie.",
    "I need some coffee.",
    "The cat is sleeping.",
    "He ran to catch the bus.",
    "She smiled and waved goodbye.",
    "The book is on the table.",
    "They laughed at the joke.",
    "I want to learn coding.",
    "We had a great time.",
    "He asked me a question.",
    "She sings beautifully.",
    "The sun sets in the west.",
    "They went hiking in the mountains.",
    "I forgot my keys at home.",
    "He bought a new car.",
    "She enjoys playing tennis.",
    "We had dinner at a fancy restaurant.",
    "They are planning a trip to Europe.",
    "I saw a shooting star last night.",
    "She wrote a letter to her friend.",
    "The dog chased its tail."
]

Max activations(absolute): get the maximum possible activation for a neuron(the numbers seem very low here)

In [14]:
# get max activation of all neurons for prompts in dataset (pre and post LN)
max_mlp_activations = np.zeros((2048, 3))-10**6
max_ln_activations = np.zeros((2048, 3))-10**6
def get_max_activation_hook(prompt_num: int,
    max_activations: np.ndarray,
    print_val,
    x: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint):
    x_np = np.array(x[0].detach().cpu())
    max_vals=np.max(x_np, axis=1, keepdims=True)
    # x_np *= 1/max_vals
    max_vals = np.max(x_np, axis=0)
    max_indices = np.argmax(x_np, axis=0)
    # print(print_val, max_vals[519])
    mask = max_vals > max_activations[:, 0]

    max_activations[mask, 0] = max_vals[mask]
    max_activations[mask, 1] = prompt_num
    max_activations[mask, 2] = max_indices[mask]

for pn in range(len(prompts)):
    p = prompts[pn]
    mlp_hooker = partial(get_max_activation_hook, pn, max_mlp_activations, "===mlp===")
    ln_hooker = partial(get_max_activation_hook, pn, max_ln_activations, "===ln===")

    solu_model.run_with_hooks(
        p,
        return_type = "logits",
        fwd_hooks = [
            (mlp_layer, mlp_hooker),
            (ln_layer, ln_hooker)
        ]
    )


In [15]:
idx =  np.argmax(max_mlp_activations[:, 0])
max_mlp_activations[idx], max_ln_activations[idx], max_ln_activations[idx, 0]/max_mlp_activations[idx, 0]

(array([0.03622018, 9.        , 0.        ]),
 array([10.66426277,  9.        ,  0.        ]),
 294.4287786003592)

Max activations seems to be very, very low!!
1. Let's try ablating neurons with max activation ratio(abs(post LN)/abs(pre LN), relative to the max activation at that token)

    Intuition: neurons who's relative activations are low, will have a very high ratio post LN (in retropect, I could've just filtered wrt max activating neuron and not take the ratio: could have been less noisy- big oof)

2. Or we could do it on a larger dataset

Easier thing(1) first

In [16]:
ratio = max_ln_activations[:, 0]/(max_mlp_activations[:, 0] + 1e-12)
idx = np.argmax(ratio)
max_ln_activations[idx], max_mlp_activations[idx]

(array([-1.37378564e-02,  1.80000000e+01,  6.00000000e+00]),
 array([-6.67751783e-06,  1.50000000e+01,  3.00000000e+00]))

In [17]:
idx, ratio[idx]

(2041, 2057.3301060681656)

Doing it offline on ratios is better, ig

Also, caveat: Ratios can be noisy

In [69]:
caches = []
for p in prompts:
    caches.append(solu_model.run_with_cache(p))

In [70]:
ratios = []
for c in caches:
    mlp_c = np.array(c[1][mlp_layer].cpu().detach())
    ln_c = np.array(c[1][ln_layer].cpu().detach())
    # normalize wrt max neuron
    mlp_c /= mlp_c.max(axis=1, keepdims=True)
    ln_c /= ln_c.max(axis=1, keepdims=True)
    ratio = np.absolute(ln_c/(mlp_c + 1e-6))
    ratios.append(ratio)

In [20]:
caches[0][1][mlp_layer].shape

torch.Size([1, 5, 2048])

In [21]:
ratios[0].shape

(1, 5, 2048)

In [22]:
idxs = np.argmax(ratios[0][0], axis=-1)
idx = idxs[0]
mlp_c = caches[0][1][mlp_layer].cpu().detach()
ln_c = caches[0][1][ln_layer].cpu().detach()
idx, (mlp_c)[0, 0][idx], ln_c[0][0][idx], ratios[0][0, 0, idx]

(525, tensor(-4.8048e-08), tensor(-0.1012), 17607.037)

In [23]:
px.histogram(mlp_c[0,0])

In [24]:
ln_c[0, 0, mlp_c[0,0].argmax()]

tensor(10.6643)

In [25]:
px.histogram(ratios[0][0][0])

### Do DLA after ablating lower activation neurons post LN

In [63]:
prompt = prompts[0]
prompt

'I am happy.'

In [64]:
def ablate_layer_hook(
    neuron_id: Int,
    pos: Int,
    x: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint): # where do I get the data from? Res stream or ln?
    x[:, pos, neuron_id] = 0
    return x

In [65]:
baseline_logits = [caches[i][0] for i in range(len(caches))]

In [66]:
ablation_logits = []
for i in range(len(idxs)): # currently only for the first prompt
    ablate_hooker = partial(ablate_layer_hook, idxs[i], i)
    logit = solu_model.run_with_hooks(
        prompt,
        return_type = "logits",
        fwd_hooks = [
            (ln_layer, ablate_hooker)
        ]
    )
    ablation_logits.append(logit)

In [67]:
baseline_logit = baseline_logits[0]

In [68]:
sims = []
for i in range(len(ablation_logits)):
    al = np.array(ablation_logits[i][0, i, :].cpu().detach())
    bl = np.array(baseline_logit[0, i, :].cpu().detach())
    sims.append((al@bl)/(np.linalg.norm(al)*np.linalg.norm(bl)))

In [32]:
sims # bruh

[0.9999594, 0.9999765, 0.99999183, 0.9999973, 0.9999778]

In [60]:
from tqdm import tqdm
def run_ablation_on_prompts(prompts, sim_whole, baseline_logits, ratios,
                            number_of_neurons=200): # ~ 10% of d_mlp (=2048)
    for prompt_num in range(len(prompts)):
        prompt = prompts[prompt_num]
        # idxs = np.argmax(ratios[prompt_num][0], axis=-1) # maybe get more numbers here
        baseline_logit = baseline_logits[prompt_num]
        sims_prompt = []
        num_tokens = ratios[prompt_num].shape[1]
        print(prompt_num, ":", prompt)
        for token_num in tqdm(range(num_tokens)):
            idxs = np.argsort(ratios[prompt_num][0, token_num])[:number_of_neurons]
            for idx in idxs:
                ratio = ratios[prompt_num][0, token_num, :]
                ablate_hooker = partial(ablate_layer_hook, idx, token_num)
                ablation_logits = solu_model.run_with_hooks(
                    prompt, # we don't care about the rest of the prompt
                    return_type = "logits",
                    fwd_hooks = [
                        (ln_layer, ablate_hooker)
                    ]
                )
                al = np.array(ablation_logits[0, token_num, :].cpu())
                bl = np.array(baseline_logit[0, token_num, :].cpu())
                sims_prompt.append((al@bl)/(np.linalg.norm(al)*np.linalg.norm(bl)))
        sims_whole.append(sims_prompt)

In [None]:
sims_whole = [] # prompt x token
run_ablation_on_prompts(prompts, sims_whole, baseline_logits, ratios)

In [36]:
len(sims_whole[0]), ratios[0].shape[1]*200 # sanity- should be num_tokens x 200

(1000, 1000)

In [37]:
sims_whole # bruh

[[0.99999994,
  0.99999994,
  0.99999994,
  0.99999994,
  0.9999999,
  0.99999994,
  1.0,
  0.9999999,
  1.0,
  1.0000001,
  0.99999994,
  1.0000001,
  0.9999999,
  1.0,
  0.99999994,
  0.9999999,
  1.0,
  0.99999994,
  0.9999999,
  0.9999999,
  0.9999999,
  1.0,
  0.9999999,
  0.99999994,
  0.9999998,
  0.99999994,
  0.9999998,
  0.9999999,
  0.99999976,
  0.9999999,
  0.99999976,
  0.9999999,
  0.99999994,
  0.99999976,
  0.99999964,
  0.99999976,
  0.99999994,
  0.9999998,
  0.99999946,
  1.0,
  0.9999999,
  0.99999976,
  0.99999976,
  0.99999887,
  0.99999994,
  0.9999999,
  0.99999964,
  0.9999998,
  0.9999999,
  1.0,
  0.99999946,
  0.99999964,
  0.99999976,
  0.9999999,
  0.9999998,
  0.9999997,
  0.99999976,
  0.9999995,
  0.9999997,
  0.9999997,
  0.9999994,
  0.99999976,
  0.9999995,
  0.9999997,
  0.99999946,
  0.9999992,
  0.9999999,
  0.9999991,
  0.9999998,
  0.9999996,
  0.9999994,
  0.99999976,
  0.99999976,
  0.9999995,
  0.9999996,
  0.99999964,
  0.9999995,
  0.99999

In [38]:
min_sim = 1
for i in sims_whole:
    for j in i:
        if(j < min_sim): min_sim = j

min_sim

0.9992611

There seems to be almost no difference here. We might have meaningful neurons(with so low activations??) in these prompts. Try for more prompts

In [36]:
# max activation found post ln
max = -10**6
for c in caches:
    ln_c = c[1][ln_layer]
    max_curr = ln_c.max()
    if(max_curr > max):
        max = max_curr
max # for reference, max on pre ln was .03

tensor(10.6643, device='cuda:0')

More prompts

In [39]:
%pip install datasets
from datasets import load_dataset
import numpy as np

# dataset = load_dataset("roneneldan/TinyStories")
dataset = load_dataset("generics_kb", "generics_kb_simplewiki")





  0%|          | 0/1 [00:00<?, ?it/s]

In [52]:
np.random.seed(0)
prompt_idxs = np.random.choice(len(dataset['train']['sentence']), 1000)
prompts = [dataset['train']['sentence'][i] for i in prompt_idxs]

max activations

In [54]:
max_mlp_activations = np.zeros((2048, 3))-10**6
max_ln_activations = np.zeros((2048, 3))-10**6
for pn in range(len(prompts)):
    p = prompts[pn]
    mlp_hooker = partial(get_max_activation_hook, pn, max_mlp_activations, "===mlp===")
    ln_hooker = partial(get_max_activation_hook, pn, max_ln_activations, "===ln===")

    solu_model.run_with_hooks(
        p,
        return_type = "logits",
        fwd_hooks = [
            (mlp_layer, mlp_hooker),
            (ln_layer, ln_hooker)
        ]
    )


In [57]:
idx = np.argmax(max_mlp_activations[:, 0])
max_mlp_activations[idx], max_ln_activations[idx], max_ln_activations[idx, 0]/max_mlp_activations[idx, 0] # it's still the same!!

(array([0.03622018, 0.        , 0.        ]),
 array([10.66426277,  0.        ,  0.        ]),
 294.4287786003592)

In [58]:
# max activation found post ln
max = -10**6
for c in caches:
    ln_c = c[1][ln_layer]
    max_curr = ln_c.max()
    if(max_curr > max):
        max = max_curr
max # for reference, max on pre ln was .03

tensor(10.6643, device='cuda:0')

ablations on highest ratios

In [43]:
caches = []
for p in prompts:
    caches.append(solu_model.run_with_cache(p))
baseline_logits = [caches[i][0] for i in range(len(caches))]

In [44]:
ratios = []
for c in caches:
    mlp_c = np.array(c[1][mlp_layer].cpu().detach())
    ln_c = np.array(c[1][ln_layer].cpu().detach())
    # normalize wrt max neuron
    mlp_c /= mlp_c.max(axis=1, keepdims=True)
    ln_c /= ln_c.max(axis=1, keepdims=True)
    ratio = np.absolute(ln_c/(mlp_c + 1e-6))
    ratios.append(ratio)

In [48]:
sims_whole = [] # prompt x token
run_ablation_on_prompts(prompts, sims_whole, baseline_logits, ratios, number_of_neurons=20) # this takes too long

In [None]:
sims_whole

In [50]:
min_sim = 1
for i in sims_whole:
    for j in i:
        if(j < min_sim): min_sim = j

min_sim # all that for a drop of blood

0.9989134