# Instance Wise Auto Circuits

Trying to figure out how to log instance gradients on weights

Thought I could use backwards hooks, but they only work on activations

So, I'll try following this tutorial using pytorch's functional interface
https://pytorch.org/tutorials/intermediate/per_sample_grads.html

In [None]:
# first, lets get the imports we need 
from typing import List, Tuple, Dict, Any
from collections import defaultdict
from tqdm import tqdm
import torch
from auto_circuit.data import load_datasets_from_json
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.prune_algos.edge_attribution_patching import edge_attribution_patching_prune_scores
from auto_circuit.data import BatchKey, PromptDataLoader
from auto_circuit.types import AblationType, PatchType, PruneScores, CircuitOutputs
from auto_circuit.utils.ablation_activations import src_ablations, batch_src_ablations
from auto_circuit.utils.graph_utils import patch_mode, patchable_model, set_all_masks, train_mask_mode, set_mask_batch_size
from auto_circuit.utils.tensor_ops import batch_avg_answer_diff
from auto_circuit.utils.misc import repo_path_to_abs_path
from auto_circuit.visualize import draw_seq_graph

In [None]:
def effect_prob_func(logits, effect_tokens, inputs=None):
    assert logits.ndim == 3
    # Sum over vocab and batch dim (for now we're just computing attribution values, we'll deal with per data instance later)
    probs = logits[:, -1, :].softmax(dim=-1)
    out = probs[:, effect_tokens].mean() # mean over effect tokens, mean over batch
    # out = logits[:, -1, effect_tokens].mean()
    return out

In [None]:
device = "cpu" #TODO: debug mps error
ac_model = load_tl_model("pythia-70m", device)

In [None]:
path = repo_path_to_abs_path("datasets/ioi/ioi_vanilla_template_prompts.json")
dataset_size = 32
batch_size = 16
train_loader, test_loader = load_datasets_from_json(
    model=ac_model,
    path=path,
    device=device,
    prepend_bos=True,
    batch_size=batch_size,
    train_test_size=(dataset_size, dataset_size),
)

In [None]:
ac_model = patchable_model(
    ac_model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    device=device,
    resid_src=False, 
    resid_dest=False,
    attn_src=True,
    attn_dest=True,
    mlp_src=False,
    mlp_dest=False,
)

In [None]:
effect_tokens = ac_model.tokenizer.encode(" else")[1:]

In [None]:
effect_tokens

## Compute Instance Wise Scores

In [None]:
# get zero ablations on input distribution
patch_outs: Dict[BatchKey, torch.Tensor] = {}
for batch in train_loader:
    patch_outs[batch.key] = src_ablations(ac_model, batch.clean, ablation_type=AblationType.ZERO)

In [None]:
prune_scores: Dict[str, List[torch.Tensor]] = defaultdict(list)
with set_mask_batch_size(ac_model, batch_size):
    with train_mask_mode(ac_model):
        set_all_masks(ac_model, val=0.0)

        for batch in tqdm(train_loader):
            patch_src_outs = patch_outs[batch.key].clone().detach()
            with patch_mode(ac_model, patch_src_outs):
                logits = ac_model(batch.clean)
                loss = effect_prob_func(logits, effect_tokens=effect_tokens)
                loss.backward()
            
            for dest_wrapper in ac_model.dest_wrappers:
                prune_scores[dest_wrapper.module_name].append(dest_wrapper.patch_mask.grad.detach().clone())
            ac_model.zero_grad()

In [None]:
next(iter(prune_scores.values()))[0].shape

In [None]:
next(iter(ac_model.dest_wrappers)).patch_mask.shape

In [None]:
# # filter resid pre and resid post (for parity with edge attribution implementation)
# resid_pre_node = [node for node in ac_model.srcs if node.name == "Resid Start"][0]
# resid_post_node = [node for node in ac_model.dests if node.name == "Resid End"][0]
# resid_pre_node.src_idx, resid_post_node.module_name

# # filter out resid pre
# prune_scores_new = {
#     k: [score[...,1:] for score in score_list] # I'm being dumb I think? I guess not everything has an edge
#     for k, score_list in prune_scores.items()
# }
# # remove resid_post
# del prune_scores_new[resid_post_node.module_name]
prune_scores_new = prune_scores

In [None]:
scores_stacked = {k: torch.concat(v) for k, v in prune_scores_new.items()}
# flatten along every axis except the first, then join across batch 
scores_vector = torch.concat([score.flatten(start_dim=1) for score in scores_stacked.values()], dim=1)
score_vector_dim = scores_vector.size(1)
score_vector_dim

##  Compare to EAP implementation

### My Modified Implementation

In [None]:
import numpy as np

from eap.eap_wrapper import EAP_clean_forward_hook, EAP_clean_backward_hook
from eap.eap_graph import EAPGraph

In [None]:
# run eap on invidiual instances (pulling from eap_detector)
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("pythia-70m")
model.to(device)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

In [None]:
graph = EAPGraph(
    model.cfg, 
    upstream_nodes=[
        # "mlp", 
        "head", 
        # "resid_pre.0"#["resid_pre", "mlp", "head"], 
    ], 
    downstream_nodes=[
        # "mlp",
        "head",
        # f"resid_post.{model.cfg.n_layers-1}", 
    ],
    aggregate_batch=False, 
    verbose=False
)

