# Hypothesis Testing Circuits

The primary tests we want to replicate are the faithfulness test (Equiavelance) and the minimality test (Minimality)

Faithfulness Test
- the full model and ablated model are equally likely to outpeform each other on a task
    - probability of the delta of the score function is greater then 0 - 1/2 is less than epsilon
    - use a sign test
    - test statistic is the # of times C* and M outperform each other 
    - test statistic is absolue value of the difference between the mean that the ablated model outpeforms the original model and 1/2
- P value (value of test statistic such that P(T > T_obs) = alpha) given by binomial distribution P(K > (T_obs + 1/2) * n | \theta = (1/2) + \epsilon)
    how do we choose epsilon?
    set to 0.1



In [None]:
# datasets
# indirect object identification
# doctring 
# tracr-P, tracr-R


# Minimal Faithful Circuit According to Attribution Score Ordering

In [None]:
# how to compute t_obs 
# absolue value of mean of samples with ablated C > M - n/2

# how to compute the p value given t_obs 
# binomial of k >= num ablated C > M samples [(t + 1/2) * n] given p = 1/2 + epsilon (epsilon = 0.01)

# Bonferroni correction
    # divide alpha by number of tests

# binary search for smallest number of edges on attribution ordering for which test passes




In [20]:
from functools import partial
import torch

from transformer_lens import HookedTransformer

from elk_experiments.utils import repo_path_to_abs_path, set_model, repo_path_to_abs_path
from elk_experiments.auto_circuit_utils import make_prompt_data_loader, make_mixed_prompt_dataloader, sorted_scores

from auto_circuit.types import AblationType, PatchType, CircuitOutputs
from auto_circuit.data import PromptPairBatch   
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.tasks import DOCSTRING_COMPONENT_CIRCUIT_TASK
from auto_circuit.prune import run_circuits

In [18]:
task = DOCSTRING_COMPONENT_CIRCUIT_TASK
ablation_type = AblationType.RESAMPLE
grad_function = "logit"
answer_function = "avg_diff"

In [11]:
# compute edge scores
attribution_scores = mask_gradient_prune_scores(
    model=task.model, 
    dataloader=task.train_loader,
    official_edges=None,
    grad_function=grad_function,
    answer_function=answer_function,
    mask_val=0.0, 
    ablation_type=AblationType.RESAMPLE,
    clean_corrupt="corrupt"
)

VBox(children=(          | 0/1 [00:00<?, ?it/s],))

