# Instance Wise Auto Circuits

Trying to figure out how to log instance gradients on weights

Thought I could use backwards hooks, but they only work on activations

So, I'll try following this tutorial using pytorch's functional interface
https://pytorch.org/tutorials/intermediate/per_sample_grads.html

In [118]:
# first, lets get the imports we need 
from typing import List, Tuple, Dict, Any
from collections import defaultdict
from tqdm import tqdm
import torch
from auto_circuit.data import load_datasets_from_json
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.prune_algos.edge_attribution_patching import edge_attribution_patching_prune_scores
from auto_circuit.data import BatchKey, PromptDataLoader
from auto_circuit.types import AblationType, PatchType, PruneScores, CircuitOutputs
from auto_circuit.utils.ablation_activations import src_ablations, batch_src_ablations
from auto_circuit.utils.graph_utils import patch_mode, patchable_model, set_all_masks, train_mask_mode
from auto_circuit.utils.tensor_ops import batch_avg_answer_diff
from auto_circuit.utils.misc import repo_path_to_abs_path
from auto_circuit.visualize import draw_seq_graph

In [119]:
device = "cpu" #TODO: debug mps error
model = load_tl_model("gpt2-small", device)



Loaded pretrained model gpt2-small into HookedTransformer


In [120]:
path = repo_path_to_abs_path("datasets/ioi/ioi_vanilla_template_prompts.json")
train_loader, test_loader = load_datasets_from_json(
    model=model,
    path=path,
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(33, 33),
)

In [121]:
model = patchable_model(
    model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    device=device,
)

## Compute Instance Wise Scores

In [122]:
# get zero ablations on input distribution
patch_outs: Dict[BatchKey, torch.Tensor] = {}
for batch in train_loader:
    patch_outs[batch.key] = src_ablations(model, batch.clean, ablation_type=AblationType.ZERO)

In [123]:
prune_scores: Dict[str, List[torch.Tensor]] = defaultdict(list)
with train_mask_mode(model):
    set_all_masks(model, val=0.0)

    for batch in tqdm(train_loader):
        patch_src_outs = patch_outs[batch.key].clone().detach()
        with patch_mode(model, patch_src_outs, batch_size=batch.clean.shape[0]):
            logits = model(batch.clean)[model.out_slice]
            loss = -batch_avg_answer_diff(logits, batch)
            loss.backward()
        
        for dest_wrapper in model.dest_wrappers:
            prune_scores[dest_wrapper.module_name].append(dest_wrapper.patch_mask_batch.grad.detach().clone())
        model.zero_grad()

100%|██████████| 2/2 [00:09<00:00,  4.65s/it]


In [76]:
# filter resid pre and resid post (for parity with edge attribution implementation)
resid_pre_node = [node for node in model.srcs if node.name == "Resid Start"][0]
resid_post_node = [node for node in model.dests if node.name == "Resid End"][0]
resid_pre_node.src_idx, resid_post_node.module_name

# filter out resid pre
prune_scores_new = {
    k: [score[...,1:] for score in score_list] # I'm being dumb I think? I guess not everything has an edge
    for k, score_list in prune_scores.items()
}
# remove resid_post
del prune_scores_new[resid_post_node.module_name]


In [80]:
scores_stacked = {k: torch.concat(v) for k, v in prune_scores_new.items()}
# flatten along every axis except the first, then join across batch 
scores_vector = torch.concat([score.flatten(start_dim=1) for score in scores_stacked.values()], dim=1)
score_vector_dim = scores_vector.size(1)
score_vector_dim

31890

In [None]:
# jenson shannon

##  Compare to EAP implementation

### My Modified Implementation

In [81]:
import numpy as np

from eap.eap_wrapper import EAP_clean_forward_hook, EAP_clean_backward_hook
from eap.eap_graph import EAPGraph
from elk_experiments.eap_detector import valid_edge

In [144]:
graph = EAPGraph(
    model.cfg, 
    upstream_nodes=[
        "mlp", 
        "head", 
        # "resid_pre.0"#["resid_pre", "mlp", "head"], 
    ], 
    downstream_nodes=[
        "mlp",
        "head",
        # f"resid_post.{model.cfg.n_layers-1}", 
    ],
    aggregate_batch=False, 
    verbose=False
)

