In [2]:
import sys 
sys.path.append("/home/daniel/ml_workspace/circuit-finder")

In [3]:
from circuit_finder.pretrained import load_model
model = load_model()



Loaded pretrained model gpt2 into HookedTransformer


In [4]:
from circuit_finder.pretrained import load_attn_saes, load_hooked_mlp_transcoders
from circuit_finder.patching.indirect_leap import preprocess_attn_saes

attn_sae_dict = load_attn_saes()
attn_sae_dict = preprocess_attn_saes(attn_sae_dict, model)
hooked_mlp_transcoder_dict = load_hooked_mlp_transcoders()

attn_saes = list(attn_sae_dict.values())
transcoders = list(hooked_mlp_transcoder_dict.values())


Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

In [5]:
clean_text = "When John and Mary went to the shop, John gave a bottle to"
answer = " Mary"
wrong_answer = " John"
corrupt_text = "When Alice and Bob went to the shop, Charlie gave a bottle to"


In [9]:
# Tokenize
clean_tokens = model.to_tokens(clean_text)
answer_tokens = model.to_tokens(answer, prepend_bos=False).squeeze(-1)
wrong_answer_tokens = model.to_tokens(wrong_answer, prepend_bos=False).squeeze(-1)
corrupt_tokens = model.to_tokens(corrupt_text)

print(clean_tokens.shape)
print(answer_tokens.shape)

torch.Size([1, 15])
torch.Size([1])


In [18]:
import torch
from eindex import eindex
from circuit_finder.patching.eap_graph import EAPGraph
from circuit_finder.patching.ablate import get_metric_with_ablation
from circuit_finder.patching.indirect_leap import IndirectLEAP, LEAPConfig
from circuit_finder.utils import clear_memory

ablate_tokens = corrupt_tokens

def compute_logit_diff(model, clean_tokens, answer_tokens, wrong_answer_tokens):
    clean_logits = model(clean_tokens)
    last_logits = clean_logits[:, -1, :]
    correct_logits = eindex(last_logits, answer_tokens, "batch [batch]")
    wrong_logits = eindex(last_logits, wrong_answer_tokens, "batch [batch]")
    return correct_logits - wrong_logits

def metric_fn(model, tokens):
    logit_diff = compute_logit_diff(model, tokens, answer_tokens, wrong_answer_tokens)
    return logit_diff.mean()

# NOTE: First, get the ceiling of the patching metric.
# TODO: Replace 'last_token_logit' with logit difference
with torch.no_grad():
    ceiling = metric_fn(model, clean_tokens).item()
print(ceiling)

# NOTE: Second, get floor of patching metric using empty graph, i.e. ablate everything
with torch.no_grad():
    empty_graph = EAPGraph([])
    floor = get_metric_with_ablation(
        model,
        empty_graph,
        clean_tokens,
        metric_fn,
        hooked_mlp_transcoder_dict,
        attn_sae_dict,
        ablate_nodes="bm",
        ablate_errors=False,  # Do not ablate errors when running forward pass
        first_ablated_layer=2,
        corrupt_tokens = ablate_tokens,
    ).item()
clear_memory()
print(floor)


# now sweep over thresholds to get graphs with variety of numbers of nodes
# for each graph we calculate faithfulness
num_nodes_list = []
metrics_list = []

# Sweep over thresholds
# TODO: make configurable
# thresholds = [0.001, 0.003, 0.006, 0.01, 0.03, 0.06, 0.1, 0.3, 0.6, 1.0]
thresholds = [0.01]
for threshold in thresholds:
    # Setup LEAP algorithm
    model.reset_hooks()
    cfg = LEAPConfig(
        threshold=threshold, contrast_pairs=True, chained_attribs=True
    )
    leap = IndirectLEAP(
        cfg=cfg,
        tokens=clean_tokens,
        model=model,
        metric=metric_fn,
        attn_saes=attn_saes,  # type: ignore
        transcoders=transcoders,
        corrupt_tokens=ablate_tokens,
    )

    # Populate the graph
    leap.metric_step()
    for layer in reversed(range(1, leap.n_layers)):
        leap.mlp_step(layer)
        leap.ov_step(layer)

    # Save the graph
    graph = EAPGraph(leap.graph)
    num_nodes = len(graph.get_src_nodes())

    # Delete tensors to save memory
    del leap
    clear_memory()

    # # Calculate the metric under ablation
    with torch.no_grad():
        metric = get_metric_with_ablation(
            model,
            graph,
            clean_tokens,
            metric_fn,
            hooked_mlp_transcoder_dict,
            attn_sae_dict,
            ablate_nodes="bm",
            ablate_errors=False,
            first_ablated_layer=2,
            corrupt_tokens = ablate_tokens,
        ).item()
    clear_memory()
    print(metric)

3.373994827270508
1.0941038131713867
2.8041763305664062


In [17]:
from circuit_finder.plotting import make_html_graph

print(len(graph.get_edges()))
make_html_graph(graph, tokens = model.to_str_tokens(clean_tokens))

267
graph.html
Generated graph.html. Open this file in Live Server to view the graph.
