In [1]:
from itertools import islice
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typed
from beartype.door import die_if_unbearable as assert_type
from datasets import load_dataset
from einops import einops as ein
from jaxtyping import Bool, Float, Int
from torch import Tensor as TT
from transformers import AutoModelForCausalLM, AutoTokenizer

%load_ext autoreload
%autoreload 2

In [2]:
model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
tokenizer = AutoTokenizer.from_pretrained("Mlxa/brackets-nested")
dataset = load_dataset("Mlxa/nested", streaming=True)["train"]

In [3]:
from utils import sh, ls
from activation_analysis import (
    input_output_mapping,
    fit_linear,
    eval_module,
    Residual,
    PrefixMean,
)


@typed
def get_prompts(n: int) -> list[str]:
    return [elem["text"] for elem in islice(dataset, n)]

In [4]:
layer = "transformer.h.7.mlp"
X_sae, Y_sae = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(200),
    input_layer=layer,
    output_layer=layer,
)
# Y_sae -= X_sae

In [7]:
from sparse_autoencoders import fit_sae, SparseAutoEncoder

sae = SparseAutoEncoder(256, 520)
others = [SparseAutoEncoder(256, d) for d in [768, 1024]]
activation_data = ein.rearrange(Y_sae, "prompts seq d -> (prompts seq) d")

In [8]:
for other in others:
    fit_sae(other, activation_data, lr=1e-2, l1=0.1, alpha=1e-4, epochs=15)

100%|██████████| 15/15 [01:07<00:00,  4.47s/it,  loss=0.00, nonzero=20.36 / 768, nonorth=0.21]
100%|██████████| 15/15 [01:39<00:00,  6.64s/it,  loss=0.01, nonzero=19.79 / 1024, nonorth=0.22]


In [15]:
fit_sae(sae, activation_data, lr=1e-3, l1=0.1, alpha=1e-4, epochs=15)

100%|██████████| 15/15 [00:46<00:00,  3.07s/it,  loss=0.00, nonzero=75.50 / 520, nonorth=0.00]


In [19]:
with t.no_grad():
    some_activations = activation_data[::11]
    guess = sae.encode(some_activations)
    other_guesses = [other.encode(some_activations) for other in others]

In [22]:
from sparse_autoencoders import max_cosine_similarity

sims = [max_cosine_similarity(guess, other) for other in other_guesses]
mean_sim = sum(sims) / len(sims)
print(ls(sims[0].mean()), ls(sims[0]))
print(ls(sims[1].mean()), ls(sims[1]))

