# Hypothesis Testing Circuits

The primary tests we want to replicate are the faithfulness test (Equiavelance) and the minimality test (Minimality)

Faithfulness Test
- the full model and ablated model are equally likely to outpeform each other on a task
    - probability of the delta of the score function is greater then 0 - 1/2 is less than epsilon
    - use a sign test
    - test statistic is the # of times C* and M outperform each other 
    - test statistic is absolue value of the difference between the mean that the ablated model outpeforms the original model and 1/2
- P value (value of test statistic such that P(T > T_obs) = alpha) given by binomial distribution P(K > (T_obs + 1/2) * n | \theta = (1/2) + \epsilon)
    how do we choose epsilon?
    set to 0.1



In [None]:
# datasets
# indirect object identification
# doctring 
# tracr-P, tracr-R


# Minimal Faithful Circuit According to Attribution Score Ordering

In [None]:
# how to compute t_obs 
# absolue value of mean of samples with ablated C > M - n/2

# how to compute the p value given t_obs 
# binomial of k >= num ablated C > M samples [(t + 1/2) * n] given p = 1/2 + epsilon (epsilon = 0.01)

# Bonferroni correction
    # divide alpha by number of tests

# binary search for smallest number of edges on attribution ordering for which test passes




In [49]:
from functools import partial
import torch

from transformer_lens import HookedTransformer

from elk_experiments.utils import repo_path_to_abs_path, set_model, repo_path_to_abs_path
from elk_experiments.auto_circuit_utils import make_prompt_data_loader, make_mixed_prompt_dataloader, sorted_scores

from auto_circuit.types import AblationType, PatchType, CircuitOutputs
from auto_circuit.data import PromptPairBatch   
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.tasks import DOCSTRING_COMPONENT_CIRCUIT_TASK, docstring_true_edges
from auto_circuit.prune import run_circuits

In [2]:
task = DOCSTRING_COMPONENT_CIRCUIT_TASK
ablation_type = AblationType.RESAMPLE
grad_function = "logit"
answer_function = "avg_diff"

In [30]:
batch.clean.shape

torch.Size([128, 30])

In [45]:
# compute edge scores
attribution_scores = mask_gradient_prune_scores(
    model=task.model, 
    dataloader=task.train_loader,
    official_edges=None,
    grad_function=grad_function,
    answer_function=answer_function,
    mask_val=None, 
    ablation_type=AblationType.RESAMPLE,
    integrated_grad_samples=10,
    clean_corrupt="corrupt"
)

VBox(children=(          | 0/11 [00:00<?, ?it/s],))

