In [1]:
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")


In [2]:
import os
import json
import glob
import copy
import pprint
import re
import ast

import bisect

import itertools
from collections import Counter
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

import seaborn as sns
from scipy.stats import gaussian_kde
from scipy.stats import ttest_ind

from IPython.display import display, Latex

import torch

from utils import read_json_file

In [3]:
from experiments import nethook
from experiments.tools import make_inputs
from experiments.utils import load_atlas
from experiments.dataset import KnownsDataset
from experiments.tools import (
    collect_embedding_std,
    calculate_hidden_flow,
    # plot_trace_heatmap,
    predict_token,
    find_token_ranges,
    prompt_segmenter
)



In [4]:
os.environ["WANDB_CACHE_DIR"] = f"../../caches/wandb"
os.environ["TRANSFORMERS_CACHE"]= f"../../.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = f"../../.cache/huggingface/datasets"

SIZE = "base"  # "base", "large"
QA_PROMPT_FORMAT = "question: {question} answer: <extra_id_0>"

reader_model_type = f"google/t5-{SIZE}-lm-adapt"
model_path = f"../data/atlas/models/atlas_nq/{SIZE}"
model, opt = load_atlas(reader_model_type, model_path, n_context=1, qa_prompt_format="question: {question} answer: <extra_id_0>")
type(model)


Some weights of the model checkpoint at facebook/contriever were not used when initializing Contriever: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


atlas.atlas.Atlas

In [5]:
def safe_division(a, b, do_log=False, threshold=None):
    threshold = threshold if threshold else np.finfo(np.float32).max/10.0
    return torch.clamp(a/b, max=threshold) if not do_log else (torch.log(a) - torch.log(b))


def calculate_te_ie(r, samples, experiment_type="a", do_log=True, threshold=1e-40):
    cr_cf_score = torch.clamp(r["cr_cf_score"], min=threshold)
    cr_ans_score = torch.clamp(r["cr_ans_score"], min=threshold)

    crr_cf_score = torch.clamp(r["crr_cf_score"], min=threshold)
    crr_ans_score = torch.clamp(r["crr_ans_score"], min=threshold)

    crwrr_cf_score = torch.clamp(r["crwrr_cf_score"], min=threshold)
    crwrr_ans_score = torch.clamp(r["crwrr_ans_score"], min=threshold)
    

    te, ie = [], []

    for i in range(samples):
        if experiment_type == "a" or experiment_type == "aa":
            te_i = safe_division(crr_cf_score[i], crr_ans_score[i], do_log=do_log) - \
                   safe_division(cr_cf_score[i], cr_ans_score, do_log=do_log)
            ie_i = safe_division(crwrr_cf_score[:, :, i], crwrr_ans_score[:, :, i], do_log=do_log) - \
                   safe_division(cr_cf_score[i], cr_ans_score, do_log=do_log)
        else:
            te_i = safe_division(cr_cf_score[i], cr_ans_score, do_log=do_log) - \
                   safe_division(crr_cf_score[i], crr_ans_score[i], do_log=do_log)
            ie_i = safe_division(crwrr_cf_score[:, :, i], crwrr_ans_score[:, :, i], do_log=do_log) - \
                   safe_division(crr_cf_score[i], crr_ans_score[i], do_log=do_log)

                

        te.append(te_i)
        ie.append(ie_i.unsqueeze(-1))

    te = torch.stack(te).mean()
    ie = torch.cat(ie, axis=-1).mean(-1)

    return te, ie


def calculate_post_proc(r, samples, experiment_type="a"):
    r = {k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v for k, v in r.items()}
    
    if r["status"]:
        te_log, ie_log = calculate_te_ie(r, samples=samples, experiment_type=experiment_type, do_log=True)
        r["te_log"] = te_log
        r["ie_log"] = ie_log

        te, ie = calculate_te_ie(r, samples=samples, experiment_type=experiment_type, do_log=False)
        r["te"] = te
        r["ie"] = ie
    else:
        r["te_log"] = None
        r["ie_log"] = None
        r["te"] = None
        r["ie"] = None

    r = {k: v.detach().cpu().numpy() if torch.is_tensor(v) else v for k, v in r.items()}

    return r


