In [1]:
from itertools import islice
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typed
from beartype.door import die_if_unbearable as assert_type
from datasets import load_dataset
from einops import einops as ein
from jaxtyping import Bool, Float, Int
from torch import Tensor as TT
from transformers import AutoModelForCausalLM, AutoTokenizer

%load_ext autoreload
%autoreload 2

In [2]:
model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
tokenizer = AutoTokenizer.from_pretrained("Mlxa/brackets-nested")
dataset = load_dataset("Mlxa/nested", streaming=True)["train"]

In [3]:
from utils import sh, ls
from activation_analysis import (
    input_output_mapping,
    fit_linear,
    eval_module,
    Residual,
    PrefixMean,
)


@typed
def get_prompts(n: int) -> list[str]:
    return [elem["text"] for elem in islice(dataset, n)]

In [4]:
layer = "transformer.h.7.mlp"
X_sae, Y_sae = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(200),
    input_layer=layer,
    output_layer=layer,
)
# Y_sae -= X_sae

In [5]:
from activation_analysis import fit_sae, SparseAutoEncoder

sae = SparseAutoEncoder(256, 520)
activations = ein.rearrange(Y_sae, "prompts seq d -> (prompts seq) d")
fit_sae(sae, activations, lr=1e-2, l1=3.0, alpha=1e-4, epochs=15)

100%|██████████| 15/15 [00:48<00:00,  3.25s/it,  loss=0.00, nonzero=386.17 / 520, nonorth=6.09]


In [6]:
fit_sae(sae, activations, lr=1e-3, l1=15.0, alpha=1e-4, epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s,  loss=0.01, nonzero=302.61 / 520, nonorth=5.24]

100%|██████████| 10/10 [00:34<00:00,  3.42s/it,  loss=0.01, nonzero=31.32 / 520, nonorth=4.66]


In [46]:
from activation_analysis import compressed_activations
from language_modeling import prompt_from_template, generate_sample
import plotly.express as px
from tqdm import tqdm


prefixes: list[str] = []
suffixes: list[str] = []
activations: list[Float[TT, "d"]] = []

for prompt in tqdm(get_prompts(200)):
    res = compressed_activations(model, tokenizer, prompt, {layer: sae})[layer]
    assert_type(res, Float[TT, "n d"])
    tokens = tokenizer.tokenize(prompt)
    for i in range(min(8, len(tokens))):
        prefix = " ".join(tokens[: i + 1])
        prefixes.append(prefix)
        generated_suffix = generate_sample(
            model,
            tokenizer,
            tokenizer.encode(prefix),
            max_new_tokens=4,
        )
        suffixes.append(generated_suffix)
        activations.append(res[i])

100%|██████████| 200/200 [01:24<00:00,  2.37it/s]


In [30]:
@typed
def show_activations(prefixes: list[str], activations: Float[TT, "n d"], perm: Int[TT, "k"]) -> None:
    fig = px.imshow(
        activations[:, perm],
        color_continuous_scale=px.colors.diverging.RdBu,
        labels={"x": "Dimension", "y": "Prompt", "color": "Activation"},
        x=[str(elem.item()) for elem in perm],
        # y=prefixes,
        # zmin=-1,
        # zmax=1,
    )
    fig.show()

stacked = t.stack(activations)
stds = stacked.std(dim=0)
std_argsort = t.tensor(list(stds.argsort())[::-1])
important_dirs = std_argsort[:50]
stacked -= stacked.mean(dim=0, keepdim=True)
# stacked /= stacked.norm(dim=0, keepdim=True)
show_activations(prefixes, stacked, perm=important_dirs)

In [47]:
stacked = t.stack(activations)
for pos, i in enumerate(list(stds.argsort())[::-1][:15]):
    print(f"{pos}) {i.item()} {ls(stds[i])}")

0) 368 537.78
1) 349 113.60
2) 505 110.36
3) 2 104.09
4) 202 103.08
5) 127 102.26
6) 97 102.12
7) 408 101.04
8) 334 100.26
9) 76 98.37
10) 9 97.97
11) 226 96.90
12) 84 94.88
13) 152 94.63
14) 353 94.19


In [99]:
from activation_analysis import feature_effect

short_prompts = [
    ("<10 <11 <12 12> 11> 10>", 1),
    ("<3 <4 <5 5> 4> 3>", 1),
    ("<8 <2 <7 7> 2> 8>", 1),
    ("<10 <11 <12 12> 11> 10>", 2),
    ("<3 <4 <5 5> 4> 3>", 2),
]

