In [1]:
from itertools import islice
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typed
from beartype.door import die_if_unbearable as assert_type
from datasets import load_dataset
from einops import einops as ein
from jaxtyping import Bool, Float, Int
from torch import Tensor as TT
from transformers import AutoModelForCausalLM, AutoTokenizer

%load_ext autoreload
%autoreload 2

In [2]:
model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
tokenizer = AutoTokenizer.from_pretrained("Mlxa/brackets-nested")
dataset = load_dataset("Mlxa/nested", streaming=True)["train"]

In [175]:
from utils import sh, ls
from activation_analysis import (
    input_output_mapping,
    fit_linear,
    eval_module,
    Residual,
    PrefixMean,
)


@typed
def get_prompts(n: int, n_tokens: int) -> list[str]:
    return [
        " ".join(tokenizer.tokenize(elem["text"])[:n_tokens])
        for elem in islice(dataset, n)
    ]

In [179]:
layer = "transformer.h.7.mlp"
X_sae, Y_sae = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(4000, n_tokens=64),
    input_layer=layer,
    output_layer=layer,
)
# Y_sae -= X_sae

In [180]:
from sparse_autoencoders import fit_sae, SparseAutoEncoder

activation_data = ein.rearrange(Y_sae, "prompts seq d -> (prompts seq) d")

In [182]:
sae = SparseAutoEncoder(256, 64)

In [183]:
others = [SparseAutoEncoder(256, d) for d in [200, 600]]

In [185]:
for other in others:
    fit_sae(other, activation_data, lr=1e-2, l1=0.1, alpha=1e-4, epochs=5)

100%|██████████| 5/5 [00:24<00:00,  4.87s/it,  loss=0.00, nonzero=22.05 / 200, nonorth=0.59]
100%|██████████| 5/5 [00:45<00:00,  9.03s/it,  loss=0.00, nonzero=20.40 / 600, nonorth=0.21]


In [187]:
fit_sae(sae, activation_data, lr=1e-3, l1=0.5, alpha=1e-2, epochs=5)

100%|██████████| 5/5 [00:17<00:00,  3.53s/it,  loss=0.00, nonzero=19.68 / 64, nonorth=0.01]


In [188]:
with t.no_grad():
    some_activations = activation_data[::11]
    guess = sae.encode(some_activations)
    other_guesses = [other.encode(some_activations) for other in others]

In [189]:
from sparse_autoencoders import max_cosine_similarity

sims = [max_cosine_similarity(guess, other) for other in other_guesses]
mean_sim = sum(sims) / len(sims)
print(ls(sims[0].mean()), ls(sims[0]))
print(ls(sims[1].mean()), ls(sims[1]))

0.59 [ 0.45 0.57 0.55 0.59 0.71 0.65 0.89 0.53 0.94 0.45 0.55 0.54 0.55 1.00 0.59 0.50 0.73 0.56 0.55 0.95 0.49 0.29 0.92 0.52 0.76 0.49 0.63 0.40 0.83 0.59 0.90 0.68 0.98 0.61 0.50 0.59 0.53 0.54 0.56 0.36 0.86 0.70 0.39 0.48 0.24 0.93 0.45 0.64 0.49 0.12 0.45 0.56 0.54 0.78 0.54 0.55 0.47 0.84 0.34 0.50 0.52 0.48 0.52 0.15 ]
0.54 [ 0.51 0.62 0.51 0.63 0.50 0.59 0.88 0.50 0.38 0.10 0.54 0.45 0.59 0.99 0.61 0.53 0.16 0.59 0.59 0.84 0.49 0.27 0.54 0.67 0.66 0.55 0.56 0.37 0.84 0.53 0.90 0.71 0.97 0.57 0.48 0.64 0.62 0.54 0.48 0.37 0.86 0.72 0.39 0.43 0.23 0.93 0.57 0.65 0.27 0.13 0.47 0.44 0.59 0.49 0.05 0.50 0.52 0.84 0.31 0.45 0.67 0.42 0.64 0.03 ]


In [191]:
from activation_analysis import compressed_activations
from language_modeling import prompt_from_template, generate_sample, get_logprobs
import plotly.express as px
from tqdm import tqdm