In [6]:
class Avg:
    def __init__(self, size=12, name=None):
        self.d = []
        self.size = size
        self.name = name

    def add(self, v):
        if v.size > 0:
            self.d.append(v[None])

    def add_all(self, vv):
        if vv.size > 0:
            self.d.append(vv)

    def avg(self):
        if len(self.d) > 0:
            non_empty_arrays = [arr for arr in self.d if arr.size > 0]
            
            if len(non_empty_arrays) > 0:
                return np.concatenate(non_empty_arrays).mean(axis=0)
            else:
                return np.zeros(self.size)

        return np.zeros(self.size)

    def std(self):
        if len(self.d) > 0:
            non_empty_arrays = [arr for arr in self.d if arr.size > 0]
            
            if len(non_empty_arrays) > 0:
                return np.concatenate(non_empty_arrays).std(axis=0)
            else:
                return np.zeros(self.size)

        return np.zeros(self.size)

    def size(self):
        return sum(datum.shape[0] for datum in self.d)

    def humanize(self):
        return self.name.replace("_", " ").capitalize()

    def __repr__(self):
        return f"<Avg name={self.name}>"


def plot_array(
    differences,
    ax,
    labels,
    kind=None,
    savepdf=None,
    title=None,
    low_score=None,
    high_score=None,
    archname="Atlas",
    show_y_labels=True
):
    if low_score is None:
        low_score = differences.min()
    if high_score is None:
        high_score = differences.max()
        
    answer = "AIE"
    labels = labels

    # fig, ax = plt.subplots(figsize=(3.5, 2), dpi=300)

    h = ax.pcolor(
        differences,
        cmap={None: "Purples", "mlp": "Greens", "attn": "Reds"}[kind],
        vmin=low_score,
        vmax=high_score,
    )
    
    if title:
        ax.set_title(title)
        
    ax.invert_yaxis()
    ax.set_yticks([0.5 + i for i in range(len(differences))])
    # ax.set_xticks([0.5 + i for i in range(0, differences.shape[1] - 6, 5)])
    # ax.set_xticklabels(list(range(0, differences.shape[1] - 6, 5)))
    ax.set_xticks([0.5 + i for i in range(0, differences.shape[1])])
    ax.set_xticklabels(list(range(0, differences.shape[1])))
    ax.set_yticklabels(labels)

    if show_y_labels:
        ax.set_yticklabels(labels)
    else:
        ax.set_yticklabels([])

    if kind is None:
        ax.set_xlabel(f"single patched layer")
        # ax.set_xlabel(f"single patched layer within {archname}")
    else:
        ax.set_xlabel(f"center of interval of 6 patched {kind} layers")
        # ax.set_xlabel(f"center of interval of 6 patched {kind} layers within {archname}")

    cb = plt.colorbar(h)
    # The following should be cb.ax.set_xlabel(answer), but this is broken in matplotlib 3.5.1.
    if answer:
        # cb.ax.set_title(str(answer).strip(), y=-0.16, fontsize=10)
        cb.ax.set_title(str(answer).strip(), y=-0.16)

    if savepdf:
        os.makedirs(os.path.dirname(savepdf), exist_ok=True)
        plt.savefig(savepdf, bbox_inches="tight")
    

def is_valid_attrs(attrs):
    # Check if attrs is a list with exactly two items
    if not isinstance(attrs, list) or len(attrs) != 2:
        return False
    
    # Check if each item in the list is a tuple with exactly two integers
    for item in attrs:
        if not isinstance(item, tuple) or len(item) != 2:
            return False
        if not all(isinstance(i, int) for i in item):
            return False

    return True
    
        