In [146]:
# iterate over downstream nodes, get hookslice * earler updstream nodes
# hmm, maybe get valid edge mask from this?
valid_edge_mask = np.zeros((len(graph.upstream_nodes), len(graph.downstream_nodes)), dtype=bool)
for hook in graph.downstream_hooks:
    layer, hook_type = hook.split(".")[1:3]
    hook_slice = graph.get_hook_slice(hook)
    if hook_type == "hook_mlp_in":
        slice_prev_upstream = graph.upstream_nodes_before_mlp_layer[int(layer)]
    elif hook_type == "hook_resid_post":
        slice_prev_upstream = graph.upstream_nodes_before_layer[int(layer)+1]
    else:
        slice_prev_upstream = graph.upstream_nodes_before_layer[int(layer)]
    valid_edge_mask[slice_prev_upstream , hook_slice] = 1
valid_edge_mask.sum()

31890

In [147]:
assert valid_edge_mask.sum() == score_vector_dim

In [148]:
# run eap on invidiual instances (pulling from eap_detector)
from transformer_lens import HookedTransformer
from elk_experiments.eap_detector import set_model
model = HookedTransformer.from_pretrained("gpt2-small")
model.to(device)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cpu


In [149]:
from functools import partial

def gen_hooks(upstream_actiation_difference, graph):
    upstream_hook_filter = lambda name: name.endswith(tuple(graph.upstream_hooks))
    downstream_hook_filter = lambda name: name.endswith(tuple(graph.downstream_hooks))

    clean_upstream_hook_fn = partial(
        EAP_clean_forward_hook,
        upstream_activations_difference=upstream_activations_difference,
        graph=graph
    )

    clean_downstream_hook_fn = partial(
        EAP_clean_backward_hook,
        upstream_activations_difference=upstream_activations_difference,
        graph=graph, 
        aggregate_batch=False
    )
    return clean_upstream_hook_fn, clean_downstream_hook_fn


In [151]:
model.train()

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [181]:
eap_scores = []
with torch.enable_grad():
    for batch in tqdm(train_loader):
        batch_size, seq_len = batch.clean.shape[:2]
        # set hooks
        model.reset_hooks()
        graph.reset_scores(batch_size)
        upstream_activations_difference = torch.zeros(
            (batch_size, seq_len, graph.n_upstream_nodes, model.cfg.d_model),
            device=model.cfg.device,
            dtype=model.cfg.dtype,
            requires_grad=False
        )
        clean_upstream_hook_fn, clean_downstream_hook_fn = gen_hooks(upstream_activations_difference, graph)
        upstream_hook_filter = lambda name: name.endswith(tuple(graph.upstream_hooks))
        downstream_hook_filter = lambda name: name.endswith(tuple(graph.downstream_hooks))
        model.add_hook(upstream_hook_filter, clean_upstream_hook_fn, "fwd")
        model.add_hook(downstream_hook_filter, clean_downstream_hook_fn, "bwd")
        #TODO: add support for corrupted tokens

        logits = model(batch.clean, return_type="logits")[:, -1, :] # batch, seq_len, vocab
        value = batch_avg_answer_diff(logits, batch)
        value.backward()

        model.zero_grad()
        upstream_activations_difference *= 0
        eap_scores_flat = graph.eap_scores[:, valid_edge_mask]
        assert eap_scores_flat.shape == (batch_size, valid_edge_mask.sum())
        eap_scores.append(eap_scores_flat)


100%|██████████| 2/2 [00:03<00:00,  1.98s/it]


In [182]:
# hmm, seems very close, that's great, I should try to figure out how to align the axes
abs(eap_scores[0][0]).sum(), abs(scores_vector[0]).sum()

(tensor(50.0864), tensor(50.0864))

In [183]:
eap_scores[0][0].median(), scores_vector[0].median() # need to flip signs on one

(tensor(5.1272e-06), tensor(5.1271e-06))

In [184]:
eap_scores[0][0], scores_vector[0]

(tensor([ 8.4679e-03, -3.8574e-04, -5.1559e-05,  ..., -3.5358e-03,
         -9.7548e-05, -1.9848e-05]),
 tensor([-0.0081,  0.0035, -0.0047,  ..., -0.0136, -0.0091,  0.0103]))