In [None]:
graph.downstream_hook_slice.keys()

In [None]:
# iterate over downstream nodes, get hookslice * earler updstream nodes
# hmm, maybe get valid edge mask from this?
valid_edge_mask = np.zeros((len(graph.upstream_nodes), len(graph.downstream_nodes)), dtype=bool)
for hook in graph.downstream_hooks:
    layer, hook_type = hook.split(".")[1:3]
    hook_slice = graph.get_hook_slice(hook)
    if hook_type == "hook_mlp_in":
        slice_prev_upstream = graph.upstream_nodes_before_mlp_layer[int(layer)]
    elif hook_type == "hook_resid_post":
        slice_prev_upstream = graph.upstream_nodes_before_layer[int(layer)+1]
    else:
        slice_prev_upstream = graph.upstream_nodes_before_layer[int(layer)]
    valid_edge_mask[slice_prev_upstream , hook_slice] = 1
valid_edge_mask.sum()

In [None]:
assert valid_edge_mask.sum() == score_vector_dim

In [None]:
from functools import partial

def gen_hooks(upstream_actiation_difference, graph):
    upstream_hook_filter = lambda name: name.endswith(tuple(graph.upstream_hooks))
    downstream_hook_filter = lambda name: name.endswith(tuple(graph.downstream_hooks))

    clean_upstream_hook_fn = partial(
        EAP_clean_forward_hook,
        upstream_activations_difference=upstream_activations_difference,
        graph=graph
    )

    clean_downstream_hook_fn = partial(
        EAP_clean_backward_hook,
        upstream_activations_difference=upstream_activations_difference,
        graph=graph, 
        aggregate_batch=False
    )
    return clean_upstream_hook_fn, clean_downstream_hook_fn


In [None]:
model.train()
eap_scores = []
with torch.enable_grad():
    for batch in tqdm(train_loader):
        batch_size, seq_len = batch.clean.shape[:2]
        # set hooks
        model.reset_hooks()
        graph.reset_scores(batch_size)
        upstream_activations_difference = torch.zeros(
            (batch_size, seq_len, graph.n_upstream_nodes, model.cfg.d_model),
            device=model.cfg.device,
            dtype=model.cfg.dtype,
            requires_grad=False
        )
        clean_upstream_hook_fn, clean_downstream_hook_fn = gen_hooks(upstream_activations_difference, graph)
        upstream_hook_filter = lambda name: name.endswith(tuple(graph.upstream_hooks))
        downstream_hook_filter = lambda name: name.endswith(tuple(graph.downstream_hooks))
        model.add_hook(upstream_hook_filter, clean_upstream_hook_fn, "fwd")
        model.add_hook(downstream_hook_filter, clean_downstream_hook_fn, "bwd")
        #TODO: add support for corrupted tokens

        logits = model(batch.clean, return_type="logits")# batch, seq_len, vocab
        value = effect_prob_func(logits, effect_tokens=effect_tokens)
        value.backward()

        model.zero_grad()
        upstream_activations_difference *= 0
        eap_scores_flat = graph.eap_scores[:, valid_edge_mask]
        assert eap_scores_flat.shape == (batch_size, valid_edge_mask.sum())
        eap_scores.append(eap_scores_flat)


In [None]:
graph.eap_scores.shape

In [None]:
# hmm, seems very close, that's great, I should try to figure out how to align the axes
abs(eap_scores[0][0]).sum(), abs(scores_vector[0]).sum()

In [None]:
len(ac_model.srcs)

In [None]:
# sort prune scores new according to eap graph
prune_scores_arr = torch.zeros((dataset_size, len(ac_model.srcs), len(ac_model.dests)))
for hook_name, scores_list in prune_scores_new.items():

    for i, score in enumerate(scores_list):
        # score: downtream, upstream
        layer, hook_type = hook_name.split(".")[1:3]
        # get upstream hook slice 
        if hook_type == "hook_mlp_in":
            upstream_slice = graph.upstream_nodes_before_mlp_layer[int(layer)]
        else: 
            upstream_slice = graph.upstream_nodes_before_layer[int(layer)]
        downstream_slice = graph.downstream_hook_slice[hook_name]
        if score.ndim == 2:
            if downstream_slice.stop - downstream_slice.start == 1:
                # need to add 1 
                score = score.unsqueeze(dim=1)
            elif upstream_slice.stop - upstream_slice.start == 1:
                # need to add 1
                score = score.unsqueeze(dim=2)
            else:
                raise ValueError("unexpected score shape")
        assert score.ndim == 3, score.shape
        score = score.transpose(1, 2) # downstream, upstream -> upstream, downstream
        # get downstream hook slice 
        prune_scores_arr[i * batch_size : (i+1) * batch_size, upstream_slice, downstream_slice] = score
    # get batch index

In [None]:
torch.allclose(prune_scores_arr[:, valid_edge_mask][0], eap_scores[0][0], atol=1e-5)

In [None]:
torch.allclose(prune_scores_arr[:, valid_edge_mask], torch.concat(eap_scores, dim=0), atol=1e-5)