def healthy_r(r):
    logs = []

    if "scores" not in r:
        msg = f"no `scores`"
        logs.append(msg)
        return False, logs
    
    if "status" not in r or not r["status"]:
        msg = f"`status == False`"
        logs.append(msg)
        return False, logs
    
    attribute_loc = list(sorted(list(itertools.chain.from_iterable(r["attributes_loc"].values()))))
    if not attribute_loc:
        msg = f"no `attribute_loc`"
        logs.append(msg)
        return False, logs

    
    if np.isnan(r["te_log"]) or np.isinf(r["te_log"]):
        msg = f"no `te_log`"
        logs.append(msg)
        return False, logs

    
    if np.isnan(r["te"]) or np.isinf(r["te"]):
        msg = f"no `te`"
        logs.append(msg)
        return False, logs

    return True, logs


def find_insert_position(tuples_list, target_tuple):
    # Extract the end values of each tuple in the list
    ends = [t[1] for t in tuples_list]
    
    # Find the position using bisect_right
    position = bisect.bisect_right(ends, target_tuple[0])
    
    return position

def transform_to_title_case(input_string):
    words = input_string.split(" ")
    transformed_string = ' '.join(word.capitalize() for word in words)
    return transformed_string

def tokens_space_division(data, experiment_type, num_layers=12, do_log=True):
    avg_effects = [
        "question_tokens",

        # "answer",
        # "answer_token",
        
        "begining_of_context",

        "first_subject_token",
        "middle_subject_tokens",
        "last_subject_token",

        "context_in_between_tokens",

        "first_object_token",
        "middle_object_tokens",
        "last_object_token",

        "rest_of_context_tokens",

        "last_token",

        "subject_set_tokens",
        "object_set_tokens",
        "relation_set_tokens",
    ]
    avg_scores = ["high_score", "low_score", "te", "fixed_score"]
    avg = {name: Avg(size=1, name=name) for name in avg_scores}
    avg.update({name: Avg(size=num_layers, name=name) for name in avg_effects})

    result = np.array([])
    result_std = np.array([])

    for r in tqdm(data, total=len(data)):
        attribute_locs = list(sorted(list(itertools.chain.from_iterable(r["attributes_loc"].values()))))

        attribute_loc = attribute_locs[-1]
        start_of_answer = r["input_tokens"].index("▁answer")
        start_of_context = r["input_tokens"].index("<extra_id_0>") + 1
        start_of_attr, end_of_attr = attribute_loc
        end_of_prompt = len(r["input_tokens"])

        input_segments = prompt_segmenter([r["input_tokens"]])

        object_ranges = find_token_ranges(model.reader_tokenizer, r["input_ids"], r["cf"][0], bounds=input_segments[0]["context"])
        object_ranges = object_ranges[0] if isinstance(object_ranges, list) and len(object_ranges) > 0 else []

        subject_ranges = find_token_ranges(model.reader_tokenizer, r["input_ids"], r["prompt"]["subj"], bounds=input_segments[0]["context"])
        subject_ranges = subject_ranges[0] if isinstance(subject_ranges, list) and len(subject_ranges) > 0 else []

        if do_log:
            te, ie = r["te_log"], r["ie_log"]
        else:
            te, ie = r["te"], r["ie"]


        subject_pos = object_pos = 0
        if experiment_type == "c":
            try:
                subject_pos = find_insert_position(list(sorted(list(itertools.chain.from_iterable(r["attributes_loc"].values())))), subject_ranges)
                object_pos = find_insert_position(list(sorted(list(itertools.chain.from_iterable(r["attributes_loc"].values())))), object_ranges)
            except:
                continue

        
        if experiment_type == "a":
            attrs = [subject_ranges, attribute_loc]
        elif experiment_type == "b":
            attrs = [attribute_loc, object_ranges]
        elif experiment_type == "c":
            attrs = [subject_ranges, object_ranges]

        
        if not is_valid_attrs(attrs):
            continue


        avg["high_score"].add(np.array(r["cr_ans_score"]))
        avg["low_score"].add(np.array(r["crr_score"]))
        avg["te"].add(np.array(te))
        avg["fixed_score"].add(ie.max())

        avg["question_tokens"].add_all(ie[0:start_of_answer])

        # avg["answer"].add_all(ie[start_of_answer:start_of_answer+3])
        # avg["answer_token"].add(ie[start_of_answer+3])
        avg["last_token"].add(ie[-1])

        if experiment_type == "a":
            attrs = [subject_ranges, attribute_loc]
            first_attr = attrs.index(min(attrs))
            second_attr = 0 if first_attr == 1 else 1

            avg["begining_of_context"].add_all(ie[start_of_context+4:attrs[first_attr][0]])
            avg["context_in_between_tokens"].add_all(ie[attrs[first_attr][1]:attrs[second_attr][0]])
            avg["rest_of_context_tokens"].add_all(ie[attrs[second_attr][1]:end_of_prompt-1])

            avg["subject_set_tokens"].add_all(ie[subject_ranges[0]:subject_ranges[1]])
            avg["object_set_tokens"].add_all(ie[attribute_loc[0]:attribute_loc[1]])
            
            avg["relation_set_tokens"].add_all(ie[start_of_context+4:attrs[first_attr][0]])
            avg["relation_set_tokens"].add_all(ie[attrs[first_attr][1]:attrs[second_attr][0]])
            avg["relation_set_tokens"].add_all(ie[attrs[second_attr][1]:end_of_prompt-1])

            avg["first_subject_token"].add(ie[subject_ranges[0]])
            avg["middle_subject_tokens"].add_all(ie[subject_ranges[0]+1:subject_ranges[1]-1])
            avg["last_subject_token"].add(ie[subject_ranges[1]-1])

            avg["first_object_token"].add(ie[attribute_loc[0]])
            avg["middle_object_tokens"].add_all(ie[attribute_loc[0]+1:attribute_loc[1]-1])
            avg["last_object_token"].add(ie[attribute_loc[1]-1])

        elif experiment_type == "b":
            attrs = [attribute_loc, object_ranges]
            first_attr = attrs.index(min(attrs))
            second_attr = 0 if first_attr == 1 else 1

            avg["begining_of_context"].add_all(ie[start_of_context+4:attrs[first_attr][0]])
            avg["context_in_between_tokens"].add_all(ie[attrs[first_attr][1]:attrs[second_attr][0]])
            avg["rest_of_context_tokens"].add_all(ie[attrs[second_attr][1]:end_of_prompt-1])

            avg["subject_set_tokens"].add_all(ie[attribute_loc[0]:attribute_loc[1]])
            avg["object_set_tokens"].add_all(ie[object_ranges[0]:object_ranges[1]])

            avg["relation_set_tokens"].add_all(ie[start_of_context+4:attrs[first_attr][0]])
            avg["relation_set_tokens"].add_all(ie[attrs[first_attr][1]:attrs[second_attr][0]])
            avg["relation_set_tokens"].add_all(ie[attrs[second_attr][1]:end_of_prompt-1])
            
            avg["first_subject_token"].add(ie[attribute_loc[0]])
            avg["middle_subject_tokens"].add_all(ie[attribute_loc[0]+1:attribute_loc[1]-1])
            avg["last_subject_token"].add(ie[attribute_loc[1]])

            avg["first_object_token"].add(ie[object_ranges[0]])
            avg["middle_object_tokens"].add_all(ie[object_ranges[0]+1:object_ranges[1]-1])
            avg["last_object_token"].add(ie[object_ranges[1]-1])
        elif experiment_type == "c":
            attrs = [subject_ranges, object_ranges]
            first_attr = attrs.index(min(attrs))
            second_attr = 0 if first_attr == 1 else 1
            
            avg["begining_of_context"].add_all(ie[start_of_context+4:attrs[first_attr][0]])
            avg["context_in_between_tokens"].add_all(ie[attrs[first_attr][1]:attrs[second_attr][0]])
            avg["rest_of_context_tokens"].add_all(ie[attrs[second_attr][1]:end_of_prompt-1])

            avg["subject_set_tokens"].add_all(ie[subject_ranges[0]:subject_ranges[1]])
            avg["object_set_tokens"].add_all(ie[object_ranges[0]:object_ranges[1]])

            avg["relation_set_tokens"].add_all(ie[start_of_context+4:attrs[first_attr][0]])
            avg["relation_set_tokens"].add_all(ie[attrs[first_attr][1]:attrs[second_attr][0]])
            avg["relation_set_tokens"].add_all(ie[attrs[second_attr][1]:end_of_prompt-1])

            avg["first_subject_token"].add(ie[subject_ranges[0]])
            avg["middle_subject_tokens"].add_all(ie[subject_ranges[0]+1:subject_ranges[1]-1])
            avg["last_subject_token"].add(ie[subject_ranges[1]-1])

            avg["first_object_token"].add(ie[object_ranges[0]])
            avg["middle_object_tokens"].add_all(ie[object_ranges[0]+1:object_ranges[1]-1])
            avg["last_object_token"].add(ie[object_ranges[1]-1])

        
        result = [avg[name].avg() for name in avg_effects]
        result_std = [avg[name].std() for name in avg_effects]
    
    print_out = [
        {"METRIC": "Average Total Effect", "VALUE": avg["te"].avg()},
    ]

    return {
        "high_score": avg["high_score"].avg(),
        "low_score": avg["low_score"].avg(),
        "labels": [avg[name].humanize() for name in avg_effects],
        "result": result,
        "result_std": result_std,
        "size": num_layers,
        "print_out": print_out,
        "avg": avg,
    }

