In [1]:
import sys 
sys.path.append("/home/daniel/ml_workspace/circuit-finder")

In [2]:
from circuit_finder.pretrained import load_model
model = load_model()



Loaded pretrained model gpt2 into HookedTransformer


In [3]:
from circuit_finder.pretrained import load_attn_saes, load_hooked_mlp_transcoders
from circuit_finder.patching.indirect_leap import preprocess_attn_saes

attn_sae_dict = load_attn_saes()
attn_sae_dict = preprocess_attn_saes(attn_sae_dict, model)
hooked_mlp_transcoder_dict = load_hooked_mlp_transcoders()

attn_saes = list(attn_sae_dict.values())
transcoders = list(hooked_mlp_transcoder_dict.values())


Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

In [4]:
clean_text = "When John and Mary went to the shop, John gave a bottle to"
answer = " Mary"
wrong_answer = " John"
corrupt_text = "When Alice and Bob went to the shop, Charlie gave a bottle to"


In [5]:
# Tokenize
clean_tokens = model.to_tokens(clean_text)
answer_tokens = model.to_tokens(answer, prepend_bos=False).squeeze(-1)
wrong_answer_tokens = model.to_tokens(wrong_answer, prepend_bos=False).squeeze(-1)
corrupt_tokens = model.to_tokens(corrupt_text)

print(clean_tokens.shape)
print(answer_tokens.shape)

torch.Size([1, 15])
torch.Size([1])


In [19]:
import torch
from eindex import eindex
from circuit_finder.patching.eap_graph import EAPGraph
from circuit_finder.patching.ablate import get_metric_with_ablation
from circuit_finder.patching.indirect_leap import IndirectLEAP, LEAPConfig
from circuit_finder.utils import clear_memory

ablate_tokens = corrupt_tokens

def compute_logit_diff(model, clean_tokens, answer_tokens, wrong_answer_tokens):
    clean_logits = model(clean_tokens)
    last_logits = clean_logits[:, -1, :]
    correct_logits = eindex(last_logits, answer_tokens, "batch [batch]")
    wrong_logits = eindex(last_logits, wrong_answer_tokens, "batch [batch]")
    return correct_logits - wrong_logits

def metric_fn(model, tokens):
    logit_diff = compute_logit_diff(model, tokens, answer_tokens, wrong_answer_tokens)
    return logit_diff.mean()

# NOTE: First, get the ceiling of the patching metric.
# TODO: Replace 'last_token_logit' with logit difference
with torch.no_grad():
    ceiling = metric_fn(model, clean_tokens).item()
print(ceiling)

# NOTE: Second, get floor of patching metric using empty graph, i.e. ablate everything
with torch.no_grad():
    empty_graph = EAPGraph([])
    floor = get_metric_with_ablation(
        model,
        empty_graph,
        clean_tokens,
        metric_fn,
        hooked_mlp_transcoder_dict,
        attn_sae_dict,
        ablate_nodes="bm",
        ablate_errors=False,  # Do not ablate errors when running forward pass
        first_ablated_layer=2,
        corrupt_tokens = ablate_tokens,
    ).item()
clear_memory()
print(floor)


# now sweep over thresholds to get graphs with variety of numbers of nodes
# for each graph we calculate faithfulness
num_nodes_list = []
metrics_list = []