0.53 [ 0.25 0.79 0.81 0.62 0.37 0.53 0.18 0.33 0.18 0.68 0.53 0.30 0.51 0.95 0.28 0.67 0.46 0.56 0.43 0.65 0.60 0.50 0.65 0.46 0.40 0.59 0.57 0.76 0.32 0.43 0.54 0.45 0.38 0.54 0.20 0.89 0.76 0.46 0.64 0.68 0.51 0.38 0.73 0.77 0.59 0.53 0.58 0.27 0.37 0.99 0.25 0.60 0.18 0.55 0.61 0.75 0.19 0.84 0.61 0.57 0.52 0.49 0.54 0.24 0.65 0.53 0.34 0.58 0.55 0.62 0.71 0.88 0.91 0.35 0.54 0.42 0.43 0.96 0.52 0.15 0.12 0.49 0.85 0.32 0.73 0.47 0.39 0.29 0.57 0.80 0.51 0.19 0.51 0.55 1.00 0.61 0.23 0.20 0.85 0.54 0.40 0.55 0.54 0.86 0.78 0.48 0.49 0.36 0.63 0.65 0.90 0.43 0.49 0.43 0.88 0.47 0.45 0.54 0.45 0.68 0.97 0.65 0.60 0.50 0.40 0.53 0.81 0.44 0.27 0.62 0.47 0.03 0.53 0.23 0.29 0.84 0.34 0.48 0.59 0.63 0.40 0.49 0.59 0.36 0.37 0.61 0.58 0.48 0.40 0.85 0.14 0.31 0.50 0.45 0.51 0.40 0.41 0.18 0.57 0.18 0.67 0.81 0.74 0.46 0.39 0.61 0.87 0.60 0.63 0.53 0.78 0.60 0.31 0.57 0.64 0.36 1.00 0.56 0.81 0.17 0.64 0.60 0.53 0.47 0.53 0.42 0.28 0.49 0.69 0.53 0.96 0.56 0.48 0.32 0.77 0.00 0.59 0.46 0.4

In [24]:
from activation_analysis import compressed_activations
from language_modeling import prompt_from_template, generate_sample
import plotly.express as px
from tqdm import tqdm


prefixes: list[str] = []
suffixes: list[str] = []
compressed: list[Float[TT, "d"]] = []

for prompt in tqdm(get_prompts(200)):
    res = compressed_activations(model, tokenizer, prompt, {layer: sae})[layer]
    assert_type(res, Float[TT, "n d"])
    tokens = tokenizer.tokenize(prompt)
    for i in range(min(8, len(tokens))):
        prefix = " ".join(tokens[: i + 1])
        prefixes.append(prefix)
        generated_suffix = generate_sample(
            model,
            tokenizer,
            tokenizer.encode(prefix),
            max_new_tokens=4,
        )
        suffixes.append(generated_suffix)
        compressed.append(res[i])

100%|██████████| 200/200 [01:24<00:00,  2.35it/s]


In [26]:
@typed
def show_activations(
    prefixes: list[str], activations: Float[TT, "n d"], perm: Int[TT, "k"]
) -> None:
    fig = px.imshow(
        activations[:, perm],
        color_continuous_scale=px.colors.diverging.RdBu,
        labels={"x": "Dimension", "y": "Prompt", "color": "Activation"},
        x=[str(elem.item()) for elem in perm],
        # y=prefixes,
        # zmin=-1,
        # zmax=1,
    )
    fig.show()


stacked = t.stack(compressed)
stds = stacked.std(dim=0)
std_argsort = t.tensor(list(stds.argsort())[::-1])
important_dirs = std_argsort[:50]
# stacked -= stacked.mean(dim=0, keepdim=True)
# stacked /= stacked.norm(dim=0, keepdim=True)
show_activations(prefixes, stacked, perm=important_dirs)

In [32]:
stacked = t.stack(compressed)
for pos, i in enumerate(list(stds.argsort())[::-1][:500]):
    if mean_sim[i] > 0.8:
        print(
            f"{pos}) dim: {i.item()}\tstd: {ls(stds[i])}\tsim: {ls(sims[0][i])}\t other sim: {ls(sims[1][i])}"
        )

0) dim: 49	std: 535.12	sim: 0.99	 other sim: 0.97
40) dim: 149	std: 27.83	sim: 0.85	 other sim: 0.87
41) dim: 2	std: 27.55	sim: 0.81	 other sim: 0.85
42) dim: 35	std: 25.31	sim: 0.89	 other sim: 0.90
43) dim: 110	std: 23.53	sim: 0.90	 other sim: 0.88
47) dim: 234	std: 22.48	sim: 0.89	 other sim: 0.90
49) dim: 473	std: 21.90	sim: 0.93	 other sim: 0.93
50) dim: 279	std: 21.89	sim: 0.83	 other sim: 0.84
53) dim: 243	std: 21.51	sim: 0.85	 other sim: 0.87
55) dim: 388	std: 21.29	sim: 0.90	 other sim: 0.92
57) dim: 337	std: 20.79	sim: 0.82	 other sim: 0.83
60) dim: 82	std: 20.49	sim: 0.85	 other sim: 0.87
61) dim: 396	std: 20.39	sim: 0.92	 other sim: 0.93
62) dim: 321	std: 20.34	sim: 0.91	 other sim: 0.88
63) dim: 264	std: 20.32	sim: 0.91	 other sim: 0.89
65) dim: 265	std: 20.07	sim: 0.91	 other sim: 0.95
66) dim: 510	std: 19.62	sim: 0.95	 other sim: 0.96
68) dim: 307	std: 19.41	sim: 0.88	 other sim: 0.89
71) dim: 384	std: 18.16	sim: 0.81	 other sim: 0.81
72) dim: 315	std: 18.03	sim: 0.91	 o