In [7]:
def has_parametric_behavior(r, threshold=6):
    
    if sum([int(r["answer"] in p) for p in r["crr_predicted"][1:]]) >= threshold:
        return True

    return False


def data_prep(cases_directory, experiment_type, do_log=True, parametric_behavior=False):
    
    experiments = {
        "ordinary_r": [],
        "no_attn_r": [],
        "no_mlp_r": [],
    }
    for directory in tqdm(cases_directory, total=len(cases_directory)):
        for f in glob.glob(os.path.join(directory, "*.bin")):
            r = torch.load(f, map_location='cpu')
            
            if parametric_behavior:
                if has_parametric_behavior(r["ordinary_r"]):
                    for key in r:
                        experiments[key].append(calculate_post_proc(r[key], samples=6, experiment_type=experiment_type))
            else:
                for key in r:
                    experiments[key].append(calculate_post_proc(r[key], samples=6, experiment_type=experiment_type))

    data = {key: tokens_space_division(experiments[key], experiment_type=experiment_type, do_log=do_log) for key in experiments}
    return data, experiments

    
def plot_r_impact(ordinary, no_attn, no_mlp, title, token_idx=-1, bar_width=0.2, savepdf=None, show_plots=True):
    fig, ax = plt.subplots(1, figsize=(5, 3), dpi=300)
    ax.bar(
        [i - bar_width for i in range(len(ordinary[token_idx]))],
        ordinary[token_idx],
        width=bar_width,
        color="#7261ab",
        label="Impact of single state",
    )
    ax.bar(
        [i for i in range(len(no_attn[token_idx]))],
        no_attn[token_idx],
        width=bar_width,
        color="#f3201b",
        label="Impact with Attn severed",
    )
    ax.bar(
        [i + bar_width for i in range(len(no_mlp[token_idx]))],
        no_mlp[token_idx],
        width=bar_width,
        color="#20b020",
        label="Impact with MLP severed",
    )

    ax.set_title(title, pad=30)
    ax.set_xticks([i for i in range(0, ordinary.shape[1])])
    ax.set_xticklabels(list(range(0, ordinary.shape[1])), rotation=0)
    ax.set_ylabel("AIE")
    ax.set_xlabel("Layers")

    ax.yaxis.set_major_formatter(mtick.ScalarFormatter())
    # ax.set_ylim(None, max(0, ordinary.max() * 1.5))

    # ax.legend(loc='upper left', bbox_to_anchor=(1, 1), frameon=False)
    ax.legend(
        loc='upper center', 
        bbox_to_anchor=(0.5, 1.19),
        ncol=2,
        fontsize=8,
        frameon=False
    )

    if savepdf:
        plt.savefig(os.path.join(savepdf, title.replace(" ", "_").strip().lower() + ".pdf"), bbox_inches="tight")

    if show_plots:
        plt.show()
    else:
        plt.close()


