In [1]:
%load_ext autoreload
%autoreload 2

from teren import utils as teren_utils
from teren.perturbations import Perturbation

import torch
import seaborn as sns
import plotly_express as px

device = teren_utils.get_device_str()

print(f"{device=}")
print(f"Gradients globally enabled: {torch.is_grad_enabled()}")

device='cuda'
Gradients globally enabled: False


## Setup

In [2]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Loaded pretrained model gpt2-small into HookedTransformer


In [14]:
# each layer is kind of independent experiment, but we save a lot of time by
# collecting dataset examples and residual stream activation for many layers at once

LAYERS = list(range(model.cfg.n_layers))
# LAYERS = 0, 5, 10

# pre-trained SAE model availavle in SAELens
SAE_RELEASE = "gpt2-small-res-jb"

# TODO: experiment with hook_resid_post
LAYER_TO_SAE_ID = lambda layer: f"blocks.{layer}.hook_resid_pre"

# HF text  dataset
DATASET_PATH = "NeelNanda/c4-10k"
DATASET_SPLIT = "train[:1%]"

# how long each tokenized prompt should be
CONTEXT_SIZE = 10

# how many tokens can be forwarded through the transformer at once
# tune to your VRAM, 12_800 works well with 16GB
INFERENCE_TOKENS = 320  # 12_800
INFERENCE_BATCH_SIZE = INFERENCE_TOKENS // CONTEXT_SIZE
print(f"{INFERENCE_BATCH_SIZE=}")

# what is the minimum activation for a feature to be considered active
MIN_FEATURE_ACTIVATION = 0.0
# minimum number of tokenized prompts that should have a feature active
# for the feature to be included in the experiment
MIN_EXAMPLES_PER_FEATURE = 30
# ids of feature to consider in the experiment
# we can't consider all the features, because of compute & memory constraints
CONSIDER_FEATURE_IDS = list(range(30))

INFERENCE_BATCH_SIZE=32


In [10]:
from sae_lens import SAE

sae_by_layer = {
    # ignore other things it returns
    layer: SAE.from_pretrained(
        release=SAE_RELEASE,
        sae_id=LAYER_TO_SAE_ID(layer),
        device=device,
    )[0]
    for layer in LAYERS
}

In [5]:
all_input_ids = teren_utils.load_and_tokenize_dataset(
    path=DATASET_PATH,
    split=DATASET_SPLIT,
    column_name="text",
    tokenizer=model.tokenizer,
    max_length=CONTEXT_SIZE,
)
print(f"all_input_ids (tokenized dataset) shape: {tuple(all_input_ids.shape)}")

all_input_ids (tokenized dataset) shape: (5463, 10)


In [18]:
from teren.sae_examples import (
    get_sae_feature_examples_by_layer_and_resid_stats_by_layer,
)
from teren.typing import *

consider_feature_ids = [FeatureId(i) for i in range(MIN_EXAMPLES_PER_FEATURE)]

examples_by_feature_by_layer, resid_stats_by_layer = (
    get_sae_feature_examples_by_layer_and_resid_stats_by_layer(
        input_ids=all_input_ids,
        model=model,
        sae_by_layer=sae_by_layer,
        fids=consider_feature_ids,
        n_examples=MIN_EXAMPLES_PER_FEATURE,
        batch_size=INFERENCE_BATCH_SIZE,
        min_activation=MIN_FEATURE_ACTIVATION,
    )
)

In [29]:
# display high-level summary of the data
for layer, examples_by_feature in examples_by_feature_by_layer.items():
    print(f"Layer: {layer}")
    active_feature_ids = [f.int for f in examples_by_feature.fids]
    print(f"Number of selected features: {len(active_feature_ids)}")
    print(f"Active feature ids: {active_feature_ids}")
    print()

Layer: 0
Number of selected features: 1
Active feature ids: [9]

Layer: 1
Number of selected features: 1
Active feature ids: [11]

Layer: 2
Number of selected features: 1
Active feature ids: [0]

Layer: 3
Number of selected features: 1
Active feature ids: [6]

Layer: 4
Number of selected features: 5
Active feature ids: [7, 8, 12, 15, 19]

Layer: 5
Number of selected features: 14
Active feature ids: [0, 2, 4, 5, 9, 10, 13, 15, 19, 24, 25, 26, 27, 28]

Layer: 6
Number of selected features: 9
Active feature ids: [5, 7, 13, 14, 16, 17, 18, 20, 27]

Layer: 7
Number of selected features: 12
Active feature ids: [0, 5, 6, 10, 14, 15, 17, 18, 20, 21, 26, 29]

Layer: 8
Number of selected features: 16
Active feature ids: [1, 3, 4, 6, 7, 8, 10, 11, 13, 15, 16, 21, 24, 25, 26, 29]

Layer: 9
Number of selected features: 22
Active feature ids: [0, 2, 4, 5, 8, 9, 10, 12, 13, 14, 15, 17, 18, 19, 21, 22, 23, 24, 25, 27, 28, 29]

Layer: 10
Number of selected features: 13
Active feature ids: [0, 1, 2, 7, 

In [30]:
def print_examples_shapes(idx):
    print(f"Layer: {LAYERS[idx]}")
    examples_by_feature = examples_by_feature_by_layer[idx]
    print(
        f"input_ids shape: {tuple(examples_by_feature.input_ids.shape)} (n_features, n_examples, context_size)"
    )
    print(
        f"resid_acts shape: {tuple(examples_by_feature.resid_acts.shape)} (n_features, n_examples, context_size, d_model)"
    )
    print(
        f"clean_loss shape: {tuple(examples_by_feature.clean_loss.shape)} (n_features, n_examples, context_size - 1)"
    )


print_examples_shapes(0)

Layer: 0
input_ids shape: (1, 30, 10) (n_features, n_examples, context_size)
resid_acts shape: (1, 30, 10, 768) (n_features, n_examples, context_size, d_model)
clean_loss shape: (1, 30, 9) (n_features, n_examples, context_size - 1)


## Define Perturbations

In [37]:
def is_positive_definite(A):
    try:
        torch.linalg.cholesky(A)
        return True
    except RuntimeError:
        return False

In [38]:
resid_stats_by_layer[0].cov

False