In [1]:
import sys

sys.path.append("/workspace/circuit-finder")

In [2]:
from circuit_finder.pretrained import (
    load_model,
    load_attn_saes,
    load_hooked_mlp_transcoders,
)
from circuit_finder.patching.indirect_leap import preprocess_attn_saes

model = load_model()
attn_saes = load_attn_saes()
attn_saes = preprocess_attn_saes(attn_saes, model)
hooked_mlp_transcoders = load_hooked_mlp_transcoders()

transcoders = list(hooked_mlp_transcoders.values())
saes = list(attn_saes.values())



Loaded pretrained model gpt2 into HookedTransformer


Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

# C4

In [3]:
from datasets import load_dataset

dataset = load_dataset("c4", "en", streaming=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:
import torch
from transformer_lens import ActivationCache
from circuit_finder.patching.ablate import (
    splice_model_with_saes_and_transcoders,
    filter_sae_acts_and_errors,
)

n_tokens = 0
# total_tokens = 100_000 # 100k tokens
total_tokens = 100
print(f"Total tokens: {total_tokens}")

# A bit of a hack, run once to get cache shapes
with splice_model_with_saes_and_transcoders(model, transcoders, saes):
    _, dummy_cache = model.run_with_cache(
        "Hello World", names_filter=filter_sae_acts_and_errors
    )


zero_cache_dict = {
    hook_name: torch.zeros_like(act.sum(1).squeeze(0))
    for hook_name, act in dummy_cache.items()
}
# zero_cache = ActivationCache(zero_cache_dict, model)

# Run the model
with splice_model_with_saes_and_transcoders(model, transcoders, saes):
    for element in dataset["train"]:
        text = element["text"]
        tokens = model.to_tokens(text)
        _, cache = model.run_with_cache(text, names_filter=filter_sae_acts_and_errors)

        n_tokens += tokens.shape[1]
        for hook_name, act in cache.items():
            zero_cache_dict[hook_name] += act.sum(1).squeeze(0)

        if n_tokens >= total_tokens:
            break

# Average the cache
for hook_name, act in zero_cache_dict.items():
    zero_cache_dict[hook_name] /= n_tokens

zero_cache = ActivationCache(zero_cache_dict, model)

Total tokens: 100


In [5]:
print(zero_cache["blocks.0.attn.hook_z.hook_sae_acts_post"])

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')


In [6]:
for hook_name, act in zero_cache.items():
    print(hook_name, act.shape)

blocks.0.attn.hook_z.hook_sae_acts_post torch.Size([49152])
blocks.0.attn.hook_z.hook_sae_error torch.Size([768])
blocks.0.mlp.transcoder.hook_sae_acts_post torch.Size([24576])
blocks.0.mlp.hook_sae_error torch.Size([768])
blocks.1.attn.hook_z.hook_sae_acts_post torch.Size([49152])
blocks.1.attn.hook_z.hook_sae_error torch.Size([768])
blocks.1.mlp.transcoder.hook_sae_acts_post torch.Size([24576])
blocks.1.mlp.hook_sae_error torch.Size([768])
blocks.2.attn.hook_z.hook_sae_acts_post torch.Size([49152])
blocks.2.attn.hook_z.hook_sae_error torch.Size([768])
blocks.2.mlp.transcoder.hook_sae_acts_post torch.Size([24576])
blocks.2.mlp.hook_sae_error torch.Size([768])
blocks.3.attn.hook_z.hook_sae_acts_post torch.Size([49152])
blocks.3.attn.hook_z.hook_sae_error torch.Size([768])
blocks.3.mlp.transcoder.hook_sae_acts_post torch.Size([24576])
blocks.3.mlp.hook_sae_error torch.Size([768])
blocks.4.attn.hook_z.hook_sae_acts_post torch.Size([49152])
blocks.4.attn.hook_z.hook_sae_error torch.Size([

In [7]:
# Save the cache  

import pickle
with open("c4_mean_acts.pkl", 'wb') as file:
    pickle.dump(zero_cache, file)

# Auto-Circuit Datasets

In [8]:
import pathlib
import pickle
import pandas as pd
import json
import torch
import transformer_lens as tl

from simple_parsing import ArgumentParser
from dataclasses import dataclass
from circuit_finder.patching.eap_graph import EAPGraph
from circuit_finder.utils import clear_memory
from circuit_finder.patching.ablate import get_metric_with_ablation
from circuit_finder.data_loader import load_datasets_from_json, PromptPairBatch
from circuit_finder.constants import device
from tqdm import tqdm
from circuit_finder.patching.ablate import get_metric_with_ablation

from typing import Literal
from eindex import eindex
from pathlib import Path
from circuit_finder.pretrained import (
    load_model,
    load_attn_saes,
    load_hooked_mlp_transcoders,
)
from circuit_finder.patching.indirect_leap import (
    preprocess_attn_saes,
    IndirectLEAP,
    LEAPConfig,
)
from circuit_finder.core.types import Model
from circuit_finder.metrics import batch_avg_answer_diff
from circuit_finder.constants import ProjectDir
from circuit_finder.patching.ablate import (
    splice_model_with_saes_and_transcoders,
    get_metric_with_ablation,
    AblateType,
)

from circuit_finder.experiments.run_dataset_sweep import ALL_DATASETS

batch_size = 8
print(ALL_DATASETS)


['datasets/greaterthan_gpt2-small_prompts.json', 'datasets/ioi/ioi_ABBA_template_0_prompts.json', 'datasets/ioi/ioi_ABBA_template_1_prompts.json', 'datasets/ioi/ioi_BABA_template_0_prompts.json', 'datasets/ioi/ioi_BABA_template_1_prompts.json']


In [9]:
import torch
from transformer_lens import ActivationCache
from circuit_finder.patching.ablate import (
    splice_model_with_saes_and_transcoders,
    filter_sae_acts_and_errors,
)


def get_cache(train_loader):

    n_tokens = 0
    total_tokens = 100_000 # 100k tokens
    # total_tokens = 100
    print(f"Total tokens: {total_tokens}")

    # A bit of a hack, run once to get cache shapes
    with splice_model_with_saes_and_transcoders(model, transcoders, saes):
        _, dummy_cache = model.run_with_cache(
            "Hello World", names_filter=filter_sae_acts_and_errors
        )


    zero_cache_dict = {
        hook_name: torch.zeros_like(act.sum(1).squeeze(0))
        for hook_name, act in dummy_cache.items()
    }
    # zero_cache = ActivationCache(zero_cache_dict, model)

    # Run the model
    with splice_model_with_saes_and_transcoders(model, transcoders, saes):
        for batch in train_loader:
            tokens = batch.clean
            _, cache = model.run_with_cache(tokens, names_filter=filter_sae_acts_and_errors)

            n_tokens += tokens.shape[1] * tokens.shape[0]
            for hook_name, act in cache.items():
                zero_cache_dict[hook_name] += act.sum(1).sum(0)

            if n_tokens >= total_tokens:
                break

    print(n_tokens)

    # Average the cache
    for hook_name, act in zero_cache_dict.items():
        zero_cache_dict[hook_name] /= n_tokens

    zero_cache = ActivationCache(zero_cache_dict, model)
    return zero_cache

In [10]:
for dataset_path in ALL_DATASETS:
    print("Processing", dataset_path)
    train_loader, _ = load_datasets_from_json(
        model,
        ProjectDir / dataset_path,
        device=torch.device("cuda"),
        batch_size=batch_size,
    )
    cache = get_cache(train_loader)
    with open(f"{pathlib.Path(dataset_path).stem}_mean_acts.pkl", "wb") as file:
        pickle.dump(zero_cache, file)

Processing datasets/greaterthan_gpt2-small_prompts.json


Total tokens: 100000
1408
Processing datasets/ioi/ioi_ABBA_template_0_prompts.json
Total tokens: 100000
2048
Processing datasets/ioi/ioi_ABBA_template_1_prompts.json
Total tokens: 100000
2560
Processing datasets/ioi/ioi_BABA_template_0_prompts.json
Total tokens: 100000
2048
Processing datasets/ioi/ioi_BABA_template_1_prompts.json
Total tokens: 100000
2560


In [11]:
# Save the cache

import pathlib

import pickle

with open(f"{pathlib.Path(dataset_path).stem}_mean_acts.pkl", "wb") as file:
    pickle.dump(zero_cache, file)

# Marks et al Datasets