def set_plot_font_size(ax, font_size):
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(font_size)

In [8]:
experiment_type = "aa"
savepdf = f"../data/figures/synthetic_context/pse/exp1_object/"
os.makedirs(savepdf, exist_ok=True)

cases_directory = [
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital_of/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/color/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/composer/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/country/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/father/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/genre/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/occupation/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/place_of_birth/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/religion/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/sport/cases",

    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P17/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P19/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P20/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P36/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P69/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P106/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P127/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P131/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P159/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P175/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P176/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P276/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P407/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P413/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P495/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P740/cases",
]



data, experiments = data_prep(cases_directory, experiment_type="a", parametric_behavior=False)
for i, division in enumerate(data["ordinary_r"]["labels"]):
    r = {key: np.array(data[key]["result"]).clip(0, None) for key in data}

    if division not in ["Relation set tokens", "Subject set tokens", "Object set tokens"]:
        continue
    
    plot_r_impact(
        r["ordinary_r"][i:i+1], 
        r["no_attn_r"][i:i+1], 
        r["no_mlp_r"][i:i+1], 
        title=transform_to_title_case(division), 
        token_idx=0, 
        bar_width=0.25, 
        savepdf=savepdf,
        show_plots=False
    )
    

100%|██████████| 27/27 [00:31<00:00,  1.18s/it]
100%|██████████| 1246/1246 [00:12<00:00, 99.41it/s]
100%|██████████| 1246/1246 [00:12<00:00, 100.03it/s]
100%|██████████| 1246/1246 [00:12<00:00, 100.81it/s]


