<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/kmeng01/rome rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [None]:
IS_COLAB = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/rome")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [None]:
%load_ext autoreload
%autoreload 2

The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [1]:
import sys
sys.path.append("/home/hthakur/model_editing/rome")
import os, re, json
import torch, numpy
from collections import defaultdict
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import CounterFactDataset

torch.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7fc5042167f0>

Now we load a model and tokenizer, and show that it can complete a couple factual statements correctly.

In [2]:
model_name = "gpt2-xl"  # or "EleutherAI/gpt-j-6B" or "EleutherAI/gpt-neox-20b"
mt = ModelAndTokenizer(
    model_name,
    torch_dtype=(torch.float16 if "20b" in model_name else None),
)

In [None]:
predict_token(
    mt,
    ["Megan Rapinoe plays the sport of", "The Space Needle is in the city of"],
    return_p=True,
)

To obfuscate the subject during Causal Tracing, we use noise sampled from a zero-centered spherical Gaussian, whose stddev is 3 times the $\sigma$ stddev the model's embeddings. Let's compute that value.

In [3]:
from tqdm import tqdm
def collect_embedding_std(mt, subjects):
    alldata = []
    for i, s in tqdm(enumerate(subjects), total=len(subjects)):
        inp = make_inputs(mt.tokenizer, [s])
        with nethook.Trace(mt.model, layername(mt.model, 0, "embed")) as t:
            mt.model(**inp)
            alldata.append(t.output[0])
    alldata = torch.cat(alldata)
    noise_level = alldata.std().item()
    return noise_level

In [50]:
import random

In [None]:
noise_level = 3 * collect_embedding_std(mt, [k["requested_rewrite"]["subject"] for k in [knowns[x] for x in indices]])

In [5]:
knowns = CounterFactDataset(DATA_DIR)  # Dataset of known facts
# noise_level = 3 * collect_embedding_std(mt, [k["requested_rewrite"]["subject"] for k in kk])


Loaded dataset with 21919 elements


In [5]:
noise_level

0.1351652629673481

In [52]:
knowns = CounterFactDataset(DATA_DIR)  # Dataset of known facts
indices = random.sample(range(len(knowns)), 5000)  # replace 100 with the size of your array
noise_level = 3 * collect_embedding_std(mt, [k["requested_rewrite"]["subject"] for k in [knowns[x] for x in indices]])
print(f"Using noise level {noise_level}")

Loaded dataset with 21919 elements


100%|██████████| 5000/5000 [02:17<00:00, 36.47it/s]

Using noise level 0.1353171207010746





## Tracing a single location

The core intervention in causal tracing is captured in this function:

`trace_with_patch` a single causal trace.

It enables running a batch of inferences with two interventions.

  1. Random noise can be added to corrupt the inputs of some of the batch.
  2. At any point, clean non-noised state can be copied over from an
     uncorrupted batch member to other batch members.
  
The convention used by this function is that the zeroth element of the
batch is the uncorrupted run, and the subsequent elements of the batch
are the corrupted runs.  The argument tokens_to_mix specifies an
be corrupted by adding Gaussian noise to the embedding for the batch
inputs other than the first element in the batch.  Alternately,
subsequent runs could be corrupted by simply providing different
input tokens via the passed input batch.

To ensure that corrupted behavior is representative, in practice, we
will actually run several (ten) corrupted runs in the same batch,
each with its own sample of noise.

Then when running, a specified set of hidden states will be uncorrupted
by restoring their values to the same vector that they had in the
zeroth uncorrupted run.  This set of hidden states is listed in
states_to_patch, by listing [(token_index, layername), ...] pairs.
To trace the effect of just a single state, this can be just a single
token/layer pair.  To trace the effect of restoring a set of states,
any number of token indices and layers can be listed.

Note that this function is also in experiments.causal_trace; the code
is shown here to show the logic.

