# Script to plot average causal effects

This script loads sets of hundreds of causal traces that have been computed by the
`experiment.causal_trace` program, and then aggregates the results to compute
Average Indirect Effects and Average Total Effects as well as some other information.


In [None]:
import json
import re
import sys
from functools import lru_cache
from pathlib import Path

import numpy, os
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

sys.path.append("/raid/lingo/dez/code/rome")
from experiments import causal_trace

plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

# Uncomment the architecture to plot.
# arch = "gpt2-xl"
# archname = "GPT-2-XL"

# arch = 'EleutherAI_gpt-j-6B'
# archname = 'GPT-J-6B'

# arch = 'EleutherAI_gpt-neox-20b'
# archname = 'GPT-NeoX-20B'


class Avg:
    def __init__(self):
        self.d = []

    def add(self, v):
        self.d.append(v[None])

    def add_all(self, vv):
        self.d.append(vv)

    def avg(self):
        return numpy.concatenate(self.d).mean(axis=0)

    def std(self):
        return numpy.concatenate(self.d).std(axis=0)

    def size(self):
        return sum(datum.shape[0] for datum in self.d)


@lru_cache(maxsize=None)
def get_tokenizer(arch="gpt2"):
    import transformers
    return transformers.AutoTokenizer.from_pretrained(arch)

def get_raws(dataset="winoventi", mediation="med", corrupted="subj_last"):
    file_full_subj = Path(f"../data/mediation/full/{dataset}_med_subj_last.json")
    with file_full_subj.open("r") as handle:
        raws_full_subj = json.load(handle)

    file_full_attr = Path(f"../data/mediation/full/{dataset}_med_attr.json")
    with file_full_attr.open("r") as handle:
        raws_full_attr = json.load(handle)

    raws_full_subj_by_prompt = {raw["prompt"]: raw for raw in raws_full_subj}
    raws_full_attr_by_prompt = {raw["prompt"]: raw for raw in raws_full_attr}
    for prompt, raw in raws_full_subj_by_prompt.items():
        raw["attribute"] = raws_full_attr_by_prompt[prompt]["subject"]
    
    # Now load subset used during causal tracing.
    file_ct = Path(f"../data/mediation/{dataset}_{mediation}_{corrupted}.json")
    with file_ct.open("r") as handle:
        raws = json.load(handle)
    
    raws_by_known_id = {
        raw["known_id"]: raws_full_subj_by_prompt[raw["prompt"]]
        for raw in raws
    }
    return raws_by_known_id


def read_knowlege(kind=None,
                  arch="gpt2",
                  dataset="winoventi",
                  mediation="med",
                  corrupted="subj_first"):
    dirname = Path(f"../results/ns3_r0_{arch}/{dataset}_{mediation}_{corrupted}/causal_trace/cases")
    kindcode = "" if not kind else f"_{kind}"
    pattern = re.compile(f"knowledge_(\d+){kindcode}.npz")
    tokenizer = get_tokenizer(arch.lower())
    size = 0
    raws_by_known_id = get_raws(dataset=dataset, mediation=mediation, corrupted=corrupted)
    (
        avg_subj_first,
        avg_after_subj_first,
        avg_attr,
        avg_after_attr,
        avg_subj_last,
        avg_after_subj_last,
        avg_ls,
    ) = [Avg() for _ in range(7)]
    for file in tqdm(Path(dirname).glob(f"knowledge_*{kindcode}.npz"), desc=f"kind={kind}"):
        try:
            data = numpy.load(str(file))
        except:
            continue
        # Only consider cases where the model begins with the correct prediction
        if "correct_prediction" in data and not data["correct_prediction"]:
            continue

        # Parse known ID, if it fails it means we're in case kindcode="" and file is
        # kind mlp/attn.
        match = pattern.match(file.name)
        if match is None:
            continue
        known_id = int(match.group(1))

        raw = raws_by_known_id[known_id]
        tokens_prompt = tokenizer(raw["prompt"], add_special_tokens=False).input_ids
        try:
            subj_first_i, subj_first_j = causal_trace.find_token_range(tokenizer, tokens_prompt, raw["subject"])
            subj_last_i, subj_last_j = causal_trace.find_token_range(tokenizer, tokens_prompt, raw["subject"],
                                                                     occurrence=raw["occurrence"])
            attr_i, attr_j = causal_trace.find_token_range(tokenizer, tokens_prompt, raw["attribute"])
        except:
            continue


        size += 1
        scores = data["scores"]
        avg_ls.add(data["low_score"])
        avg_subj_first.add_all(scores[subj_first_i:subj_first_j])
        avg_after_subj_first.add_all(scores[subj_first_j:attr_i])
        avg_attr.add_all(scores[attr_i:attr_j])
        avg_after_attr.add_all(scores[attr_j:subj_last_i])
        avg_subj_last.add_all(scores[subj_last_i:subj_last_j])
        avg_after_subj_last.add_all(scores[subj_last_j:])

    result = numpy.stack(
        [
            avg_subj_first.avg(),
            avg_after_subj_first.avg(),
            avg_attr.avg(),
            avg_after_attr.avg(),
            avg_subj_last.avg(),
            avg_after_subj_last.avg(),
        ]
    )
    result_std = numpy.stack(
        [
            avg_subj_first.std(),
            avg_after_subj_first.std(),
            avg_attr.std(),
            avg_after_attr.std(),
            avg_subj_last.std(),
            avg_after_subj_last.std(),
        ]
    )
    return dict(
        low_score=avg_ls.avg(), result=result, result_std=result_std, size=size
    )