prefixes: list[str] = []
suffixes: list[str] = []
compressed: list[Float[TT, "d"]] = []
prefix_cutoff = 16

for prompt in tqdm(get_prompts(4000, n_tokens=prefix_cutoff)):
    tokens = tokenizer.tokenize(prompt)
    short_prompt = " ".join(tokens)
    res = compressed_activations(model, tokenizer, short_prompt, {layer: sae})[layer]
    assert_type(res, Float[TT, "n d"])
    logprobs = get_logprobs(model, tokenizer,  short_prompt)
    next_tokens = t.multinomial(logprobs.roll(-1, dims=0).exp(), 1).squeeze().tolist()
    for i in range(8, 16):
        prefix = " ".join(tokens[: i + 1])
        prefixes.append(prefix)
        suffixes.append(tokenizer.convert_ids_to_tokens(next_tokens[i]))
        compressed.append(res[i])

100%|██████████| 4000/4000 [01:32<00:00, 43.15it/s]


In [222]:
@typed
def show_activations(activations: Float[TT, "n d"], perm: Int[TT, "k"]) -> None:
    fig = px.imshow(
        activations[:, perm],
        color_continuous_scale=px.colors.diverging.RdBu,
        labels={"x": "Dimension", "y": "Prompt", "color": "Activation"},
        x=[str(elem.item()) for elem in perm],
        # y=prefixes,
        # zmin=-1,
        # zmax=1,
    )
    fig.show()


stacked = t.stack(compressed)
stds = stacked.std(dim=0)
std_argsort = t.tensor(list(stds.argsort())[::-1])
important_dirs = std_argsort[:53]
# stacked -= stacked.mean(dim=0, keepdim=True)
# stacked /= stacked.norm(dim=0, keepdim=True)
show_activations(stacked[:100], perm=important_dirs)

In [193]:
stacked = t.stack(compressed)
for pos, i in enumerate(std_argsort[:50]):
    print(
        f"{pos}) dim: {i.item()}\tstd: {ls(stds[i])}\tsim: {ls(sims[0][i])}\t other sim: {ls(sims[1][i])}"
    )

0) dim: 13	std: 590.36	sim: 1.00	 other sim: 0.99
1) dim: 31	std: 90.22	sim: 0.68	 other sim: 0.71
2) dim: 53	std: 89.79	sim: 0.78	 other sim: 0.49
3) dim: 23	std: 85.60	sim: 0.52	 other sim: 0.67
4) dim: 19	std: 84.98	sim: 0.95	 other sim: 0.84
5) dim: 3	std: 84.35	sim: 0.59	 other sim: 0.63
6) dim: 36	std: 84.09	sim: 0.53	 other sim: 0.62
7) dim: 35	std: 83.98	sim: 0.59	 other sim: 0.64
8) dim: 14	std: 83.51	sim: 0.59	 other sim: 0.61
9) dim: 5	std: 82.60	sim: 0.65	 other sim: 0.59
10) dim: 37	std: 81.52	sim: 0.54	 other sim: 0.54
11) dim: 55	std: 81.32	sim: 0.55	 other sim: 0.50
12) dim: 20	std: 80.81	sim: 0.49	 other sim: 0.49
13) dim: 12	std: 80.74	sim: 0.55	 other sim: 0.59
14) dim: 10	std: 79.95	sim: 0.55	 other sim: 0.54
15) dim: 26	std: 79.95	sim: 0.63	 other sim: 0.56
16) dim: 62	std: 79.61	sim: 0.52	 other sim: 0.64
17) dim: 11	std: 79.35	sim: 0.54	 other sim: 0.45
18) dim: 51	std: 78.66	sim: 0.56	 other sim: 0.44
19) dim: 2	std: 78.58	sim: 0.55	 other sim: 0.51
20) dim: 1	s

In [226]:
from activation_analysis import feature_effect

