## Set Up

In [1]:
%load_ext autoreload
%autoreload 2
import torch
from teren import utils as teren_utils
from transformer_lens import HookedTransformer
from sae_lens import SAE
device = teren_utils.get_device_str()
print(device)


cuda


In [2]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
LAYERS = 8, 10
BATCH_SIZE = 100
saes = [
    SAE.from_pretrained(
        release = "gpt2-small-res-jb",
        sae_id = f"blocks.{layer}.hook_resid_pre",
        device = device
    )[0]
    for layer in LAYERS
]

In [4]:
all_input_ids = teren_utils.load_and_tokenize_dataset(
    path="NeelNanda/pile-10k",
    split="train[:1%]",
    column_name="text",
    tokenizer=model.tokenizer,  # type: ignore
    max_length=saes[0].cfg.context_size,
)
all_input_ids.shape

torch.Size([960, 128])

## Find examples where features are active

In [5]:
examples_by_feature_by_sae = teren_utils.get_examples_by_feature_by_sae(
    input_ids=all_input_ids,
    model=model,
    saes=saes,
    feature_ids=range(10),
    n_examples=10,
    batch_size=BATCH_SIZE,
    min_activation=0.0
)

In [7]:
examples_by_feature_by_sae[0].clean_loss.shape

torch.Size([10, 10, 127])

In [6]:
for sae, examples_by_feature in zip(saes, examples_by_feature_by_sae):
    print(f"SAE hook name: {sae.cfg.hook_name}")
    active_feature_ids = examples_by_feature.active_feature_ids
    print(f"Number of active features: {len(active_feature_ids)}")
    print(f"Active feature idxs: {active_feature_ids}")
    print()

SAE hook name: blocks.8.hook_resid_pre
Number of active features: 10
Active feature idxs: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

SAE hook name: blocks.10.hook_resid_pre
Number of active features: 8
Active feature idxs: [0, 1, 2, 3, 4, 6, 7, 8]



In [18]:
def test_clean_loss_correct():
    loss_by_feature_by_sae = []
    for sae, examples_by_feature in zip(saes, examples_by_feature_by_sae):
        loss_by_feature_by_sae.append(
            teren_utils.compute_loss(
                model=model,
                input_ids=examples_by_feature.input_ids,
                resid_acts=examples_by_feature.resid_acts,
                start_at_layer=sae.cfg.hook_layer,
                batch_size=BATCH_SIZE
            )
        )
    for examples_by_feature, loss_by_feature in zip(examples_by_feature_by_sae, loss_by_feature_by_sae):
        assert torch.allclose(examples_by_feature.clean_loss, loss_by_feature, rtol=1e-3, atol=1e-5)
test_clean_loss_correct()

## Perturbations

In [19]:
from teren.perturbations import Perturbation, naive_random_perturbation, amplify_resid_acts_perturbation, dampen_resid_acts_perturbation, TowardSAEReconPerturbation

In [25]:
perturbations = {
    "naive_random": naive_random_perturbation,
    "amplify_resid_acts": amplify_resid_acts_perturbation,
    "dampen_resid_acts": dampen_resid_acts_perturbation,
}

loss_by_feature_by_pert_by_sae = []
for sae, examples_by_feature in zip(saes, examples_by_feature_by_sae):
    reisd_acts = examples_by_feature.resid_acts
    loss_by_feature_by_pert = {}
    for pert_name, pert in perturbations.items():
        pert_resid_acts = reisd_acts + pert(reisd_acts)
        loss_by_feature_by_pert[pert_name] = teren_utils.compute_loss(
            model=model,
            input_ids=examples_by_feature.input_ids,
            resid_acts=pert_resid_acts,
            start_at_layer=sae.cfg.hook_layer,
            batch_size=BATCH_SIZE
        )
    loss_by_feature_by_pert_by_sae.append(loss_by_feature_by_pert)

## Hooks & Patching

In [34]:
import pandas as pd
import plotly.express as px

def plot_loss_bar(losses_by_hook, clean_loss, ablate_resid_loss, layer):
    df_dicts = []
    for hook_name, loss in losses_by_hook.items():
        normalized_loss = (loss - clean_loss) / ablate_resid_loss

        # normalized_loss is (n_features, batch, n_ctx-1)
        # (n_features, batch)
        normalized_loss_ctx_max = normalized_loss.max(dim=-1).values

        # normalized_loss_ctx_mean = normalized_loss.mean(dim=-1)
        # mean over batch, (n_features,)
        loss_by_feature = normalized_loss_ctx_max.mean(-1)
        df_dicts.append({
            "name": hook_name,
            "mean_loss": loss_by_feature.mean().item(),
            "std_loss": loss_by_feature.std().item(),
        })

    df = pd.DataFrame.from_dict(df_dicts)
    display(df)

    labels = {
        "mean_loss": "Normalized loss increase<br />(relative to ablating entire layer)",
    }
    px.bar(df, title=f"Layer {layer}", x="name", y="mean_loss", error_y="std_loss", color="name", labels=labels).show()

In [35]:
for layer, examples_by_feature, loss_by_feature_by_pert in zip(LAYERS, examples_by_feature_by_sae, loss_by_feature_by_pert_by_sae):
    clean_loss = examples_by_feature.clean_loss
    ablate_resid_loss = loss_by_feature_by_pert["dampen_resid_acts"]
    plot_loss_bar(loss_by_feature_by_pert, clean_loss, ablate_resid_loss, layer=layer)

Unnamed: 0,name,mean_loss,std_loss
0,naive_random,0.346162,0.063047
1,amplify_resid_acts,0.523384,0.051776
2,dampen_resid_acts,0.999859,0.000142


Unnamed: 0,name,mean_loss,std_loss
0,naive_random,0.201087,0.0693
1,amplify_resid_acts,0.601778,0.216502
2,dampen_resid_acts,0.999637,0.00049