In [69]:
from activation_analysis import feature_effect

short_prompts = [
    ("<10 <11 <12 12> 11> 10>", 1),
    ("<3 <4 <5 5> 4> 3>", 1),
    ("<8 <2 <7 7> 2> 8>", 1),
    ("<10 <11 <12 12> 11> 10>", 2),
    ("<3 <4 <5 5> 4> 3>", 2),
]

inspected_direction = 149

for prompt, position in short_prompts:
    eps = 1e2

    effect = feature_effect(
        model,
        tokenizer,
        prompt,
        position,
        directions={layer: (sae, inspected_direction)},
        eps=eps,
    )

    d = effect.base[f"lm_head"] * 0 + effect.diff[f"lm_head"]
    print(f"{prompt} (pos={position})")

    r = 0.002 * eps
    window = 50
    px.imshow(
        d[:, :window].detach(),
        zmin=-r,
        zmax=r,
        color_continuous_scale=px.colors.diverging.RdBu,
        x=[tokenizer.decode(i) for i in range(window)],
    ).show()

    # window = 500
    # px.imshow(
    #     d[:, 1:window:2].detach(),
    #     color_continuous_scale=px.colors.diverging.RdBu,
    #     x=[tokenizer.decode(2 * i + 1) for i in range(window//2)],
    # ).show()


<10 <11 <12 12> 11> 10> (pos=1)


<3 <4 <5 5> 4> 3> (pos=1)


<8 <2 <7 7> 2> 8> (pos=1)


<10 <11 <12 12> 11> 10> (pos=2)


<3 <4 <5 5> 4> 3> (pos=2)


In [67]:
values = t.tensor([a[inspected_direction] for a in compressed])
perm = values.argsort()

for i in perm[::100]:
    print(ls(values[i]), "   \t", prefixes[i], "#", suffixes[i])

0.00    	 <233 233> <195 195> <151 <90 # 90> 151> <211 <194
0.00    	 <24 24> <3 # 3> <211 211> <70
0.00    	 <235 <214 <191 191> <20 <85 # 85> 20> 214> 235>
0.00    	 <233 233> <42 <47 47> <182 # 182> <97 97> <150
0.00    	 <93 93> <63 63> <194 194> <6 # 6> <182 182> <50
0.65    	 <44 44> <226 # 226> <239 239> <15
8.46    	 <182 182> <23 23> <165 <222 <86 # 86> 222> 165> <239
15.78    	 <17 17> <133 # 133> <78 78> <222
23.80    	 <140 <199 199> 140> <165 <127 <226 226> # 127> 165> <182 182>
31.47    	 <224 224> <71 <208 # 208> 71> <218 218>
39.36    	 <209 209> # <143 143> <6 <217
44.09    	 <150 150> <178 178> <248 # 248> <213 213> <172
49.21    	 <57 <120 120> 57> # <91 91> <193 193>
59.83    	 <85 85> <10 <125 125> <39 39> # 10> <50 50> <56
66.70    	 <236 236> <94 94> # <167 167> <193 193>
75.83    	 <227 227> <168 <124 <60 <128 128> <80 # 80> 60> <68 68>


In [79]:
from language_modeling import get_balances
from collections import Counter

segments = 8
for segment in range(segments):
    l, r = segment * len(perm) // segments, (segment + 1) * len(perm) // segments
    prefix_ends_open = []
    suffix_starts_open = []
    prefix_balance = []
    feature_values = []
    prefix_tokens = Counter()
    last_tokens = Counter()
    suffix_tokens = Counter()
    for i in perm[l:r]:
        balances = get_balances(prefixes[i])
        feature_values.append(values[i].item())
        prefix_balance.append(balances[-1].item())
        suffix_starts_open.append(suffixes[i].startswith("<"))
        prefix_ends_open.append(not prefixes[i].endswith(">"))
        prefix_tokens.update(tokenizer.encode(prefixes[i]))
        last_tokens.update(tokenizer.encode(prefixes[i])[-1:])
        suffix_tokens.update(tokenizer.encode(suffixes[i])[:1])
    avg = lambda x: sum(x) / (r - l)
    Counter.get_freqs = lambda self: [
        (k, round(v / self.total(), 3)) for k, v in self.most_common(10)
    ]
    print(f"segment {segment}: {ls(min(feature_values))}, {ls(max(feature_values))}")
    idx_to_check = [
        tokenizer.encode(s)[0] for s in ["<4", "4>", "<8", "8>", "<12", "12>"]
    ]
    print(
        *idx_to_check, ":", tokenizer.decode(idx_to_check), ":", *[suffix_tokens[i] for i in idx_to_check]
    )
    print(
        f"prefix ends open = {avg(prefix_ends_open)}, suffix starts open = {avg(suffix_starts_open)}"
    )
    print(
        f"prefix balance quantiles: {np.quantile(prefix_balance, [0.1, 0.25, 0.5, 0.75, 0.9])}"
    )
    print("prefix (full)", prefix_tokens.total(), prefix_tokens.get_freqs())
    print("prefix (last)", last_tokens.total(), last_tokens.get_freqs())
    print("suffix (first)", suffix_tokens.total(), suffix_tokens.get_freqs())
    print()