In [9]:
experiment_type = "bb"
savepdf = f"../data/figures/synthetic_context/pse/exp2_subject"
os.makedirs(savepdf, exist_ok=True)

cases_directory = [
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital_of/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/color/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/composer/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/country/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/father/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/genre/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/occupation/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/place_of_birth/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/religion/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/sport/cases",

    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P17/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P19/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P20/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P36/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P69/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P106/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P127/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P131/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P159/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P175/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P176/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P276/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P407/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P413/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P495/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P740/cases",
]



data, experiments = data_prep(cases_directory, experiment_type="b")
for i, division in enumerate(data["ordinary_r"]["labels"]):
    r = {key: np.array(data[key]["result"]).clip(0, None) for key in data}

    if division not in ["Relation set tokens", "Subject set tokens", "Object set tokens"]:
        continue

    plot_r_impact(
        r["ordinary_r"][i:i+1], 
        r["no_attn_r"][i:i+1], 
        r["no_mlp_r"][i:i+1], 
        title=f"{division}", 
        token_idx=0, 
        bar_width=0.25, 
        savepdf=savepdf,
        show_plots=False
    )
    
    # break

100%|██████████| 27/27 [01:37<00:00,  3.62s/it]
100%|██████████| 3726/3726 [02:10<00:00, 28.62it/s]
100%|██████████| 3726/3726 [02:08<00:00, 28.93it/s]
100%|██████████| 3726/3726 [02:11<00:00, 28.23it/s]


In [10]:
experiment_type = "cc"
savepdf = f"../data/figures/synthetic_context/pse/exp2_relation"
os.makedirs(savepdf, exist_ok=True)

cases_directory = [
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital_of/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/color/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/composer/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/country/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/father/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/genre/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/occupation/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/place_of_birth/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/religion/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/sport/cases",

    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P17/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P19/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P20/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P36/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P69/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P106/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P127/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P131/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P159/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P175/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P176/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P276/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P407/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P413/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P495/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P740/cases",
]



data, experiments = data_prep(cases_directory, experiment_type="c")
for i, division in enumerate(data["ordinary_r"]["labels"]):
    r = {key: np.array(data[key]["result"]).clip(0, None) for key in data}

    if division not in ["Relation set tokens", "Subject set tokens", "Object set tokens"]:
        continue

    plot_r_impact(
        r["ordinary_r"][i:i+1], 
        r["no_attn_r"][i:i+1], 
        r["no_mlp_r"][i:i+1], 
        title=f"{division}", 
        token_idx=0, 
        bar_width=0.25, 
        savepdf=savepdf,
        show_plots=False
    )
    
    # break