inspected_direction = 505

for prompt, position in short_prompts:
    eps = 1e2

    effect = feature_effect(
        model,
        tokenizer,
        prompt,
        position,
        directions={layer: (sae, inspected_direction)},
        eps=eps,
    )

    d = effect.base[f"lm_head"] * 0 + effect.diff[f"lm_head"]
    print(f"{prompt} (pos={position})")
    r = 0.01 * eps 
    window = 500

    # px.imshow(
    #     d[:, :window].detach(),
    #     zmin=-r,
    #     zmax=r,
    #     color_continuous_scale=px.colors.diverging.RdBu,
    #     x=[tokenizer.decode(i) for i in range(window)],
    # ).show()

    px.imshow(
        d[:, 1:window:2].detach(),
        # zmin=-r,
        # zmax=r,
        color_continuous_scale=px.colors.diverging.RdBu,
        x=[tokenizer.decode(2 * i + 1) for i in range(window//2)],
    ).show()

# pos: 6 4
# neg: 154 111 96 98 221 148

<10 <11 <12 12> 11> 10> (pos=1)


<3 <4 <5 5> 4> 3> (pos=1)


<8 <2 <7 7> 2> 8> (pos=1)


<10 <11 <12 12> 11> 10> (pos=2)


<3 <4 <5 5> 4> 3> (pos=2)


In [78]:
values = t.tensor([a[inspected_direction] for a in activations])
perm = values.argsort()

for i in perm[::5]:
    print(ls(values[i]), "   \t", ls(activations[i][368]), "   \t", prefixes[i], "#", suffixes[i])

-6.02    	 96.98    	 <170 170> <15 # 15> <161 161> <142
-5.27    	 90.55    	 <153 <20 20> <85 <77 # 77> 85> 153> <24
-5.04    	 96.83    	 <141 141> <139 139> <101 <158 <34 34> # 158> 101> <52 <138
-4.89    	 96.43    	 <196 196> <18 18> <52 52> <131 <225 # <145 145> 225> 131>
-4.70    	 113.39    	 <36 36> <113 113> <144 # 144> <33 33> <213
-4.56    	 94.08    	 <47 <131 131> 47> <249 <158 # 158> 249> <192 192>
-4.46    	 101.00    	 <246 246> <160 160> <81 <211 # 211> 81> <193 193>
-4.32    	 141.67    	 <170 170> <15 15> <96 <122 <78 # 78> 122> 96> <15
-4.17    	 125.80    	 <21 <48 # <2 2> 48> 21>
-4.12    	 92.68    	 <219 219> <42 <217 <181 # 181> 217> 42> <198
-4.07    	 137.40    	 <246 <133 133> <48 # 48> 246> <239 239>
-4.00    	 70.24    	 <183 <135 # 135> 183> <91 <214
-3.95    	 49.16    	 <228 # 228> <97 97> <242
-3.88    	 71.35    	 <188 <93 # 93> 188> <56 56>
-3.81    	 70.16    	 <114 <104 104> <69 <221 <158 <249 249> # 158> 221> 69> <91
-3.75    	 122.03    	 <79 7

In [136]:
while True:
    user_input = input()
    tokens = tokenizer.encode(user_input)
    if any(x >= len(tokenizer) - 2 for x in tokenizer.encode(user_input)):
        print("Invalid input. Please try again.")
        continue
    cur = compressed_activations(model, tokenizer, user_input, {layer: sae})[layer]
    inspected = cur[:, inspected_direction]
    assert inspected.shape == (len(tokens),)
    print(ls(inspected))
    str_tokens = tokenizer.tokenize(user_input)
    str_tokens = [f"{tok} ({i})" for i, tok in enumerate(str_tokens)]
    r = 200
    px.line(x=str_tokens, y=inspected, range_y=(-r, r)).show()

# pos: 212 204 6 4
# neg: 154 111 96 98 221 148

[ 415.85 580.64 -3.91 -4.53 ]


[ -4.87 -5.55 322.82 423.15 ]


[ -4.87 374.97 428.12 ]


[ 415.85 580.64 -3.91 ]


[ 415.85 455.90 563.33 562.98 -0.98 ]


[ 415.85 455.90 563.33 562.98 523.76 488.88 521.89 514.86 319.60 ]


[ 415.85 455.90 563.33 562.98 523.76 40.18 ]


[ 415.85 455.90 426.24 432.95 81.06 ]


[ 415.85 455.90 426.24 432.95 426.31 174.18 ]


[ 415.85 455.90 593.87 ]


RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous

In [None]:
values = t.stack(activations)[:, 41]

perm = values.argsort()
for i in perm:
    print(ls(values[i]), "\t", prefixes[i])


In [257]:
mid_l = 3
mid_r = 6
X_mid, Y_mid = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.{mid_l}",
    output_layer=f"transformer.h.{mid_r - 1}",
)
Y_mid = (Y_mid - X_mid).reshape(-1, 256)
X_mid = X_mid.reshape(-1, 256)
mid_line = fit_linear(X_mid, Y_mid, reg="l2", alpha=1e-3)
print(
    eval_module(
        mid_line,
        X_mid,
        Y_mid,
    )
)

0.18959206342697144


In [320]:
standard_ln = nn.LayerNorm(256, elementwise_affine=False)

X_mlp_1, Y_mlp_1 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.1.ln_2",
    output_layer=f"transformer.h.1.mlp",
)
Y_mlp_1 = Y_mlp_1.reshape(-1, 256)
X_mlp_1 = standard_ln(X_mlp_1).reshape(-1, 256)
mlp_1 = fit_linear(X_mlp_1, Y_mlp_1, reg="l2", alpha=1e-3)