In [6]:
def trace_with_patch(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    answers_t,  # Answer probabilities to collect
    tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    noise=0.1,  # Level of noise to add
    trace_layers=None,  # List of traced outputs to return
):
    prng = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    patch_spec = defaultdict(list)
    for t, l in states_to_patch:
        patch_spec[l].append(t)
    embed_layername = layername(model, 0, "embed")

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer == embed_layername:
            # If requested, we corrupt a range of token embeddings on batch items x[1:]
            if tokens_to_mix is not None:
                b, e = tokens_to_mix
                x[1:, b:e] += noise * torch.from_numpy(
                    prng.randn(x.shape[0] - 1, e - b, x.shape[2])
                ).to(x.device)
            return x
        if layer not in patch_spec:
            return x
        # If this layer is in the patch_spec, restore the uncorrupted hidden state
        # for selected tokens.
        h = untuple(x)
        for t in patch_spec[layer]:
            h[1:, t] = h[0, t]
        return x

    # With the patching rules defined, run the patched model in inference.
    additional_layers = [] if trace_layers is None else trace_layers
    with torch.no_grad(), nethook.TraceDict(
        model,
        [embed_layername] + list(patch_spec.keys()) + additional_layers,
        edit_output=patch_rep,
    ) as td:
        outputs_exp = model(**inp)

    # We report softmax probabilities for the answers_t token predictions of interest.
    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    # If tracing all layers, collect all activations together to return.
    if trace_layers is not None:
        all_traced = torch.stack(
            [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2
        )
        return probs, all_traced

    return probs

## Scanning all locations

A causal flow heatmap is created by repeating `trace_with_patch` at every individual hidden state, and measuring the impact of restoring state at each location.

The `calculate_hidden_flow` function does this loop.  It handles both the case of restoring a single hidden state, and also restoring MLP or attention states.  Because MLP and attention make small residual contributions, to observe a causal effect in those cases, we need to restore several layers of contributions at once, which is done by `trace_important_window`.

In [7]:
def calculate_hidden_flow(
    mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None
):
    """
    Runs causal tracing over every token/layer combination in the network
    and returns a dictionary numerically summarizing the results.
    """
    inp = make_inputs(mt.tokenizer, [prompt] * (samples + 1))
    with torch.no_grad():
        answer_t, base_score = [d[0] for d in predict_from_input(mt.model, inp)]
    [answer] = decode_tokens(mt.tokenizer, [answer_t])
    e_range = find_token_range(mt.tokenizer, inp["input_ids"][0], subject)
    low_score = trace_with_patch(
        mt.model, inp, [], answer_t, e_range, noise=noise
    ).item()
    if not kind:
        differences = trace_important_states(
            mt.model, mt.num_layers, inp, e_range, answer_t, noise=noise
        )
    else:
        differences = trace_important_window(
            mt.model,
            mt.num_layers,
            inp,
            e_range,
            answer_t,
            noise=noise,
            window=window,
            kind=kind,
        )
    differences = differences.detach().cpu()
    return dict(
        scores=differences,
        low_score=low_score,
        high_score=base_score,
        input_ids=inp["input_ids"][0],
        input_tokens=decode_tokens(mt.tokenizer, inp["input_ids"][0]),
        subject_range=e_range,
        answer=answer,
        window=window,
        kind=kind or "",
    )


def trace_important_states(model, num_layers, inp, e_range, answer_t, noise=0.1):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            r = trace_with_patch(
                model,
                inp,
                [(tnum, layername(model, layer))],
                answer_t,
                tokens_to_mix=e_range,
                noise=noise,
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)


def trace_important_window(
    model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1
):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            layerlist = [
                (tnum, layername(model, L, kind))
                for L in range(
                    max(0, layer - window // 2), min(num_layers, layer - (-window // 2))
                )
            ]
            r = trace_with_patch(
                model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)

## Plotting the results

The `plot_trace_heatmap` function draws the data on a heatmap.  That function is not shown here; it is in `experiments.causal_trace`.


In [8]:
from matplotlib import pyplot as plt

def plot_trace_heatmap(result, savepdf=None, title=None, xlabel=None, modelname=None):
    differences = result["scores"]
    low_score = result["low_score"]
    answer = result["answer"]
    kind = (
        None
        if (not result["kind"] or result["kind"] == "None")
        else str(result["kind"])
    )
    window = result.get("window", 10)
    labels = list(result["input_tokens"])
    for i in range(*result["subject_range"]):
        labels[i] = labels[i] + "*"

    with plt.rc_context(rc={"font.family": "serif", "font.size": 8}):
        fig, ax = plt.subplots(figsize=(3.5, 2), dpi=200)
        h = ax.pcolor(
            differences,
            cmap={None: "Purples", "None": "Purples", "mlp": "Greens", "attn": "Reds"}[
                kind
            ],
            vmin=low_score,
        )
        ax.invert_yaxis()
        ax.set_yticks([0.5 + i for i in range(len(differences))])
        ax.set_xticks([0.5 + i for i in range(0, differences.shape[1] - 6, 5)])
        ax.set_xticklabels(list(range(0, differences.shape[1] - 6, 5)))
        ax.set_yticklabels(labels)
        if not modelname:
            modelname = "GPT"
        if not kind:
            ax.set_title("Impact of restoring state after corrupted input")
            ax.set_xlabel(f"single restored layer within {modelname}")
        else:
            kindname = "MLP" if kind == "mlp" else "Attn"
            ax.set_title(f"Impact of restoring {kindname} after corrupted input")
            ax.set_xlabel(f"center of interval of {window} restored {kindname} layers")
        cb = plt.colorbar(h)
        if title is not None:
            ax.set_title(title)
        if xlabel is not None:
            ax.set_xlabel(xlabel)
        elif answer is not None:
            # The following should be cb.ax.set_xlabel, but this is broken in matplotlib 3.5.1.
            cb.ax.set_title(f"p({str(answer).strip()})", y=-0.16, fontsize=10)
        if savepdf:
            os.makedirs(os.path.dirname(savepdf), exist_ok=True)
            plt.savefig(savepdf, bbox_inches="tight")
            plt.close()
        else:
            plt.show()

In [33]:
def plot_hidden_flow(
    mt,
    prompt,
    subject=None,
    samples=10,
    noise=0.1,
    window=10,
    kind=None,
    modelname=None,
    savepdf=None,
):
    if subject is None:
        subject = guess_subject(prompt)
    result = calculate_hidden_flow(
        mt, prompt, subject, samples=samples, noise=noise, window=window, kind=kind
    )
    plot_trace_heatmap(result, savepdf, modelname=modelname)


def plot_all_flow(mt, prompt, subject=None, noise=0.1, modelname=None, savepdf=None):
    for kind in [None]:
        plot_hidden_flow(
            mt, prompt, subject, modelname=modelname, noise=noise, kind=kind, savepdf=savepdf
        )

In [2]:
bad = [(98, 193),
(99, 343),
(30, 341),
(6, 1416),
(1, 173),
(99, 531),
(5, 1420),
(0, 266),
(11, 240),
(20, 1508),
(8, 946)]
bad_ids = [i[1] for i in bad]

In [9]:

for i in knowns:
    for j in bad:
        if j[1] == i["case_id"]:
            prompt  = (i["requested_rewrite"]["prompt"]).format((i["requested_rewrite"]["subject"]))
            target_new = i["requested_rewrite"]["target_new"]["str"]
            target_old = i["requested_rewrite"]["target_true"]["str"]
            sent = "{} {} -> {}".format(prompt, target_old, target_new)
            print("Edit: {} | Impact: {}/1000 correct".format(sent, j[0]))

Edit: Chicago is a twin city of Warsaw -> Istanbul | Impact: 1/1000 correct
Edit: Oslo is a twin city of Copenhagen -> Tehran | Impact: 98/1000 correct
Edit: Arthur is located in Illinois -> California | Impact: 11/1000 correct
Edit: Moscow is a twin city of Amsterdam -> Miami | Impact: 0/1000 correct
Edit: Olot, located in Spain -> India | Impact: 30/1000 correct
Edit: Jennings can be found in Louisiana -> Maryland | Impact: 99/1000 correct
Edit: Junnar, which is located in India -> Belarus | Impact: 99/1000 correct
Edit: Bay, which is located in Philippines -> Italy | Impact: 8/1000 correct
Edit: Manchester is a twin city of Amsterdam -> Munich | Impact: 6/1000 correct
Edit: Life was originally aired on NBC -> HBO | Impact: 5/1000 correct
Edit: Japan, in Asia -> Antarctica | Impact: 20/1000 correct


In [36]:
with tqdm(total=len(bad_ids)) as pbar:
    for idx, item in enumerate(knowns):
        case_id = item["case_id"]
        if case_id in bad_ids:
            print(case_id)
            prompt = item["requested_rewrite"]["prompt"].replace("{}", item["requested_rewrite"]["subject"])
            plot_all_flow(mt, prompt, subject=item["requested_rewrite"]["subject"], savepdf=f"/home/hthakur/model_editing/rome/experiments/plots/causal/bad_{case_id}.pdf", noise=noise_level)
            pbar.update(1) 
        else:
            continue

  0%|          | 0/11 [00:00<?, ?it/s]

173


  9%|▉         | 1/11 [00:13<02:10, 13.10s/it]

193


 18%|█▊        | 2/11 [00:28<02:11, 14.59s/it]

240


 27%|██▋       | 3/11 [00:35<01:27, 10.90s/it]

266


 36%|███▋      | 4/11 [00:48<01:22, 11.83s/it]

341


 45%|████▌     | 5/11 [00:57<01:03, 10.66s/it]

343


 55%|█████▍    | 6/11 [01:10<00:57, 11.56s/it]

531


 64%|██████▎   | 7/11 [01:28<00:55, 13.79s/it]

946


 73%|███████▎  | 8/11 [01:42<00:40, 13.61s/it]

1416


 82%|████████▏ | 9/11 [01:55<00:26, 13.49s/it]

1420


 91%|█████████ | 10/11 [02:03<00:11, 11.97s/it]

1508


100%|██████████| 11/11 [02:08<00:00, 11.70s/it]


The following prompt can be changed to any factual statement to trace.

In [21]:
with open("/home/hthakur/model_editing/rome/experiments/both.txt", "r") as f:
    both = f.readlines()
    both = [eval(i.strip()) for i in both]

both = [(x[1], x[0]) for x in both]

In [26]:
import warnings
warnings.filterwarnings("ignore")
import random
random.seed(123)

indices = list(range(1500)) # random.sample(range(len(knowns)), len(knowns))  # replace 100 with the size of your array
# indices = list(range(11))
bad_ids = [i[1] for i in bad] + both
good_ids = list(set(indices) - set(bad_ids))
random.shuffle(good_ids)
good_ids = good_ids[:30]
cntr = 0
with tqdm(total=len(good_ids)) as pbar:
    for i in good_ids:
        item = knowns[i]
        case_id = item["case_id"]
        if case_id in good_ids:
            print(case_id)
            prompt = item["requested_rewrite"]["prompt"].replace("{}", item["requested_rewrite"]["subject"])
            plot_all_flow(mt, prompt, subject=item["requested_rewrite"]["subject"], savepdf=f"/home/hthakur/model_editing/rome/experiments/plots/causal/holdout_{case_id}.pdf", noise=noise_level)
            pbar.update(1)
        else:
            continue

  0%|          | 0/30 [00:00<?, ?it/s]

931


  3%|▎         | 1/30 [00:21<10:28, 21.66s/it]

655


  7%|▋         | 2/30 [00:50<11:57, 25.61s/it]

1036


 10%|█         | 3/30 [01:12<10:52, 24.18s/it]

636


 13%|█▎        | 4/30 [01:28<09:01, 20.81s/it]

1028


 17%|█▋        | 5/30 [01:53<09:24, 22.57s/it]

714


 20%|██        | 6/30 [02:09<08:05, 20.24s/it]

635


 23%|██▎       | 7/30 [02:25<07:12, 18.81s/it]

638


 23%|██▎       | 7/30 [02:29<08:10, 21.32s/it]


KeyboardInterrupt: 

In [68]:
plot_all_flow(mt, "The Eiffel Tower is located in", subject="Eiffel Tower", savepdf=f"/home/hthakur/model_editing/rome/experiments/plots/causal/good_rome.pdf", noise=noise_level)

Here we trace a few more factual statements from a file of test cases.

In [None]:
for knowledge in knowns[:5]:
    plot_all_flow(mt, knowledge["prompt"], knowledge["subject"], noise=noise_level)