short_prompts = [
    ("<10 <11 <12 12> 11> 10>", 1),
    ("<3 <4 <5 5> 4> 3>", 1),
    ("<8 <2 <7 7> 2> 8>", 1),
    ("<10 <11 <12 12> 11> 10>", 2),
    ("<3 <4 <5 5> 4> 3>", 2),
]

inspected_direction = 9

for prompt, position in short_prompts:
    eps = 1e2

    effect = feature_effect(
        model,
        tokenizer,
        prompt,
        position,
        directions={layer: (sae, inspected_direction)},
        eps=eps,
    )

    d = effect.base[f"lm_head"] * 0 + effect.diff[f"lm_head"]
    print(f"{prompt} (pos={position})")

    r = 0.002 * eps
    window = 50
    px.imshow(
        d[:, :window].detach(),
        zmin=-r,
        zmax=r,
        color_continuous_scale=px.colors.diverging.RdBu,
        x=[tokenizer.decode(i) for i in range(window)],
    ).show()

    # window = 500
    # px.imshow(
    #     d[:, 1:window:2].detach(),
    #     color_continuous_scale=px.colors.diverging.RdBu,
    #     x=[tokenizer.decode(2 * i + 1) for i in range(window//2)],
    # ).show()


<10 <11 <12 12> 11> 10> (pos=1)


<3 <4 <5 5> 4> 3> (pos=1)


<8 <2 <7 7> 2> 8> (pos=1)


<10 <11 <12 12> 11> 10> (pos=2)


<3 <4 <5 5> 4> 3> (pos=2)


In [204]:
values = t.tensor([a[inspected_direction] for a in compressed])
perm = values.argsort()

for i in perm[::100]:
    print(ls(values[i]), "   \t", prefixes[i], "#", suffixes[i])

0.00    	 <168 <97 <42 42> 97> 168> <71 71> <114 <139 139> <64 64> <151 <51 <190 # <195
0.00    	 <119 <52 <125 <236 <21 <193 <158 158> 193> <246 246> <69 <204 <120 # 120>
0.00    	 <134 <240 240> 134> <74 <98 98> 74> <141 141> <117 <150 150> 117> # <75
0.00    	 <216 216> <126 <51 <48 48> 51> <213 213> 126> <52 52> <66 # <38
0.00    	 <149 <104 <192 192> 104> <96 96> 149> <7 7> <197 197> <111 # 111>
0.00    	 <153 153> <92 92> <7 7> <242 <138 <110 # 110>
0.00    	 <86 86> <187 187> <238 238> <53 53> <191 <250 # <11
0.00    	 <145 <134 <100 100> <247 <186 <135 135> 186> 247> 134> 145> <64 64> <66 66> # <221
0.00    	 <60 <230 230> <115 115> <81 <76 76> <3 3> 81> 60> <131 # <90
0.00    	 <164 164> <104 104> <187 187> <72 72> <16 # 16>
0.00    	 <240 <16 16> 240> <207 207> <154 154> <20 # 20>
0.00    	 <11 11> <170 <142 <123 <61 61> <14 14> <30 30> # 123>
0.00    	 <117 <32 32> 117> <212 <73 <147 147> 73> 212> <178 178> <66 66> <89 # 89>
0.00    	 <154 154> <125 <238 238> <237 <133 133> 

In [205]:
from language_modeling import get_balances
from collections import Counter