X_mlp_6, Y_mlp_6 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.6.ln_2",
    output_layer=f"transformer.h.6.mlp",
)
Y_mlp_6 = Y_mlp_6.reshape(-1, 256)
X_mlp_6 = standard_ln(X_mlp_6).reshape(-1, 256)
mlp_6 = fit_linear(X_mlp_6, Y_mlp_6, reg="l2", alpha=1e-3)

In [149]:
X_attn, Y_attn = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0",
    output_layer="transformer.h.0.attn",
)

X_attn_sums = X_attn.cumsum(dim=-2)
X_attn_lens = t.arange(1, X_attn_sums.size(-2) + 1).reshape(1, -1, 1)
Y_attn = Y_attn.reshape(-1, 256)
X_attn = (X_attn_sums / X_attn_lens).reshape(-1, 256)

attn_line = fit_linear(X_attn, Y_attn, reg="l2", alpha=1e-3)
print(
    eval_module(
        attn_line,
        X_attn,
        Y_attn,
    )
)

0.001996031031012535


In [127]:
X_wo_ln, Y_wo_ln = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0.ln_2",
    output_layer="transformer.h.0.mlp",
)
Y_wo_ln = Y_wo_ln.reshape(-1, 256)
X_wo_ln = standard_ln(X_wo_ln).reshape(-1, 256)
wo_ln = fit_linear(X_wo_ln, Y_wo_ln, reg="l2", alpha=1e-3)
print(eval_module(wo_ln, X_wo_ln, Y_wo_ln))

0.09023763239383698


In [362]:
from utils import prompt_from_template, get_loss, PrefixMean, Residual, Wrapper
from transformers import GPTNeoForCausalLM

new_model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
new_model.config.use_cache = False
new_model.config.output_attentions = False

new_model.transformer.h[0] = Wrapper(
    nn.Sequential(
        Residual(
            nn.Sequential(
                PrefixMean(),
                attn_line,
            )
        ),
        Residual(
            nn.Sequential(
                nn.LayerNorm(256, elementwise_affine=False),
                wo_ln,
            )
        ),
    ),
    append=(),
)

# new_model.transformer.h[6].ln_2 = nn.LayerNorm(256, elementwise_affine=False)
# new_model.transformer.h[6].mlp = mlp_6

new_model.transformer.h = nn.ModuleList(
    new_model.transformer.h[:mid_l]
    + [Wrapper(Residual(mid_line), append=())]
    + new_model.transformer.h[mid_r:]
)

prompt = prompt_from_template("((((((())))))()" * 3, random=True)
a = get_loss(model, tokenizer, prompt)
b = get_loss(new_model, tokenizer, prompt)
print(f"clean: {a:.3f}, corrupted: {b:.3f}")
print(f"delta: {b - a:.3f}")

clean: 3.649, corrupted: 3.735
delta: 0.085


In [363]:
new_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(502, 256)
    (wpe): Embedding(2048, 256)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): Wrapper(
        (fn): Sequential(
          (0): Residual(
            (fn): Sequential(
              (0): PrefixMean()
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
          (1): Residual(
            (fn): Sequential(
              (0): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
        )
      )
      (1-2): 2 x GPTNeoBlock(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=256, out_featu