segment 0: 0.00, 0.00
6 7 14 15 22 23 : <4 4> <8 8> <12 12> : 0 0 0 1 0 6
prefix ends open = 0.81, suffix starts open = 0.105
prefix balance quantiles: [1. 1. 1. 2. 2.]
prefix (full) 871 [(168, 0.011), (16, 0.01), (22, 0.01), (100, 0.01), (126, 0.009), (4, 0.009), (96, 0.009), (362, 0.009), (17, 0.009), (450, 0.009)]
prefix (last) 200 [(328, 0.025), (22, 0.025), (4, 0.02), (362, 0.02), (46, 0.02), (100, 0.02), (186, 0.02), (174, 0.015), (464, 0.015), (436, 0.015)]
suffix (first) 200 [(23, 0.03), (363, 0.03), (5, 0.025), (169, 0.025), (175, 0.02), (329, 0.02), (425, 0.02), (435, 0.02), (47, 0.02), (101, 0.02)]

segment 1: 0.00, 0.00
6 7 14 15 22 23 : <4 4> <8 8> <12 12> : 0 0 0 2 0 0
prefix ends open = 0.785, suffix starts open = 0.14
prefix balance quantiles: [1. 1. 1. 2. 3.]
prefix (full) 882 [(436, 0.017), (412, 0.014), (332, 0.014), (168, 0.012), (366, 0.012), (70, 0.012), (316, 0.012), (317, 0.012), (194, 0.011), (286, 0.011)]
prefix (last) 200 [(412, 0.025), (436, 0.025), (168, 0.

In [136]:
while True:
    user_input = input()
    tokens = tokenizer.encode(user_input)
    if any(x >= len(tokenizer) - 2 for x in tokenizer.encode(user_input)):
        print("Invalid input. Please try again.")
        continue
    cur = compressed_activations(model, tokenizer, user_input, {layer: sae})[layer]
    inspected = cur[:, inspected_direction]
    assert inspected.shape == (len(tokens),)
    print(ls(inspected))
    str_tokens = tokenizer.tokenize(user_input)
    str_tokens = [f"{tok} ({i})" for i, tok in enumerate(str_tokens)]
    r = 200
    px.line(x=str_tokens, y=inspected, range_y=(-r, r)).show()

# pos: 212 204 6 4
# neg: 154 111 96 98 221 148

[ 415.85 580.64 -3.91 -4.53 ]


[ -4.87 -5.55 322.82 423.15 ]


[ -4.87 374.97 428.12 ]


[ 415.85 580.64 -3.91 ]


[ 415.85 455.90 563.33 562.98 -0.98 ]


[ 415.85 455.90 563.33 562.98 523.76 488.88 521.89 514.86 319.60 ]


[ 415.85 455.90 563.33 562.98 523.76 40.18 ]


[ 415.85 455.90 426.24 432.95 81.06 ]


[ 415.85 455.90 426.24 432.95 426.31 174.18 ]


[ 415.85 455.90 593.87 ]


RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous

In [None]:
values = t.stack(compressed)[:, 41]

perm = values.argsort()
for i in perm:
    print(ls(values[i]), "\t", prefixes[i])

In [257]:
mid_l = 3
mid_r = 6
X_mid, Y_mid = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.{mid_l}",
    output_layer=f"transformer.h.{mid_r - 1}",
)
Y_mid = (Y_mid - X_mid).reshape(-1, 256)
X_mid = X_mid.reshape(-1, 256)
mid_line = fit_linear(X_mid, Y_mid, reg="l2", alpha=1e-3)
print(
    eval_module(
        mid_line,
        X_mid,
        Y_mid,
    )
)

0.18959206342697144


In [320]:
standard_ln = nn.LayerNorm(256, elementwise_affine=False)

X_mlp_1, Y_mlp_1 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.1.ln_2",
    output_layer=f"transformer.h.1.mlp",
)
Y_mlp_1 = Y_mlp_1.reshape(-1, 256)
X_mlp_1 = standard_ln(X_mlp_1).reshape(-1, 256)
mlp_1 = fit_linear(X_mlp_1, Y_mlp_1, reg="l2", alpha=1e-3)