segments = 16
for segment in range(segments):
    l, r = segment * len(perm) // segments, (segment + 1) * len(perm) // segments
    prefix_ends_open = []
    suffix_starts_open = []
    prefix_balance = []
    feature_values = []
    prefix_tokens = Counter()
    last_tokens = Counter()
    suffix_tokens = Counter()
    for i in perm[l:r]:
        balances = get_balances(prefixes[i])
        feature_values.append(values[i].item())
        prefix_balance.append(balances[-1].item())
        suffix_starts_open.append(suffixes[i].startswith("<"))
        prefix_ends_open.append(not prefixes[i].endswith(">"))
        prefix_tokens.update(tokenizer.tokenize(prefixes[i]))
        last_tokens.update(tokenizer.tokenize(prefixes[i])[-1:])
        suffix_tokens.update(tokenizer.tokenize(suffixes[i])[:1])
    avg = lambda x: sum(x) / len(x)
    Counter.get_freqs = lambda self: [
        (k, round(v / self.total(), 3)) for k, v in self.most_common(10)
    ]
    print(f"segment {segment}: {ls(min(feature_values))}, {ls(max(feature_values))}")
    # idx_to_check = [
    #     tokenizer.encode(s)[0] for s in ["<4", "4>", "<8", "8>", "<12", "12>"]
    # ]
    # print(
    #     *idx_to_check, ":", tokenizer.decode(idx_to_check), ":", *[suffix_tokens[i] for i in idx_to_check]
    # )
    print(
        f"last open = {avg(prefix_ends_open)}, next open = {avg(suffix_starts_open)}, prefix balance quantiles: {np.quantile(prefix_balance, [0.1, 0.25, 0.5, 0.75, 0.9])}, mean balance: {avg(prefix_balance)}"
    )
    print("prefix (full)", prefix_tokens.total(), prefix_tokens.get_freqs())
    print("prefix (last)", last_tokens.total(), last_tokens.get_freqs())
    print("suffix (first)", suffix_tokens.total(), suffix_tokens.get_freqs())
    print()

segment 0: 0.00, 0.00
last open = 0.4805, next open = 0.5755, prefix balance quantiles: [0. 0. 1. 2. 4.], mean balance: 1.659
prefix (full) 25054 [('<21', 0.005), ('<73', 0.004), ('<236', 0.004), ('<110', 0.004), ('<133', 0.004), ('<89', 0.004), ('<190', 0.004), ('<78', 0.004), ('<184', 0.004), ('<34', 0.004)]
prefix (last) 2000 [('<41', 0.006), ('184>', 0.006), ('<218', 0.006), ('<166', 0.005), ('236>', 0.005), ('<236', 0.005), ('<184', 0.005), ('<47', 0.005), ('23>', 0.005), ('<63', 0.005)]
suffix (first) 2000 [('<128', 0.005), ('<45', 0.005), ('<66', 0.005), ('73>', 0.005), ('95>', 0.005), ('<106', 0.005), ('77>', 0.005), ('27>', 0.005), ('<164', 0.005), ('<130', 0.005)]

segment 1: 0.00, 0.00
last open = 0.4925, next open = 0.556, prefix balance quantiles: [0. 0. 1. 3. 4.], mean balance: 1.7585
prefix (full) 25089 [('<111', 0.004), ('<117', 0.004), ('<162', 0.004), ('<102', 0.004), ('<95', 0.004), ('<88', 0.004), ('117>', 0.004), ('<192', 0.004), ('<119', 0.004), ('<132', 0.004)]
p

In [None]:
while True:
    user_input = input()
    tokens = tokenizer.encode(user_input)
    if any(x >= len(tokenizer) - 2 for x in tokenizer.encode(user_input)):
        print("Invalid input. Please try again.")
        continue
    cur = compressed_activations(model, tokenizer, user_input, {layer: sae})[layer]
    inspected = cur[:, inspected_direction]
    assert inspected.shape == (len(tokens),)
    print(ls(inspected))
    str_tokens = tokenizer.tokenize(user_input)
    str_tokens = [f"{tok} ({i})" for i, tok in enumerate(str_tokens)]
    r = 200
    px.line(x=str_tokens, y=inspected, range_y=(-r, r)).show()


In [223]:
covar = sae.decoder[std_argsort] @ sae.decoder[std_argsort].T
px.imshow(covar.detach()).show()
# pairs = []
# for i in range(covar.shape[0]):
#     for j in range(covar.shape[1]):
#         if i < j and covar[i, j] < -0.5:
#             print(covar[i, j])
#             assert covar[i, j] < -0.7
#             pairs.append((i, j))

In [219]:
for x, y in pairs:
    print(x, y)

0 52
1 20
2 35
3 22
5 7
6 25
8 33
9 26
10 29
11 28
12 18
13 15
14 21
16 30
17 19
23 36
24 37
27 31
32 34
38 39