In [49]:
# staistical test
# circuit out 
circuit_out = run_circuits(
    model=task.model, 
    dataloader=task.test_loader,
    test_edge_counts=[len(task.model.edges) // 2],
    prune_scores=attribution_scores,
    patch_type=PatchType.TREE_PATCH,
    ablation_type=ablation_type,
    reverse_clean_corrupt=False,
)
circuit_out = dict(circuit_out[len(task.model.edges) // 2])

VBox(children=(          | 0/2 [00:00<?, ?it/s],))

In [50]:
# model out
model_out: CircuitOutputs = {}
for batch in task.test_loader:
    model_out[batch.key] = task.model(batch.clean)[task.model.out_slice]

In [95]:
from auto_circuit.data import PromptDataLoader
from scipy.stats import binom
import numpy as np
def compute_num_C_gt_M(
    circ_out: CircuitOutputs, 
    model_out: CircuitOutputs, 
    dataloader: PromptDataLoader, 
) -> tuple[int, int]:
    # compute number of samples with ablated C > M
    num_ablated_C_gt_M = 0
    n = 0
    for batch in dataloader:
        bs = batch.clean.size(0)
        circ_out_batch = circ_out[batch.key]
        model_out_batch = model_out[batch.key]
        circ_out_answers = torch.gather(circ_out_batch, 1, batch.answers).sum(dim=1)
        circ_out_wrong_answers = torch.gather(circ_out_batch, 1, batch.wrong_answers).sum(dim=1)
        model_out_answers = torch.gather(model_out_batch, 1, batch.answers).sum(dim=1)
        model_out_wrong_answers = torch.gather(model_out_batch, 1, batch.wrong_answers).sum(dim=1)
        assert circ_out_answers.ndim == 1, circ_out_answers.shape 
        assert circ_out_answers.size(0)
        assert circ_out_answers.shape == model_out_answers.shape == circ_out_wrong_answers.shape == model_out_wrong_answers.shape
        num_ablated_C_gt_M += torch.sum(
            circ_out_answers - circ_out_wrong_answers > model_out_answers - model_out_wrong_answers
        ).item()
        n += bs
    return num_ablated_C_gt_M, n 

def run_non_equiv_test(num_ablated_C_gt_M: int, n: int, alpha: float = 0.05, epsilon: float = 0.1) -> tuple[bool, float]:
    p = 1 / 2 + epsilon
    k = num_ablated_C_gt_M
    left_tail = binom.cdf(k, n, p)
    right_tail = 1 - binom.cdf(n-k, n, p)
    p_value = left_tail + right_tail
    return p_value < alpha, p_value 

In [96]:
num_ablated_C_gt_M, n = compute_num_C_gt_M(circuit_out, model_out, task.test_loader)
not_equiv, p_value = run_non_equiv_test(num_ablated_C_gt_M, n, alpha=0.05, epsilon=0.1)
num_ablated_C_gt_M, n, not_equiv, p_value

(83, 256, True, 0.005120892411944023)

In [111]:
# binary search for smallest number of edges on attribution ordering for which test passes
alpha = 0.05
epsilon = 0.1
edge_count_interval = [i for i in range(task.model.n_edges + 1)]
min_equiv = edge_count_interval[-1]
while len(edge_count_interval) > 0:
    midpoint = len(edge_count_interval) // 2
    edge_count = edge_count_interval[midpoint]
    print(
        "cur", edge_count, 
        "min", edge_count_interval[0], 
        "max", edge_count_interval[-1], 
        "len", len(edge_count_interval)
    )
    circuit_out = run_circuits(
        model=task.model, 
        dataloader=task.test_loader,
        test_edge_counts=[edge_count],
        prune_scores=attribution_scores,
        patch_type=PatchType.TREE_PATCH,
        ablation_type=ablation_type,
        reverse_clean_corrupt=False,
    )
    circuit_out = dict(circuit_out[edge_count])
    # model out
    model_out: CircuitOutputs = {}
    for batch in task.test_loader:
        model_out[batch.key] = task.model(batch.clean)[task.model.out_slice]
    # run statitiscal test 
    num_ablated_C_gt_M, n = compute_num_C_gt_M(circuit_out, model_out, task.test_loader)
    not_equiv, p_value = run_non_equiv_test(num_ablated_C_gt_M, n, alpha, epsilon)

    if not_equiv:
        print(f"not equiv at {edge_count}, increase edge count")
        edge_count_interval = edge_count_interval[midpoint+1:] # more edges 
    else:
        min_equiv = edge_count
        print(f"equiv at {edge_count}, decrease edge count")
        edge_count_interval = edge_count_interval[:midpoint] # less edges
# note this itself is not a perfect procedure - there could be a smaller circuit which is more faithful, but in general larger circuits tend to be more faithful


cur 641 min 0 max 1281 len 1282


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

not equiv at 641, increase edge count
cur 962 min 642 max 1281 len 640


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 962, decrease edge count
cur 802 min 642 max 961 len 320


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 802, decrease edge count
cur 722 min 642 max 801 len 160


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 722, decrease edge count
cur 682 min 642 max 721 len 80


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 682, decrease edge count
cur 662 min 642 max 681 len 40


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

not equiv at 662, increase edge count
cur 672 min 663 max 681 len 19


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 672, decrease edge count
cur 667 min 663 max 671 len 9


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 667, decrease edge count
cur 665 min 663 max 666 len 4


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 665, decrease edge count
cur 664 min 663 max 664 len 2


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 664, decrease edge count
cur 663 min 663 max 663 len 1


VBox(children=(          | 0/2 [00:00<?, ?it/s],))

equiv at 663, decrease edge count


In [128]:
task.model.edge_dict

defaultdict(list,
            {None: [A0.5->A1.0.Q,
              A0.5->A3.2.K,
              A0.5->A1.6.V,
              A0.5->A3.2.V,
              A0.5->A3.7.Q,
              A0.5->A2.6.V,
              A0.5->A3.0.V,
              A0.5->A1.1.K,
              A0.5->A1.3.V,
              A0.5->A2.5.K,
              A0.5->A2.2.Q,
              A0.5->A2.5.Q,
              A0.5->A1.2.V,
              A0.5->A1.7.V,
              A0.5->A2.6.K,
              A0.5->A2.6.Q,
              A0.5->A1.2.K,
              A0.5->A1.1.V,
              A0.5->A1.5.V,
              A0.5->A3.6.K,
              A0.5->A2.7.V,
              A0.5->A1.0.V,
              A0.5->A1.3.Q,
              A0.5->A3.7.K,
              A0.5->A3.5.V,
              A0.5->A2.4.Q,
              A0.5->A1.4.V,
              A0.5->A3.5.Q,
              A0.5->A2.3.K,
              A0.5->A1.6.Q,
              A0.5->A3.1.V,
              A0.5->A2.5.V,
              A0.5->A3.7.V,
              A0.5->A1.4.Q,
              A0.5->A3.3

In [127]:
from auto_circuit.visualize import draw_seq_graph
from auto_circuit.utils.tensor_ops import prune_scores_threshold
fig = draw_seq_graph(
    model=task.model,
    prune_scores=attribution_scores,
    score_threshold=prune_scores_threshold(attribution_scores, min_equiv),
    show_all_seq_pos=False,
    seq_labels=task.train_loader.seq_labels,
)

# Minimality Test

In [None]:
q_star = 0.9

In [255]:
from itertools import product
import random
# contruct node dictionaries
srcs_by_layer = {
    layer: {
        head_idx: [n for n in task.model.srcs if (n.layer == layer and n.head_idx == head_idx)]
        for head_idx in range(task.model.cfg.n_heads)
    } 
    for layer in range(1, task.model.cfg.n_layers+1)
}
dests_by_layer = {
    layer: {
        head_idx: {
            n.module_name.split("_")[1]: n 
            for n in task.model.dests if (n.layer == layer and n.head_idx == head_idx)
        }
        for head_idx in range(task.model.cfg.n_heads)
    } 
    for layer in range(1, task.model.cfg.n_layers+1)
}

# enumerate all paths
layer_dests = product(["q", "k", "v"], range(task.model.cfg.n_heads))
layer_dests = list(layer_dests) + [(None, None)]
paths = list(product(layer_dests, repeat=task.model.cfg.n_layers))


# sample and contruct path
path_idx = random.choice(paths)
path = []
cur_src = next((n for n in task.model.srcs if n.layer == 0)) # Resid Start
for layer_idx, (attn_in, head_idx) in enumerate(path_idx):
    if head_idx == None:
        continue  
    # get dest 
    dest = dests_by_layer[layer_idx+1][head_idx][attn_in]
    # append to path 
    path.append((cur_src, dest))
    # get next src 
    cur_src = srcs_by_layer[layer_idx+1][head_idx]
# add final dest
path.append((cur_src, next((n for n in task.model.dests if n.layer == task.model.cfg.n_layers+1)))) # Resid End
path


[(SrcNode(name='Resid Start', module_name='blocks.0.hook_resid_pre', layer=0, head_idx=None, head_dim=None, weight='embed.W_E', weight_head_dim=None, src_idx=0),
  DestNode(name='A0.5.K', module_name='blocks.0.hook_k_input', layer=1, head_idx=5, head_dim=2, weight='blocks.0.attn.W_K', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A0.5', module_name='blocks.0.attn.hook_result', layer=1, head_idx=5, head_dim=2, weight='blocks.0.attn.W_O', weight_head_dim=0, src_idx=6)],
  DestNode(name='A1.4.K', module_name='blocks.1.hook_k_input', layer=2, head_idx=4, head_dim=2, weight='blocks.1.attn.W_K', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A1.4', module_name='blocks.1.attn.hook_result', layer=2, head_idx=4, head_dim=2, weight='blocks.1.attn.W_O', weight_head_dim=0, src_idx=13)],
  DestNode(name='A2.0.V', module_name='blocks.2.hook_v_input', layer=3, head_idx=0, head_dim=2, weight='blocks.2.attn.W_V', weight_head_dim=0, min_src_idx=0)),
 ([SrcNode(name='A2.0', module_name='bloc

In [None]:
# set score to inf if path in edge
# run circuit 
# radnomly ablate one of the edges 
# update score mask 
# run circuit 
# compute difference in output (should factor that out from above)

# for each edge 
    # do this process, compute number of times edge ablation is more significant than random edge ablation 
# commpute p value


In [134]:
# get ordered list of src, dest tuples for the same nodes (e.g. src and test for A0.4 (start and stop nodes will have None)
# starting with last list (dest stop node), sample whether to incldue the previous layer, and if so which node (1/n layers, and uniform over nodes)
# keep sampling until arriving at root

# use dests to map to dest wrappers, src_idxs to map to indices in dest wrappers

# create dest_wrapper mask 
# take or with threshold mask

# now how to patch with these extra edges
# I guess just set the prune scoresthe patch mask before


(1, 5)