def plot_array(
    differences,
    kind=None,
    savepdf=None,
    title=None,
    low_score=None,
    high_score=None,
    archname="GPT2",
    corrupted="subj_first",
):
    if low_score is None:
        low_score = differences.min()
    if high_score is None:
        high_score = differences.max()
    answer = "AIE"
    labels = [
        "First subj mention" + ("*" if corrupted == "subj_first" else ""),
        "Between subj and attr",
        "Attr" + ("*" if corrupted == "attr" else ""),
        "Between attr and subj",
        "Second subj mention" + ("*" if corrupted == "subj_last" else ""),
        "Last tokens",
    ]

    fig, ax = plt.subplots(figsize=(7, 4), dpi=200)
    h = ax.pcolor(
        differences,
        cmap={None: "Purples", "mlp": "Greens", "attn": "Reds"}[kind],
        vmin=low_score,
        vmax=high_score,
    )
    if title:
        ax.set_title(title)
    ax.invert_yaxis()
    ax.set_yticks([0.5 + i for i in range(len(differences))])
    ax.set_xticks([0.5 + i for i in range(0, differences.shape[1] - 6, 5)])
    ax.set_xticklabels(list(range(0, differences.shape[1] - 6, 5)))
    ax.set_yticklabels(labels)
    if kind is None:
        ax.set_xlabel(f"single patched layer within {archname}")
    else:
        ax.set_xlabel(f"center of interval of patched {kind} layers")
    cb = plt.colorbar(h)
    # The following should be cb.ax.set_xlabel(answer), but this is broken in matplotlib 3.5.1.
    if answer:
        cb.ax.set_title(str(answer).strip(), y=-0.16, fontsize=10)

    if savepdf:
        os.makedirs(os.path.dirname(savepdf), exist_ok=True)
        plt.savefig(savepdf, bbox_inches="tight")
    plt.show()

arch = "gpt2"
archname = "GPT-2"
dataset = "CounterFact"
mediation = "med"
corrupted = "attr"
high_score = None  # Scale all plots according to the y axis of the first plot

for kind in [None, "mlp", "attn"]:
    d = read_knowlege(kind=kind,
                      arch=arch,
                      dataset=dataset.lower(),
                      mediation=mediation,
                      corrupted=corrupted)
    count = d["size"]
    what = {
        None: "Indirect Effect of $h_i^{(l)}$",
        "mlp": "Indirect Effect of MLP",
        "attn": "Indirect Effect of Attn",
    }[kind]
    title = f"[{dataset}/{'Med' if mediation == 'med' else 'Unmed'}] Avg {what} over {count} prompts"
    result = numpy.clip(d["result"] - d["low_score"], 0, None)
    kindcode = "" if kind is None else f"_{kind}"
    if kind not in ["mlp", "attn"]:
        high_score = result.max()
    plot_array(
        result,
        kind=kind,
        title=title,
        low_score=0.0,
        high_score=high_score,
        archname=archname,
        corrupted=corrupted,
        savepdf=f"results/{arch}/causal_trace/summary_pdfs/rollup{kindcode}.pdf",
    )

## Plot line graph

To make confidence intervals visible, we plot the data as line graphs below.

In [None]:
import math

labels = [
    "First subject token",
    "Middle subject tokens",
    "Last subject token",
    "First subsequent token",
    "Further tokens",
    "Last token",
]
color_order = [0, 1, 2, 4, 5, 3]
x = None

cmap = plt.get_cmap("tab10")
fig, axes = plt.subplots(1, 3, figsize=(13, 3.5), sharey=True, dpi=200)
for j, (kind, title) in enumerate(
    [
        (None, "single hidden vector"),
        ("mlp", "run of 10 MLP lookups"),
        ("attn", "run of 10 Attn modules"),
    ]
):
    print(f"Reading {kind}")
    d = read_knowlege(225, kind, arch)
    for i, label in list(enumerate(labels)):
        y = d["result"][i] - d["low_score"]
        if x is None:
            x = list(range(len(y)))
        std = d["result_std"][i]
        error = std * 1.96 / math.sqrt(count)
        axes[j].fill_between(
            x, y - error, y + error, alpha=0.3, color=cmap.colors[color_order[i]]
        )
        axes[j].plot(x, y, label=label, color=cmap.colors[color_order[i]])

    axes[j].set_title(f"Average indirect effect of a {title}")
    axes[j].set_ylabel("Average indirect effect on p(o)")
    axes[j].set_xlabel(f"Layer number in {archname}")
    # axes[j].set_ylim(0.1, 0.3)
axes[1].legend(frameon=False)
plt.tight_layout()
plt.savefig(f"results/{arch}/causal_trace/summary_pdfs/lineplot-causaltrace.pdf")
plt.show()