In [257]:
mid_l = 3
mid_r = 6
X_mid, Y_mid = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.{mid_l}",
    output_layer=f"transformer.h.{mid_r - 1}",
)
Y_mid = (Y_mid - X_mid).reshape(-1, 256)
X_mid = X_mid.reshape(-1, 256)
mid_line = fit_linear(X_mid, Y_mid, reg="l2", alpha=1e-3)
print(
    eval_module(
        mid_line,
        X_mid,
        Y_mid,
    )
)

0.18959206342697144


In [320]:
standard_ln = nn.LayerNorm(256, elementwise_affine=False)

X_mlp_1, Y_mlp_1 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.1.ln_2",
    output_layer=f"transformer.h.1.mlp",
)
Y_mlp_1 = Y_mlp_1.reshape(-1, 256)
X_mlp_1 = standard_ln(X_mlp_1).reshape(-1, 256)
mlp_1 = fit_linear(X_mlp_1, Y_mlp_1, reg="l2", alpha=1e-3)


X_mlp_6, Y_mlp_6 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.6.ln_2",
    output_layer=f"transformer.h.6.mlp",
)
Y_mlp_6 = Y_mlp_6.reshape(-1, 256)
X_mlp_6 = standard_ln(X_mlp_6).reshape(-1, 256)
mlp_6 = fit_linear(X_mlp_6, Y_mlp_6, reg="l2", alpha=1e-3)

In [149]:
X_attn, Y_attn = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0",
    output_layer="transformer.h.0.attn",
)

X_attn_sums = X_attn.cumsum(dim=-2)
X_attn_lens = t.arange(1, X_attn_sums.size(-2) + 1).reshape(1, -1, 1)
Y_attn = Y_attn.reshape(-1, 256)
X_attn = (X_attn_sums / X_attn_lens).reshape(-1, 256)

attn_line = fit_linear(X_attn, Y_attn, reg="l2", alpha=1e-3)
print(
    eval_module(
        attn_line,
        X_attn,
        Y_attn,
    )
)

0.001996031031012535


In [127]:
X_wo_ln, Y_wo_ln = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0.ln_2",
    output_layer="transformer.h.0.mlp",
)
Y_wo_ln = Y_wo_ln.reshape(-1, 256)
X_wo_ln = standard_ln(X_wo_ln).reshape(-1, 256)
wo_ln = fit_linear(X_wo_ln, Y_wo_ln, reg="l2", alpha=1e-3)
print(eval_module(wo_ln, X_wo_ln, Y_wo_ln))

0.09023763239383698


In [362]:
from utils import prompt_from_template, get_loss, PrefixMean, Residual, Wrapper
from transformers import GPTNeoForCausalLM

new_model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
new_model.config.use_cache = False
new_model.config.output_attentions = False

new_model.transformer.h[0] = Wrapper(
    nn.Sequential(
        Residual(
            nn.Sequential(
                PrefixMean(),
                attn_line,
            )
        ),
        Residual(
            nn.Sequential(
                nn.LayerNorm(256, elementwise_affine=False),
                wo_ln,
            )
        ),
    ),
    append=(),
)

# new_model.transformer.h[6].ln_2 = nn.LayerNorm(256, elementwise_affine=False)
# new_model.transformer.h[6].mlp = mlp_6

new_model.transformer.h = nn.ModuleList(
    new_model.transformer.h[:mid_l]
    + [Wrapper(Residual(mid_line), append=())]
    + new_model.transformer.h[mid_r:]
)

prompt = prompt_from_template("((((((())))))()" * 3, random=True)
a = get_loss(model, tokenizer, prompt)
b = get_loss(new_model, tokenizer, prompt)
print(f"clean: {a:.3f}, corrupted: {b:.3f}")
print(f"delta: {b - a:.3f}")

clean: 3.649, corrupted: 3.735
delta: 0.085


In [363]:
new_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(502, 256)
    (wpe): Embedding(2048, 256)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): Wrapper(
        (fn): Sequential(
          (0): Residual(
            (fn): Sequential(
              (0): PrefixMean()
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
          (1): Residual(
            (fn): Sequential(
              (0): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
        )
      )
      (1-2): 2 x GPTNeoBlock(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=256, out_featu