X_mlp_6, Y_mlp_6 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.6.ln_2",
    output_layer=f"transformer.h.6.mlp",
)
Y_mlp_6 = Y_mlp_6.reshape(-1, 256)
X_mlp_6 = standard_ln(X_mlp_6).reshape(-1, 256)
mlp_6 = fit_linear(X_mlp_6, Y_mlp_6, reg="l2", alpha=1e-3)

In [149]:
X_attn, Y_attn = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0",
    output_layer="transformer.h.0.attn",
)

X_attn_sums = X_attn.cumsum(dim=-2)
X_attn_lens = t.arange(1, X_attn_sums.size(-2) + 1).reshape(1, -1, 1)
Y_attn = Y_attn.reshape(-1, 256)
X_attn = (X_attn_sums / X_attn_lens).reshape(-1, 256)

attn_line = fit_linear(X_attn, Y_attn, reg="l2", alpha=1e-3)
print(
    eval_module(
        attn_line,
        X_attn,
        Y_attn,
    )
)

0.001996031031012535


In [127]:
X_wo_ln, Y_wo_ln = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0.ln_2",
    output_layer="transformer.h.0.mlp",
)
Y_wo_ln = Y_wo_ln.reshape(-1, 256)
X_wo_ln = standard_ln(X_wo_ln).reshape(-1, 256)
wo_ln = fit_linear(X_wo_ln, Y_wo_ln, reg="l2", alpha=1e-3)
print(eval_module(wo_ln, X_wo_ln, Y_wo_ln))

0.09023763239383698


In [362]:
from utils import prompt_from_template, get_loss, PrefixMean, Residual, Wrapper
from transformers import GPTNeoForCausalLM

new_model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
new_model.config.use_cache = False
new_model.config.output_attentions = False

new_model.transformer.h[0] = Wrapper(
    nn.Sequential(
        Residual(
            nn.Sequential(
                PrefixMean(),
                attn_line,
            )
        ),
        Residual(
            nn.Sequential(
                nn.LayerNorm(256, elementwise_affine=False),
                wo_ln,
            )
        ),
    ),
    append=(),
)

# new_model.transformer.h[6].ln_2 = nn.LayerNorm(256, elementwise_affine=False)
# new_model.transformer.h[6].mlp = mlp_6

new_model.transformer.h = nn.ModuleList(
    new_model.transformer.h[:mid_l]
    + [Wrapper(Residual(mid_line), append=())]
    + new_model.transformer.h[mid_r:]
)

prompt = prompt_from_template("((((((())))))()" * 3, random=True)
a = get_loss(model, tokenizer, prompt)
b = get_loss(new_model, tokenizer, prompt)
print(f"clean: {a:.3f}, corrupted: {b:.3f}")
print(f"delta: {b - a:.3f}")

clean: 3.649, corrupted: 3.735
delta: 0.085


In [363]:
new_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(502, 256)
    (wpe): Embedding(2048, 256)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): Wrapper(
        (fn): Sequential(
          (0): Residual(
            (fn): Sequential(
              (0): PrefixMean()
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
          (1): Residual(
            (fn): Sequential(
              (0): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
        )
      )
      (1-2): 2 x GPTNeoBlock(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=256, out_featu