In [1]:
from itertools import islice
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typed
from beartype.door import die_if_unbearable as assert_type
from datasets import load_dataset
from einops import einops as ein
from jaxtyping import Bool, Float, Int
from torch import Tensor as TT
from transformers import AutoModelForCausalLM, AutoTokenizer

%load_ext autoreload
%autoreload 2

In [2]:
model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
tokenizer = AutoTokenizer.from_pretrained("Mlxa/brackets-nested")
dataset = load_dataset("Mlxa/nested", streaming=True)["train"]

In [4]:
from utils import input_output_mapping, sh, fit_linear, fit_module, eval_module, Residual, PrefixMean

@typed
def get_prompts(n: int) -> list[str]:
    return [elem["text"] for elem in islice(dataset, n)]

In [7]:
X_1, Y_1 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0",
    output_layer="transformer.h.0",
)

In [31]:
from utils import fit_sae, SparseAutoEncoder

sae = SparseAutoEncoder(256, 64)
activations = ein.rearrange(Y_1, "prompts seq d -> (prompts seq) d")
fit_sae(sae, activations, lr=1e-2, l1=3.0, alpha=1e-4, epochs=300)

100%|██████████| 300/300 [00:20<00:00, 14.56it/s,  loss=0.07, nonzero=49.09 / 64, nonorth=211.85]


In [32]:
fit_sae(sae, activations, lr=1e-3, l1=1.0, alpha=1e-4, epochs=1000)

100%|██████████| 1000/1000 [01:08<00:00, 14.57it/s,  loss=0.02, nonzero=39.66 / 64, nonorth=214.13]


In [257]:
mid_l = 3
mid_r = 6
X_mlp_1, Y_mid = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.{mid_l}",
    output_layer=f"transformer.h.{mid_r - 1}",
)
Y_mid = (Y_mid - X_mlp_1).reshape(-1, 256)
X_mlp_1 = X_mlp_1.reshape(-1, 256)
mid_line = fit_linear(X_mlp_1, Y_mid, reg="l2", alpha=1e-3)
print(
    eval_module(
        mid_line,
        X_mlp_1,
        Y_mid,
    )
)

0.18959206342697144


In [320]:
standard_ln = nn.LayerNorm(256, elementwise_affine=False)

X_mlp_1, Y_mlp_1 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.1.ln_2",
    output_layer=f"transformer.h.1.mlp",
)
Y_mlp_1 = Y_mlp_1.reshape(-1, 256)
X_mlp_1 = standard_ln(X_mlp_1).reshape(-1, 256)
mlp_1 = fit_linear(X_mlp_1, Y_mlp_1, reg="l2", alpha=1e-3)


X_mlp_6, Y_mlp_6 = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer=f"transformer.h.6.ln_2",
    output_layer=f"transformer.h.6.mlp",
)
Y_mlp_6 = Y_mlp_6.reshape(-1, 256)
X_mlp_6 = standard_ln(X_mlp_6).reshape(-1, 256)
mlp_6 = fit_linear(X_mlp_6, Y_mlp_6, reg="l2", alpha=1e-3)

In [149]:
X_attn, Y_attn = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0",
    output_layer="transformer.h.0.attn",
)

X_attn_sums = X_attn.cumsum(dim=-2)
X_attn_lens = t.arange(1, X_attn_sums.size(-2) + 1).reshape(1, -1, 1)
Y_attn = Y_attn.reshape(-1, 256)
X_attn = (X_attn_sums / X_attn_lens).reshape(-1, 256)

attn_line = fit_linear(X_attn, Y_attn, reg="l2", alpha=1e-3)
print(
    eval_module(
        attn_line,
        X_attn,
        Y_attn,
    )
)

0.001996031031012535


In [127]:
X_wo_ln, Y_wo_ln = input_output_mapping(
    model=model,
    tokenizer=tokenizer,
    prompts=get_prompts(10),
    input_layer="transformer.h.0.ln_2",
    output_layer="transformer.h.0.mlp",
)
Y_wo_ln = Y_wo_ln.reshape(-1, 256)
X_wo_ln = standard_ln(X_wo_ln).reshape(-1, 256)
wo_ln = fit_linear(X_wo_ln, Y_wo_ln, reg="l2", alpha=1e-3)
print(eval_module(wo_ln, X_wo_ln, Y_wo_ln))

0.09023763239383698


In [362]:
from utils import prompt_from_template, get_loss, PrefixMean, Residual, Wrapper
from transformers import GPTNeoForCausalLM

new_model = AutoModelForCausalLM.from_pretrained("Mlxa/brackets-nested")
new_model.config.use_cache = False
new_model.config.output_attentions = False

new_model.transformer.h[0] = Wrapper(
    nn.Sequential(
        Residual(
            nn.Sequential(
                PrefixMean(),
                attn_line,
            )
        ),
        Residual(
            nn.Sequential(
                nn.LayerNorm(256, elementwise_affine=False),
                wo_ln,
            )
        ),
    ),
    append=(),
)

# new_model.transformer.h[6].ln_2 = nn.LayerNorm(256, elementwise_affine=False)
# new_model.transformer.h[6].mlp = mlp_6

new_model.transformer.h = nn.ModuleList(
    new_model.transformer.h[:mid_l]
    + [Wrapper(Residual(mid_line), append=())]
    + new_model.transformer.h[mid_r:]
)

prompt = prompt_from_template("((((((())))))()" * 3, random=True)
a = get_loss(model, tokenizer, prompt)
b = get_loss(new_model, tokenizer, prompt)
print(f"clean: {a:.3f}, corrupted: {b:.3f}")
print(f"delta: {b - a:.3f}")

clean: 3.649, corrupted: 3.735
delta: 0.085


In [363]:
new_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(502, 256)
    (wpe): Embedding(2048, 256)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): Wrapper(
        (fn): Sequential(
          (0): Residual(
            (fn): Sequential(
              (0): PrefixMean()
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
          (1): Residual(
            (fn): Sequential(
              (0): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
              (1): Linear(in_features=256, out_features=256, bias=True)
            )
          )
        )
      )
      (1-2): 2 x GPTNeoBlock(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=256, out_featu