In [4]:
# staistical test
# circuit out 
circuit_out = run_circuits(
    model=task.model, 
    dataloader=task.test_loader,
    test_edge_counts=[len(task.model.edges) // 2],
    prune_scores=attribution_scores,
    patch_type=PatchType.TREE_PATCH,
    ablation_type=ablation_type,
    reverse_clean_corrupt=False,
)
circuit_out = dict(circuit_out[len(task.model.edges) // 2])

VBox(children=(          | 0/2 [00:00<?, ?it/s],))

In [5]:
# model out
model_out: CircuitOutputs = {}
for batch in task.test_loader:
    model_out[batch.key] = task.model(batch.clean)[task.model.out_slice]

In [46]:
from auto_circuit.data import PromptDataLoader
from scipy.stats import binom
import numpy as np
def compute_num_C_gt_M(
    circ_out: CircuitOutputs, 
    model_out: CircuitOutputs, 
    dataloader: PromptDataLoader, 
) -> tuple[int, int]:
    # compute number of samples with ablated C > M
    num_ablated_C_gt_M = 0
    n = 0
    for batch in dataloader:
        bs = batch.clean.size(0)
        circ_out_batch = circ_out[batch.key]
        model_out_batch = model_out[batch.key]
        circ_out_answers = torch.gather(circ_out_batch, 1, batch.answers).sum(dim=1)
        circ_out_wrong_answers = torch.gather(circ_out_batch, 1, batch.wrong_answers).sum(dim=1)
        model_out_answers = torch.gather(model_out_batch, 1, batch.answers).sum(dim=1)
        model_out_wrong_answers = torch.gather(model_out_batch, 1, batch.wrong_answers).sum(dim=1)
        assert circ_out_answers.ndim == 1, circ_out_answers.shape 
        assert circ_out_answers.size(0)
        assert circ_out_answers.shape == model_out_answers.shape == circ_out_wrong_answers.shape == model_out_wrong_answers.shape
        num_ablated_C_gt_M += torch.sum(
            circ_out_answers - circ_out_wrong_answers > model_out_answers - model_out_wrong_answers
        ).item()
        n += bs
    return num_ablated_C_gt_M, n 

def run_non_equiv_test(num_ablated_C_gt_M: int, n: int, alpha: float = 0.05, epsilon: float = 0.1) -> tuple[bool, float]:
    #TODO: low vs ghih
    theta = 1 / 2 + epsilon
    k = num_ablated_C_gt_M
    left_tail = binom.cdf(min(n-k, k), n, theta)
    right_tail = 1 - binom.cdf(max(n-k, k), n, theta)
    p_value = left_tail + right_tail
    return p_value < alpha, p_value 

In [7]:
num_ablated_C_gt_M, n = compute_num_C_gt_M(circuit_out, model_out, task.test_loader)
not_equiv, p_value = run_non_equiv_test(num_ablated_C_gt_M, n, alpha=0.05, epsilon=0.1)
num_ablated_C_gt_M, n, not_equiv, p_value

(83, 256, True, 0.005120892411944023)

In [47]:
# binary search for smallest number of edges on attribution ordering for which test passes
alpha = 0.05
epsilon = 0.1
edge_count_interval = [i for i in range(task.model.n_edges + 1)]
min_equiv = edge_count_interval[-1]
while len(edge_count_interval) > 0:
    midpoint = len(edge_count_interval) // 2
    edge_count = edge_count_interval[midpoint]
    print(
        "cur", edge_count, 
        "min", edge_count_interval[0], 
        "max", edge_count_interval[-1], 
        "len", len(edge_count_interval)
    )
    circuit_out = run_circuits(
        model=task.model, 
        dataloader=task.test_loader,
        test_edge_counts=[edge_count],
        prune_scores=attribution_scores,
        patch_type=PatchType.TREE_PATCH,
        ablation_type=ablation_type,
        reverse_clean_corrupt=False,
    )
    circuit_out = dict(circuit_out[edge_count])
    # model out
    model_out: CircuitOutputs = {}
    for batch in task.test_loader:
        model_out[batch.key] = task.model(batch.clean)[task.model.out_slice]
    # run statitiscal test 
    num_ablated_C_gt_M, n = compute_num_C_gt_M(circuit_out, model_out, task.test_loader)
    not_equiv, p_value = run_non_equiv_test(num_ablated_C_gt_M, n, alpha, epsilon)

    if not_equiv:
        print(f"not equiv at {edge_count}, increase edge count")
        edge_count_interval = edge_count_interval[midpoint+1:] # more edges 
    else:
        min_equiv = edge_count
        print(f"equiv at {edge_count}, decrease edge count")
        edge_count_interval = edge_count_interval[:midpoint] # less edges
    min_equiv
# note this itself is not a perfect procedure - there could be a smaller circuit which is more faithful, but in general larger circuits tend to be more faithful


cur 641 min 0 max 1281 len 1282


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 641, decrease edge count
cur 320 min 0 max 640 len 641


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 320, decrease edge count
cur 160 min 0 max 319 len 320


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 160, decrease edge count
cur 80 min 0 max 159 len 160


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

not equiv at 80, increase edge count
cur 120 min 81 max 159 len 79


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

not equiv at 120, increase edge count
cur 140 min 121 max 159 len 39


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

not equiv at 140, increase edge count
cur 150 min 141 max 159 len 19


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 150, decrease edge count
cur 145 min 141 max 149 len 9


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 145, decrease edge count
cur 143 min 141 max 144 len 4


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

not equiv at 143, increase edge count
cur 144 min 144 max 144 len 1


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 144, decrease edge count


In [48]:
from auto_circuit.visualize import draw_seq_graph
from auto_circuit.utils.tensor_ops import prune_scores_threshold
fig = draw_seq_graph(
    model=task.model,
    prune_scores=attribution_scores,
    score_threshold=prune_scores_threshold(attribution_scores, min_equiv),
    show_all_seq_pos=False,
    seq_labels=task.train_loader.seq_labels,
)

In [55]:
# compute faithfulness at every 10 edges 
edge_counts = [i for i in range(0, task.model.n_edges + 1, 10)]

circuit_outs = run_circuits(
    model=task.model, 
    dataloader=task.test_loader,
    test_edge_counts=[edge_counts],
    prune_scores=attribution_scores,
    patch_type=PatchType.TREE_PATCH,
    ablation_type=ablation_type,
    reverse_clean_corrupt=False,
)
circuit_outs = dict(circuit_outs)
# model out
model_out: CircuitOutputs = {}
for batch in task.test_loader:
    model_out[batch.key] = task.model(batch.clean)[task.model.out_slice]
# run statitiscal tests for each edge count
test_results = {}
for edge_count, circuit_out in circuit_outs.items():
    num_ablated_C_gt_M, n = compute_num_C_gt_M(circuit_out, model_out, task.test_loader)
    not_equiv, p_value = run_non_equiv_test(num_ablated_C_gt_M, n, alpha, epsilon)
    test_results[edge_count] = (num_ablated_C_gt_M, n, not_equiv, p_value)


[0,
 10,
 20,
 30,
 40,
 50,
 60,
 70,
 80,
 90,
 100,
 110,
 120,
 130,
 140,
 150,
 160,
 170,
 180,
 190,
 200,
 210,
 220,
 230,
 240,
 250,
 260,
 270,
 280,
 290,
 300,
 310,
 320,
 330,
 340,
 350,
 360,
 370,
 380,
 390,
 400,
 410,
 420,
 430,
 440,
 450,
 460,
 470,
 480,
 490,
 500,
 510,
 520,
 530,
 540,
 550,
 560,
 570,
 580,
 590,
 600,
 610,
 620,
 630,
 640,
 650,
 660,
 670,
 680,
 690,
 700,
 710,
 720,
 730,
 740,
 750,
 760,
 770,
 780,
 790,
 800,
 810,
 820,
 830,
 840,
 850,
 860,
 870,
 880,
 890,
 900,
 910,
 920,
 930,
 940,
 950,
 960,
 970,
 980,
 990,
 1000,
 1010,
 1020,
 1030,
 1040,
 1050,
 1060,
 1070,
 1080,
 1090,
 1100,
 1110,
 1120,
 1130,
 1140,
 1150,
 1160,
 1170,
 1180,
 1190,
 1200,
 1210,
 1220,
 1230,
 1240,
 1250,
 1260,
 1270,
 1280]

# Minimality Test

In [10]:
q_star = 0.9

In [12]:
from itertools import product
import random
from auto_circuit.types import SrcNode, DestNode
# contruct node dictionaries
srcs_by_layer = {
    layer: {
        head_idx: [n for n in task.model.srcs if (n.layer == layer and n.head_idx == head_idx)]
        for head_idx in range(task.model.cfg.n_heads)
    } 
    for layer in range(1, task.model.cfg.n_layers+1)
}
dests_by_layer = {
    layer: {
        head_idx: {
            n.module_name.split("_")[1]: n 
            for n in task.model.dests if (n.layer == layer and n.head_idx == head_idx)
        }
        for head_idx in range(task.model.cfg.n_heads)
    } 
    for layer in range(1, task.model.cfg.n_layers+1)
}

# enumerate all paths
layer_dests = product(["q", "k", "v"], range(task.model.cfg.n_heads))
layer_dests = list(layer_dests) + [(None, None)]
paths = list(product(layer_dests, repeat=task.model.cfg.n_layers))


def sample_path() -> list[tuple[SrcNode, DestNode]]:
    # sample and contruct path
    path_idx = random.choice(paths)
    path = []
    cur_src = next((n for n in task.model.srcs if n.layer == 0)) # Resid Start
    for layer_idx, (attn_in, head_idx) in enumerate(path_idx):
        if head_idx == None:
            continue  
        # get dest 
        dest = dests_by_layer[layer_idx+1][head_idx][attn_in]
        # append to path 
        path.append((cur_src, dest))
        # get next src 
        cur_src = srcs_by_layer[layer_idx+1][head_idx]
    # add final dest
    path.append((cur_src, next((n for n in task.model.dests if n.layer == task.model.cfg.n_layers+1)))) # Resid End
    return path 
sample_path()


[(SrcNode(name='Resid Start', module_name='blocks.0.hook_resid_pre', layer=0, head_idx=None, head_dim=None, weight='embed.W_E', weight_head_dim=None, src_idx=0),
  DestNode(name='A0.7.K', module_name='blocks.0.hook_k_input', layer=1, head_idx=7, head_dim=2, weight='blocks.0.attn.W_K', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A0.7', module_name='blocks.0.attn.hook_result', layer=1, head_idx=7, head_dim=2, weight='blocks.0.attn.W_O', weight_head_dim=0, src_idx=8)],
  DestNode(name='A1.2.Q', module_name='blocks.1.hook_q_input', layer=2, head_idx=2, head_dim=2, weight='blocks.1.attn.W_Q', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A1.2', module_name='blocks.1.attn.hook_result', layer=2, head_idx=2, head_dim=2, weight='blocks.1.attn.W_O', weight_head_dim=0, src_idx=11)],
  DestNode(name='A2.6.V', module_name='blocks.2.hook_v_input', layer=3, head_idx=6, head_dim=2, weight='blocks.2.attn.W_V', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A2.6', module_name='bloc

In [18]:
x = torch.arange(12).reshape(3, 4)
mask = x > 5 

torch.where(mask)

(tensor([1, 1, 2, 2, 2, 2]), tensor([2, 3, 0, 1, 2, 3]))

In [19]:
torch.nonzero(mask)

tensor([[1, 2],
        [1, 3],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3]])

In [21]:
path = sample_path()
path
# hmm how does indexing with respect to head idx, q, k, v work?

[(SrcNode(name='Resid Start', module_name='blocks.0.hook_resid_pre', layer=0, head_idx=None, head_dim=None, weight='embed.W_E', weight_head_dim=None, src_idx=0),
  DestNode(name='A0.7.V', module_name='blocks.0.hook_v_input', layer=1, head_idx=7, head_dim=2, weight='blocks.0.attn.W_V', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A0.7', module_name='blocks.0.attn.hook_result', layer=1, head_idx=7, head_dim=2, weight='blocks.0.attn.W_O', weight_head_dim=0, src_idx=8)],
  DestNode(name='A1.3.Q', module_name='blocks.1.hook_q_input', layer=2, head_idx=3, head_dim=2, weight='blocks.1.attn.W_Q', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A1.3', module_name='blocks.1.attn.hook_result', layer=2, head_idx=3, head_dim=2, weight='blocks.1.attn.W_O', weight_head_dim=0, src_idx=12)],
  DestNode(name='A2.7.Q', module_name='blocks.2.hook_q_input', layer=3, head_idx=7, head_dim=2, weight='blocks.2.attn.W_Q', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A2.7', module_name='bloc

In [None]:
# to iterate over edges, just get all the indices of the score masks where the value is > threshold 
edges: list[tuple[str, torch.Tensor]] = []
threshold = prune_scores_threshold(attribution_scores, min_equiv)
for dest_mod_name, scores in attribution_scores.items():
    for src_idx in torch.nonzero(scores > threshold):
        edges.append(dest_mod_name, src_idx)

# compute circuits out with min_equiv
circuit_out = run_circuits(
    model=task.model, 
    dataloader=task.test_loader,
    test_edge_counts=[min_equiv],
    prune_scores=attribution_scores,
    patch_type=PatchType.TREE_PATCH,
    ablation_type=ablation_type,
    reverse_clean_corrupt=False,
)

# iterate over edges starting from last 
for dest_mode_name, src_idx in edges:
    # ablate edge 
    prune_scores_ablated = attribution_scores.copy()
    prune_scores_ablated[dest_mod_name][src_idx] = 0.0
    # compute circuit out
    circuit_out_ablated = run_circuits(
        model=task.model, 
        dataloader=task.test_loader,
        test_edge_counts=[min_equiv],
        prune_scores=prune_scores_ablated,
        patch_type=PatchType.TREE_PATCH,
        ablation_type=ablation_type,
        reverse_clean_corrupt=False,
    )
    # sample path 
    path = sample_path()
    # set score to inf if path in edge
    prune_scores_inflated = attribution_scores.copy()
    for src, dest in path:
        prune_scores_inflated[dest.module_name][src_idx] = float("inf")
    # run circuit 
    # randomly sample edge in path 
    # update score mask
    # run circuit 

    # iterate over dataloader 
        # compute frequency diff between full circuit and ablated edge is greater than inflated circuit - ablated circuit

    # compute p value by binomial test
    # p_value = binom.cdf(k, n, q_star) (one sided)
    # if p_value < alpha, return edge count

In [134]:
# get ordered list of src, dest tuples for the same nodes (e.g. src and test for A0.4 (start and stop nodes will have None)
# starting with last list (dest stop node), sample whether to incldue the previous layer, and if so which node (1/n layers, and uniform over nodes)
# keep sampling until arriving at root

# use dests to map to dest wrappers, src_idxs to map to indices in dest wrappers

# create dest_wrapper mask 
# take or with threshold mask

# now how to patch with these extra edges
# I guess just set the prune scoresthe patch mask before


(1, 5)