100%|██████████| 27/27 [01:19<00:00,  2.93s/it]
100%|██████████| 3726/3726 [02:16<00:00, 27.21it/s]
100%|██████████| 3726/3726 [02:12<00:00, 28.12it/s]
100%|██████████| 3726/3726 [02:06<00:00, 29.36it/s]


In [11]:
def path_specific_effects_per_relation(cases_directories, experiment_type, savepdf=None):

    for cases_directory in cases_directories:
        if savepdf:
            _savepdf = os.path.join(savepdf, cases_directory.split("/")[-2])
            os.makedirs(_savepdf, exist_ok=True)
        else:
            _savepdf = None

        data, experiments = data_prep([cases_directory], experiment_type=experiment_type)
        for i, division in enumerate(data["ordinary_r"]["labels"]):
            r = {key: np.array(data[key]["result"]).clip(0, None) for key in data}

            if division not in ["Relation set tokens", "Subject set tokens", "Object set tokens"]:
                continue

            plot_r_impact(
                r["ordinary_r"][i:i+1], 
                r["no_attn_r"][i:i+1], 
                r["no_mlp_r"][i:i+1], 
                title=f"{division}", 
                token_idx=0, 
                bar_width=0.25, 
                savepdf=_savepdf, 
                show_plots=False,
            )
            
            # break

In [12]:
experiment_type = "aa"
cases_directory_relations = [
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital_of/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/color/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/composer/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/country/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/father/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/genre/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/occupation/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/place_of_birth/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/religion/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/sport/cases",

    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P17/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P19/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P20/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P36/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P69/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P106/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P127/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P131/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P159/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P175/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P176/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P276/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P407/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P413/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P495/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P740/cases",
]
savepdf = f"../data/figures/synthetic_context/pse/exp1_object_per_relation"
os.makedirs(savepdf, exist_ok=True)

path_specific_effects_per_relation(cases_directory_relations, experiment_type="a", savepdf=savepdf)