# Sweep over thresholds
# TODO: make configurable
# thresholds = [0.001, 0.003, 0.006, 0.01, 0.03, 0.06, 0.1, 0.3, 0.6, 1.0]
thresholds = [0.03]
for threshold in thresholds:
    # Setup LEAP algorithm
    model.reset_hooks()
    cfg = LEAPConfig(threshold=threshold,
                    contrast_pairs=False, 
                    qk_enabled=True,
                    chained_attribs=True,
                    abs_attribs = False,
                    store_error_attribs=True)
    leap = IndirectLEAP(
        cfg=cfg,
        tokens=clean_tokens,
        model=model,
        metric=metric_fn,
        attn_saes=attn_saes,  # type: ignore
        transcoders=transcoders,
        corrupt_tokens=ablate_tokens,
    )

    # Populate the graph
    leap.run()

    # Save the graph
    graph = EAPGraph(leap.graph)
    error_graph = EAPGraph(leap.error_graph)
    num_nodes = len(graph.get_src_nodes())

    # Delete tensors to save memory
    del leap
    clear_memory()

    # # Calculate the metric under ablation
    with torch.no_grad():
        metric = get_metric_with_ablation(
            model,
            graph,
            clean_tokens,
            metric_fn,
            hooked_mlp_transcoder_dict,
            attn_sae_dict,
            ablate_nodes="bm",
            ablate_errors=False,
            first_ablated_layer=2,
            corrupt_tokens = ablate_tokens,
        ).item()
    clear_memory()
    print(metric)

3.373994827270508
1.0941038131713867
3.5999412536621094


In [20]:
from circuit_finder.plotting import make_html_graph

print(len(graph.get_edges()))
make_html_graph(graph, tokens = model.to_str_tokens(clean_tokens), error_graph = error_graph.graph)

187
graph.html
Generated graph.html. Open this file in Live Server to view the graph.


In [23]:
# Convert the graph to a dataframe

import pandas as pd 
from circuit_finder.core.types import parse_node_name

rows = []
for edge, edge_info, edge_type in graph.graph:
    (dest, src) = edge
    if dest == "null": continue
    (node_node_attr, node_node_grad, edge_metric_attr, edge_metric_grad) = edge_info

    src_module_name, src_layer, src_token_idx, src_feature_idx = parse_node_name(src)    
    dest_module_name, dest_layer, dest_token_idx, dest_feature_idx = parse_node_name(dest)

    rows.append({
        "src_module_name": src_module_name,
        "src_layer": src_layer,
        "src_token_idx": src_token_idx,
        "src_feature_idx": src_feature_idx,
        "dest_module_name": dest_module_name,
        "dest_layer": dest_layer,
        "dest_token_idx": dest_token_idx,
        "dest_feature_idx": dest_feature_idx,
        "edge_metric_attr": edge_metric_attr,
        "edge_metric_grad": edge_metric_grad,
        "node_node_attr": node_node_attr,
        "node_node_grad": node_node_grad,
        "edge_type": edge_type
    }) 

df = pd.DataFrame(rows)
print(len(df))
df.head()

172


Unnamed: 0,src_module_name,src_layer,src_token_idx,src_feature_idx,dest_module_name,dest_layer,dest_token_idx,dest_feature_idx,edge_metric_attr,edge_metric_grad,node_node_attr,node_node_grad,edge_type
0,mlp,7,14,15311,metric,12,14,0,0.009911,0.030978,0.009911,0.030978,
1,mlp,8,14,14733,metric,12,14,0,0.009881,0.052038,0.009881,0.052038,
2,mlp,9,14,10182,metric,12,14,0,0.015252,0.064651,0.015252,0.064651,
3,mlp,9,14,19418,metric,12,14,0,0.01347,0.034007,0.01347,0.034007,
4,mlp,10,14,5633,metric,12,14,0,0.02706,0.049506,0.02706,0.049506,


# Analysis

In [36]:
def get_outgoing_edge_df(
    df: pd.DataFrame, 
    src_module_name: str,
    src_layer: int,
    src_feature_idx: int,
):
    return df[
        (df["src_module_name"] == src_module_name)
        & (df["src_layer"] == src_layer)
        & (df["src_feature_idx"] == src_feature_idx)
    ]

def get_incoming_edge_df(
    df: pd.DataFrame, 
    dest_module_name: str,
    dest_layer: int,
    dest_feature_idx: int,
):
    return df[
        (df["dest_module_name"] == dest_module_name)
        & (df["dest_layer"] == dest_layer)
        & (df["dest_feature_idx"] == dest_feature_idx)
    ]

