In [1]:
!pip install git+https://github.com/neelnanda-io/neel-plotly.git

Collecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-wev4386v
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-wev4386v
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25ldone


In [1]:
import os
import pathlib
from typing import List, Optional, Union

import torch
import numpy as np
import yaml
import pickle
import einops
from fancy_einsum import einsum


import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching

from torch import Tensor
from tqdm.notebook import tqdm
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from rich import print as rprint

from typing import List, Union
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import re

from functools import partial

from torchtyping import TensorType as TT

from path_patching import Node, IterNode, path_patch, act_patch
from neel_plotly import imshow as imshow_n

from utils.visualization import get_attn_head_patterns
from utils.circuit_analysis import get_logit_diff, logit_diff_denoising, logit_diff_noising



ModuleNotFoundError: No module named 'path_patching'

In [None]:
update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow_p(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure
    for setting in ["tickangle"]:
      if f"xaxis_{setting}" in kwargs_post:
          i = 2
          while f"xaxis{i}" in fig["layout"]:
            kwargs_post[f"xaxis{i}_{setting}"] = kwargs_post[f"xaxis_{setting}"]
            i += 1
    fig.update_layout(**kwargs_post)
    fig.show(renderer=renderer)

def hist_p(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    names = kwargs_pre.pop("names", None)
    if "barmode" not in kwargs_post:
        kwargs_post["barmode"] = "overlay"
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.0
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.histogram(x=tensor, **kwargs_pre).update_layout(**kwargs_post)
    if names is not None:
        for i in range(len(fig.data)):
            fig.data[i]["name"] = names[i // 2]
    fig.show(renderer)

In [None]:
device = "cuda:1"
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    #"gemma-2b",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    #hf_model=source_model,
    device = device
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


### Load dataset

In [2]:
import pandas as pd
import json

dataset_name = "ner_dataset_20each.json"
with open("../../pos_cf_datasets/"+dataset_name, "r", encoding="utf-8") as f:
    tv_data = json.load(f)

tag_set = [
    "EVENT", "LOCATION", "MONEY",
    "NATIONALITY, RELIGIOUS, or POLITICAL GROUP",
    "ORGANIZATION", "NUMERICAL", "PERSON",
    "PRODUCT", "TIME"
]

template = """Sentence: Apple announced a new iPhone during its annual product launch event.,
Entity tag: PRODUCT,
Answer: iPhone

Sentence: Barack Obama delivered a keynote speech at the conference.,
Entity tag: PERSON,
Answer: Barack Obama

Sentence: Tesla invested over 2 billion dollars in a new gigafactory in Germany.,
Entity tag: MONEY,
Answer: 2 billion dollars

Sentence: The concert will take place at 8 p.m. on Saturday.,
Entity tag: TIME,
Answer: 8 p.m. on Saturday

Sentence: The Eiffel Tower is located in Paris.,
Entity tag: LOCATION,
Answer: Paris

Sentence: The Olympic Games in Tokyo attracted thousands of visitors despite the pandemic.,
Entity tag: EVENT,
Answer: Olympic Games

Sentence: The recipe calls for 200 grams of sugar and 3 eggs.,
Entity tag: NUMERICAL,
Answer: 200 grams

Sentence: Google has opened a new research center in Zurich to focus on AI development.,
Entity tag: ORGANIZATION,
Answer: Google

Sentence: The American have a long history of culinary excellence.,
Entity tag: NATIONALITY, RELIGIOUS, or POLITICAL GROUP,
Answer: American

Sentence: The Islam religion has over a billion followers worldwide.,
Entity tag: NATIONALITY, RELIGIOUS, or POLITICAL GROUP,
Answer: Islam

Sentence: {sentence}
Entity tag: {tag}
Answer:"""

In [None]:
import ast 

item = tv_data[2]
sentence = item["Sentence"].replace("\"","").replace(" .",".")
example_answer = item["Answer"].replace("\"","")
tag = item["POS tag"].replace("\"","")

prompt = template.format(sentence=sentence, tag=tag)
print(sentence)
print(example_answer)

utils.test_prompt(prompt, example_answer, model, prepend_bos=True, top_k=5)

Proponents of the funding arrangement predict that, based on recent filing levels of more than 2,000 a year, the fees will yield at least $ 40 million this fiscal year, or $ 10 million more than the budget cuts.
million
Tokenized prompt: ['<bos>', 'Sentence', ':', ' The', ' cat', ' chased', ' the', ' mouse', '.', '\n', 'POS', ' tag', ':', ' noun', '\n', 'Answer', ':', ' cat', '\n\n', 'Sentence', ':', ' She', ' went', ' to', ' the', ' market', '.', '\n', 'POS', ' tag', ':', ' pronoun', '\n', 'Answer', ':', ' She', '\n\n', 'Sentence', ':', ' Birds', ' fly', ' across', ' the', ' sky', '.', '\n', 'POS', ' tag', ':', ' verb', '\n', 'Answer', ':', ' fly', '\n\n', 'Sentence', ':', ' He', ' quickly', ' closed', ' the', ' door', '.', '\n', 'POS', ' tag', ':', ' adverb', '\n', 'Answer', ':', ' quickly', '\n\n', 'Sentence', ':', ' The', ' tall', ' man', ' opened', ' the', ' door', '.', '\n', 'POS', ' tag', ':', ' adjective', '\n', 'Answer', ':', ' tall', '\n\n', 'Sentence', ':', ' Pro', 'ponents'

Top 0th token. Logit: 21.48 Prob: 48.65% Token: | |
Top 1th token. Logit: 20.69 Prob: 21.98% Token: | $|
Top 2th token. Logit: 19.04 Prob:  4.23% Token: | more|
Top 3th token. Logit: 19.02 Prob:  4.14% Token: | million|
Top 4th token. Logit: 18.73 Prob:  3.10% Token: |

|


In [None]:
# pip install nltk pandas
import json
import re
from typing import List, Dict, Tuple
import pandas as pd
import nltk

# Make sure these are available once in your environment:
import nltk

def ensure_nltk_data():
    # Where NLTK will look; add a custom path if you want a local cache
    # nltk.data.path.append("/path/to/nltk_data")

    checks = [
        ("taggers/averaged_perceptron_tagger", "averaged_perceptron_tagger"),
        ("taggers/averaged_perceptron_tagger_eng", "averaged_perceptron_tagger_eng"),
        ("tokenizers/punkt", "punkt"),
        ("tokenizers/punkt_tab", "punkt_tab"),                  # NLTK >= 3.8
        ("tokenizers/punkt_tab/english", "punkt_tab"),          # some installs expose language subdir
        ("chunkers/maxent_ne_chunker_tab", "maxent_ne_chunker_tab"),
        ("corpora/words", "words"),
    ]
    for resource_path, package in checks:
        try:
            nltk.data.find(resource_path)
        except LookupError:
            try:
                nltk.download(package, quiet=False)
            except Exception as e:
                print(f"Failed to download {package}: {e}")

ensure_nltk_data()

COARSE_TAGS = [
    "LOCATION","PERSON","ORGANIZATION", "NULL"
]


# Penn tag -> coarse tag mapping
TAG_TO_COARSE = {
    "O": None,            # not one of the coarse types
    "B-MISC": None,
    "I-MISC": None,

    "B-PER": "PERSON",
    "I-PER": "PERSON",

    "B-ORG": "ORGANIZATION",
    "I-ORG": "ORGANIZATION",

    "B-LOC": "LOCATION",
    "I-LOC": "LOCATION",
}

TOKENIZER_PUNCT_RE = re.compile(r"^\W+$")

def tokenize_and_tag(sentence: str) -> List[Tuple[str, str]]:
    """
    Returns list of (token, IOB-tag) where IOB is in {O, B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC}.
    Mirrors your tokenize_and_tag style: drops pure punctuation before tagging.
    """
    # tokenize
    tokens = nltk.word_tokenize(sentence)
    # drop pure punctuation before tagging
    toks_for_tag = [t for t in tokens if not TOKENIZER_PUNCT_RE.match(t)]

    # POS tag → ne_chunk → IOB triples
    pos = nltk.pos_tag(toks_for_tag)
    tree = nltk.ne_chunk(pos, binary=False)
    iob_triples = nltk.tree2conlltags(tree)  # (token, POS, IOB-NLTK)

    def _normalize_iob(iob: str) -> str:
        if iob == "O":
            return "O"
        pref, label = iob.split("-", 1)  # e.g., "B", "PERSON"
        short = TAG_TO_COARSE.get(label, "MISC")
        return f"{pref}-{short}"

    tagged = [(tok, _normalize_iob(iob)) for tok, _, iob in iob_triples]
    return tagged

def coarse_bucket(tag: str) -> str:
    return TAG_TO_COARSE.get(tag, "NULL")

def annotate_row(sentence: str) -> Dict[str, object]:
    tagged = tokenize_and_tag(sentence)

    # Initialize buckets
    data = {"Sentence": sentence}
    for c in COARSE_TAGS:
        data[c] = []
        data[f"{c}_entity"] = []

    # Fill buckets
    for idx, (tok, penntag) in enumerate(tagged):
        coarse = coarse_bucket(penntag)
        data[coarse].append(tok)
        data[f"{coarse}_entity"].append(idx)

    return data

def build_dataframe(items: List[Dict[str, str]]) -> pd.DataFrame:
    rows = []
    for obj in items:
        sent = obj["Sentence"]
        rows.append(annotate_row(sent))
    # Ensure all columns exist and are ordered as requested
    ordered_cols = (
        ["location", "person", "organization",
         "location_pos", "person_pos", "organization_pos"] # position
    )
    df = pd.DataFrame(rows)
    # In case some columns are missing (no instances), add them as empty lists
    for col in ordered_cols:
        if col not in df:
            df[col] = [[] for _ in range(len(df))]
    return df[ordered_cols]


df = build_dataframe(tv_data)
print(df.head(1))
print(df.columns.tolist(), f"[{len(df)} rows x {len(df.columns)} columns]")


[nltk_data] Downloading package punkt to /home/basan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/basan/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


[('In', 'O'), ('the', 'O'), ('wake', 'O'), ('of', 'O'), ('the', 'O'), ('earthquake', 'O'), ('in', 'O'), ('California', 'B-MISC'), ('and', 'O'), ('the', 'O'), ('devastation', 'O'), ('of', 'O'), ('Hurricane', 'B-MISC'), ('Hugo', 'I-MISC'), ('many', 'O'), ('companies', 'O'), ('in', 'O'), ('disaster', 'O'), ('prone', 'O'), ('areas', 'O'), ('are', 'O'), ('pondering', 'O'), ('the', 'O'), ('question', 'O'), ('of', 'O'), ('preparedness', 'O')]


KeyError: None

In [6]:
print(df.shape)
df.head(3)

(180, 6)


Unnamed: 0,location,person,organization,location_pos,person_pos,organization_pos
0,[],[],[],[],[],[]
1,[],[],[],[],[],[]
2,[],[],[],[],[],[]


Get logit diff on 5 examples to test if we are able to get positive logit diff when we measure gold token vs. other words in the sentence.

In [None]:
for i in range(0, 20, 1):

    for tag in tag_set:

        item = df.iloc[i]
    
        if item[tag] == "[]":
            continue
        
        sentence = item["Sentence"].replace("\"","").replace(" .",".").replace("[", "").replace("%", "")
        tag = tag
        prompt = template.format(sentence=sentence, tag=tag)
        if not isinstance(item[tag], list):
            clean_answer = ast.literal_eval(item[tag])[0].replace("\"","")
        else:
            if item[tag] == []:
                continue
            clean_answer = item[tag][0].replace("\"","")
        clean_tokens = model.to_tokens(clean_answer, prepend_bos=False)
        if clean_tokens.shape[1] > 1:
            continue 

        corr_words_list = []
        for other_tag in coarse_tags:
            if other_tag != tag:

                if not isinstance(item[other_tag], list):
                    item[other_tag] = ast.literal_eval(item[other_tag])
                else:
                    if item[other_tag] == []:
                        continue
                    
                for corr in item[other_tag]:
                    
                    if "[" in corr or "'" in corr:
                        continue
                    else:
                        corr_words_list.append(corr.replace("\"",""))

        answer_tokens_list = []
        for word in corr_words_list:
            corr_tokens = model.to_tokens(word, prepend_bos=False)
            if corr_tokens.shape[1] > 1:
                continue
            answer_tokens_list.append(torch.stack([clean_tokens, corr_tokens], dim=-1).long())

        prompt_tokens = model.to_tokens(prompt, prepend_bos=False)
        with torch.no_grad():
            logits = model(prompt_tokens)

        logit_diffs = 0
        for j, answer_token_pair in enumerate(answer_tokens_list):
            log_diff = get_logit_diff(logits, answer_token_pair)
            logit_diffs += log_diff
            # print(model.to_str_tokens(answer_token_pair))
            # print(log_diff, "\n")
            
        logit_diffs /= len(answer_tokens_list) if answer_tokens_list else 1
        print(f"Average logit differences for sentence {i} and tag {tag}: {logit_diffs}")
        
    if i == 5:
        break

if torch.cuda.is_available():
    torch.cuda.empty_cache()

Average logit differences for sentence 0 and tag noun: 4.996672630310059
Average logit differences for sentence 0 and tag verb: 0.8935872912406921
Average logit differences for sentence 0 and tag adverb: 5.275399208068848
Average logit differences for sentence 0 and tag preposition: 0.27275362610816956
Average logit differences for sentence 0 and tag determiner: 5.34472131729126
Average logit differences for sentence 0 and tag conjunction: 0.29209867119789124
Average logit differences for sentence 0 and tag pronoun: 3.195990800857544
Average logit differences for sentence 0 and tag other: -0.06074707210063934
Average logit differences for sentence 1 and tag adjective: 3.6389198303222656
Average logit differences for sentence 1 and tag noun: -4.239579200744629
Average logit differences for sentence 1 and tag verb: 2.138639450073242
Average logit differences for sentence 1 and tag preposition: -1.855430245399475
Average logit differences for sentence 1 and tag conjunction: -2.61563301086

In [None]:
!nvidia-smi

Sat Sep  6 16:10:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-PCIE-32GB           On  |   00000000:37:00.0 Off |                    0 |
| N/A   31C    P0             26W /  250W |       3MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-PCIE-32GB           On  |   00

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


We are going to test the logit diff for clean and corrupted prompts.

In [None]:
# Prepare corrupted and clean prompts
prompts = []

for i in range(len(tv_data)):

    for tag in tag_set:

        item = df.iloc[i]
        print(item)
    
        if item[tag] == "[]":
            continue
        
        sentence = item["Sentence"].replace("\"","").replace(" .",".").replace("[", "").replace("%", "")
        tag = tag
        prompt = template.format(sentence=sentence, tag=tag)
        if not isinstance(item[tag], list):
            clean_answer = ast.literal_eval(item[tag])[0].replace("\"","")
        else:
            if item[tag] == []:
                continue
            clean_answer = item[tag][0].replace("\"","")
        clean_tokens = model.to_tokens(clean_answer, prepend_bos=False)
        
        if clean_tokens.shape[1] > 1:
            continue 

        corr_words_list = []
        for other_tag in coarse_tags:
            if other_tag != tag:

                if not isinstance(item[other_tag], list):
                    item[other_tag] = ast.literal_eval(item[other_tag])
                else:
                    if item[other_tag] == []:
                        continue
                    
                for corr in item[other_tag]:
                    
                    if "[" in corr or "'" in corr:
                        continue
                    else:
                        corr_words_list.append(corr.replace("\"",""))

        answer_tokens_list = []
        answers_list = []
        for word in corr_words_list:
            corr_tokens = model.to_tokens(word, prepend_bos=False)
            
            if corr_tokens.shape[1] > 1:
                continue
            
            answer_tokens_list.append(torch.stack([clean_tokens, corr_tokens], dim=-1).long())
            answers_list.append((clean_answer, word.replace("\"","")))

        prompts.append(
            {
                "clean_input":template.format(sentence=sentence, tag=tag),
                "corr_input":template.format(sentence=sentence, tag="null"),
                "answers_input_list": answers_list,
                "clean_prompt":model.to_tokens(template.format(sentence=sentence, tag=tag), prepend_bos=False),
                "corr_prompt":model.to_tokens(template.format(sentence=sentence, tag="null"), prepend_bos=False),
                "answers_list": answer_tokens_list
            }
        )

Sentence           Furthermore that period of time might still be...
adjective                                                         []
adverb                         [Furthermore, still, very, long, n't]
conjunction                                                    [and]
determiner                                                    [that]
noun                             [period, time, something, anything]
number                                                            []
other                                                   [might, can]
particle                                                          []
preposition                                              [of, about]
pronoun                                                  [I, it, we]
verb                                         [be, do, think, 's, do]
adjective_pos                                                     []
adverb_pos                                          [0, 6, 8, 9, 13]
conjunction_pos                   

In [None]:
def logit_diff_denoising(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch n_pairs 2"],
    flipped_logit_diff: float,
    clean_logit_diff: float,
    return_tensor: bool = False,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = get_logit_diff(logits, answer_tokens)
    ld = ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff  - flipped_logit_diff))
    if return_tensor:
        return ld
    else:
        return ld.item()


def logit_diff_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float,
        corrupted_logit_diff: float,
        answer_tokens: Float[Tensor, "batch n_pairs 2"],
        return_tensor: bool = False,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = get_logit_diff(logits, answer_tokens)
        ld = ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff))

        if return_tensor:
            return ld
        else:
            return ld.item()

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

BATCH_SIZE = 1 # adjust based on your memory

all_results = []
clean_logit_diffs = []
corrupt_logit_diffs = []

def _to_cpu_detached(x):
    """Recursively move tensors to CPU and detach (handles tensors, lists, tuples, dicts)."""
    if torch.is_tensor(x):
        return x.detach().to("cpu")
    if isinstance(x, dict):
        return {k: _to_cpu_detached(v) for k, v in x.items()}
    if isinstance(x, (list, tuple)):
        t = [_to_cpu_detached(v) for v in x]
        return type(x)(t)
    return x


#for i in tqdm(range(0, len(prompts), BATCH_SIZE)):
for i in tqdm(range(0, 50, BATCH_SIZE)):
    model.reset_hooks(including_permanent=True)
    batch = prompts[i : i + BATCH_SIZE]

    clean_rows = [p["clean_prompt"].squeeze(0) for p in batch]
    corrupted_rows = [p["corr_prompt"].squeeze(0)  for p in batch]

    # print(batch[0]["clean_input"])
    # print(batch[0]["corr_input"])
    # print(batch[0]["answers_input_list"])

    pad_id = getattr(model.tokenizer, "pad_token_id", model.tokenizer.eos_token_id)
    clean_tokens = pad_sequence(clean_rows, batch_first=True, padding_value=pad_id).to(device)
    corrupted_tokens = pad_sequence(corrupted_rows, batch_first=True, padding_value=pad_id).to(device)

    if clean_tokens.shape[1] != corrupted_tokens.shape[1]:
        print("Skipping batch due to different lengths of clean and corrupted tokens")
        continue

    with torch.no_grad():
        clean_logits, clean_cache = model.run_with_cache(clean_tokens, return_type="logits")
    with torch.no_grad():
        corr_logits, corr_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
        del corr_cache

    answer_tokens = torch.stack([
    torch.stack([t.squeeze() for t in p["answers_list"]])  # [num_answers, 2]
        for p in batch
    ])

    clean_logit_diff = get_logit_diff(clean_logits, answer_tokens)
    corr_logit_diff = get_logit_diff(corr_logits, answer_tokens)

    del clean_logits, corr_logits
    torch.cuda.empty_cache()

    patching_metric = partial(
        logit_diff_denoising,
        clean_logit_diff=clean_logit_diff,
        flipped_logit_diff=corr_logit_diff,
        answer_tokens=answer_tokens,
        return_tensor=True,
    )

    with torch.inference_mode():
        results = act_patch(
            model=model,
            orig_input=corrupted_tokens,
            new_cache=clean_cache,
            patching_nodes=IterNode("z", -1), # iterating over all heads' output in all layers
            patching_metric=patching_metric,
            verbose=True,
    )
    results_cpu = _to_cpu_detached(results)
    results_cpu["z"] = torch.stack(results_cpu["z"]).reshape(model.cfg.n_layers, model.cfg.n_heads)
    all_results.append(results_cpu)

    del results, results_cpu  # results_cpu is on CPU; safe to drop too
    del clean_cache, clean_logit_diff, corr_logit_diff
    del clean_tokens, corrupted_tokens
    torch.cuda.empty_cache()

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)
Skipping batch due to different lengths of clean and corrupted tokens


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)
Skipping batch due to different lengths of clean and corrupted tokens


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)
Skipping batch due to different lengths of clean and corrupted tokens


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)
Skipping batch due to different lengths of clean and corrupted tokens


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)
Skipping batch due to different lengths of clean and corrupted tokens


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)
Skipping batch due to different lengths of clean and corrupted tokens


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=18, head=8)
Skipping batch due to different lengths of clean and corrupted tokens


In [None]:
#assert results.keys() == {"z", "q", "k", "v", "pattern"}

imshow_p(
    torch.stack([d['z'] for d in all_results]).mean(dim=0).unsqueeze(0) * 100,
    facet_col=0,
    facet_labels=["Output (z)"], #["Output", "Query", "Key", "Value", "Pattern"],
    title="Patching output of attention heads (denoising patch clean to corrupted)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=500,
    margin={"r": 100, "l": 100}
)