100%|██████████| 1/1 [00:00<00:00,  1.02it/s]
100%|██████████| 101/101 [00:00<00:00, 681.08it/s]
100%|██████████| 101/101 [00:00<00:00, 687.21it/s]
100%|██████████| 101/101 [00:00<00:00, 678.64it/s]
100%|██████████| 1/1 [00:00<00:00,  4.07it/s]
100%|██████████| 26/26 [00:00<00:00, 1201.14it/s]
100%|██████████| 26/26 [00:00<00:00, 1212.27it/s]
100%|██████████| 26/26 [00:00<00:00, 1218.70it/s]
100%|██████████| 1/1 [00:00<00:00, 31.16it/s]
100%|██████████| 4/4 [00:00<00:00, 1210.04it/s]
100%|██████████| 4/4 [00:00<00:00, 1336.40it/s]
100%|██████████| 4/4 [00:00<00:00, 1355.41it/s]
100%|██████████| 1/1 [00:00<00:00, 27.14it/s]
100%|██████████| 4/4 [00:00<00:00, 1211.88it/s]
100%|██████████| 4/4 [00:00<00:00, 1299.25it/s]
100%|██████████| 4/4 [00:00<00:00, 1308.47it/s]
100%|██████████| 1/1 [00:01<00:00,  1.03s/it]
100%|██████████| 101/101 [00:00<00:00, 643.23it/s]
100%|██████████| 101/101 [00:00<00:00, 649.66it/s]
100%|██████████| 101/101 [00:00<00:00, 644.72it/s]
100%|██████████| 1/1 [00:0

In [13]:
experiment_type = "bb"
cases_directory_relations = [
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital_of/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/color/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/composer/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/country/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/father/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/genre/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/occupation/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/place_of_birth/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/religion/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/sport/cases",

    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P17/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P19/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P20/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P36/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P69/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P106/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P127/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P131/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P159/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P175/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P176/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P276/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P407/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P413/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P495/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P740/cases",
]
savepdf = f"../data/figures/synthetic_context/pse/exp2_subject_per_relation"
os.makedirs(savepdf, exist_ok=True)

path_specific_effects_per_relation(cases_directory_relations, experiment_type="b", savepdf=savepdf)

100%|██████████| 1/1 [00:03<00:00,  3.04s/it]
100%|██████████| 303/303 [00:00<00:00, 361.86it/s]
100%|██████████| 303/303 [00:00<00:00, 361.63it/s]
100%|██████████| 303/303 [00:00<00:00, 366.06it/s]
100%|██████████| 1/1 [00:00<00:00,  1.40it/s]
100%|██████████| 78/78 [00:00<00:00, 861.03it/s]
100%|██████████| 78/78 [00:00<00:00, 874.12it/s]
100%|██████████| 78/78 [00:00<00:00, 857.38it/s]
100%|██████████| 1/1 [00:00<00:00,  9.19it/s]
100%|██████████| 12/12 [00:00<00:00, 1344.76it/s]
100%|██████████| 12/12 [00:00<00:00, 1279.47it/s]
100%|██████████| 12/12 [00:00<00:00, 1297.84it/s]
100%|██████████| 1/1 [00:00<00:00,  7.02it/s]
100%|██████████| 12/12 [00:00<00:00, 1222.06it/s]
100%|██████████| 12/12 [00:00<00:00, 1245.65it/s]
100%|██████████| 12/12 [00:00<00:00, 1256.09it/s]
100%|██████████| 1/1 [00:03<00:00,  3.31s/it]
100%|██████████| 302/302 [00:00<00:00, 327.56it/s]
100%|██████████| 302/302 [00:00<00:00, 331.79it/s]
100%|██████████| 302/302 [00:00<00:00, 331.86it/s]
100%|██████████| 

In [14]:
experiment_type = "cc"
cases_directory_relations = [
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/capital_of/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/color/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/composer/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/country/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/father/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/genre/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/occupation/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/place_of_birth/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/religion/cases",
    f"../experiments/ct/popqa/{experiment_type}/matched-both-repr/sport/cases",

    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P17/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P19/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P20/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P36/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P69/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P106/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P127/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P131/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P159/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P175/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P176/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P276/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P407/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P413/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P495/cases",
    f"../experiments/ct/peq/{experiment_type}/matched-both-repr/P740/cases",
]
savepdf = f"../data/figures/synthetic_context/pse/exp2_relation_per_relation"
os.makedirs(savepdf, exist_ok=True)

path_specific_effects_per_relation(cases_directory_relations, experiment_type="c", savepdf=savepdf)

100%|██████████| 1/1 [00:06<00:00,  6.01s/it]
100%|██████████| 303/303 [00:00<00:00, 362.59it/s]
100%|██████████| 303/303 [00:00<00:00, 362.45it/s]
100%|██████████| 303/303 [00:00<00:00, 356.77it/s]
100%|██████████| 1/1 [00:00<00:00,  1.26it/s]
100%|██████████| 78/78 [00:00<00:00, 884.30it/s]
100%|██████████| 78/78 [00:00<00:00, 896.67it/s]
100%|██████████| 78/78 [00:00<00:00, 901.78it/s]
100%|██████████| 1/1 [00:00<00:00,  8.80it/s]
100%|██████████| 12/12 [00:00<00:00, 1271.74it/s]
100%|██████████| 12/12 [00:00<00:00, 1285.94it/s]
100%|██████████| 12/12 [00:00<00:00, 1304.03it/s]
100%|██████████| 1/1 [00:00<00:00,  7.65it/s]
100%|██████████| 12/12 [00:00<00:00, 1220.61it/s]
100%|██████████| 12/12 [00:00<00:00, 1244.82it/s]
100%|██████████| 12/12 [00:00<00:00, 1265.41it/s]
100%|██████████| 1/1 [00:03<00:00,  3.22s/it]
100%|██████████| 302/302 [00:00<00:00, 339.36it/s]
100%|██████████| 302/302 [00:00<00:00, 352.10it/s]
100%|██████████| 302/302 [00:00<00:00, 343.01it/s]
100%|██████████| 