In [37]:
# Upstream edges for layer 6 att 17410

get_incoming_edge_df(df, "attn", 6, 17410)

Unnamed: 0,src_module_name,src_layer,src_token_idx,src_feature_idx,dest_module_name,dest_layer,dest_token_idx,dest_feature_idx,edge_metric_attr,edge_metric_grad,node_node_attr,node_node_grad,edge_type
63,mlp,0,3,5348,attn,6,10,17410,0.014283,0.143771,0.099345,1.243985,ov
64,mlp,0,3,10461,attn,6,10,17410,0.013178,0.107297,0.09166,0.746299,ov
65,mlp,1,3,15111,attn,6,10,17410,0.013677,0.050008,0.095129,0.347828,ov
66,mlp,2,3,3665,attn,6,10,17410,0.008721,0.030636,0.06066,0.21309,ov
67,mlp,5,3,6307,attn,6,10,17410,0.005867,0.038147,0.040811,0.265331,ov
68,attn,0,3,8162,attn,6,10,17410,0.007761,0.079605,0.053985,0.553688,ov
75,mlp,0,10,2343,attn,6,10,17410,0.028651,0.064691,0.19928,0.449959,q
76,mlp,0,10,13881,attn,6,10,17410,0.052443,0.143771,0.364763,8.005786,q
77,mlp,0,10,16165,attn,6,10,17410,0.019409,0.036999,0.134996,0.257349,q
78,mlp,1,10,5913,attn,6,10,17410,0.039418,0.143771,0.274171,1.933361,q


In [38]:
# Downstream edges for layer 6 att 17410

get_outgoing_edge_df(df, "attn", 6, 17410)

Unnamed: 0,src_module_name,src_layer,src_token_idx,src_feature_idx,dest_module_name,dest_layer,dest_token_idx,dest_feature_idx,edge_metric_attr,edge_metric_grad,node_node_attr,node_node_grad,edge_type
58,attn,6,10,17410,attn,8,14,16513,0.047495,0.066926,0.709671,2.920429,ov


In [39]:
node_df = df[
    (df["src_layer"] == 8) 
    & (df["src_module_name"] == "attn") 
    & (df["src_feature_idx"] == 16513)
]

node_df

Unnamed: 0,src_module_name,src_layer,src_token_idx,src_feature_idx,dest_module_name,dest_layer,dest_token_idx,dest_feature_idx,edge_metric_attr,edge_metric_grad,node_node_attr,node_node_grad,edge_type
40,attn,8,14,16513,attn,10,14,3849,0.011501,0.043708,0.263123,2.460771,q


In [40]:
node_df = df[
    (df["src_layer"] == 10) 
    & (df["src_module_name"] == "attn") 
    & (df["src_feature_idx"] == 3849)
]

node_df

Unnamed: 0,src_module_name,src_layer,src_token_idx,src_feature_idx,dest_module_name,dest_layer,dest_token_idx,dest_feature_idx,edge_metric_attr,edge_metric_grad,node_node_attr,node_node_grad,edge_type
15,attn,10,14,3849,metric,12,14,0,0.065984,1.0,0.065984,1.431793,


## Edge Ablation

In [None]:
from circuit_finder.patching.ablate import splice_model_with_saes_and_transcoders

with splice_model_with_saes_and_transcoders(model, transcoders, saes) as spliced_model:
    _, clean_cache = model.run_with_cache()


In [None]:
from transformer_lens import ActivationCache
from circuit_finder.core.types import Node, parse_node_name

def get_edge_patch_hook(
    clean_cache: ActivationCache,
    corrupt_cache: ActivationCache,
    src_module, # either HookedSAE or HookedTranscoder
    src_node: Node,
    dest_module, # either HookedSAE or HookedTranscoder
    dest_node: Node,
):
    pass
    


First, let's try ablating one of the important edges and confirm that the metric goes down. 