# Imports

In [None]:
from ast import literal_eval
import functools
import json
import os
import random
import shutil

# Scienfitic packages
import numpy as np
import pandas as pd
import torch
import datasets
from torch import cuda
torch.set_grad_enabled(False)

# Visuals
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(context="notebook",
        rc={"font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})
palette_ = sns.color_palette("Set1")
palette = palette_[2:5] + palette_[7:]
sns.set_theme(style='whitegrid')

# Utilities

from general_utils import (
  ModelAndTokenizer,
  make_inputs,
  decode_tokens,
  find_token_range,
  predict_from_input,
)

from patchscopes_utils import *

from tqdm import tqdm
tqdm.pandas()

In [None]:
model_to_hook = {
    "EleutherAI/pythia-6.9b": set_hs_patch_hooks_neox,
    "EleutherAI/pythia-12b": set_hs_patch_hooks_neox,
    "meta-llama/Llama-2-13b-hf": set_hs_patch_hooks_llama,
    "lmsys/vicuna-7b-v1.5": set_hs_patch_hooks_llama,
    "./stable-vicuna-13b": set_hs_patch_hooks_llama,
    "CarperAI/stable-vicuna-13b-delta": set_hs_patch_hooks_llama,
    "EleutherAI/gpt-j-6b": set_hs_patch_hooks_gptj
}

In [None]:
# Load model

# 0-shot with GPT-J
model_name = "gpt-j-6B"
sos_tok = False

if "13b" in model_name or "12b" in model_name:
    torch_dtype = torch.float16
else:
    torch_dtype = None

my_device = torch.device("cuda:0")

mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=False,
    torch_dtype=torch_dtype,
    device=my_device,
)
mt.set_hs_patch_hooks = model_to_hook[model_name]
mt.model.eval()

In [None]:
def run_experiment(task_type, task_name, data_dir, output_dir, batch_size=512, n_samples=-1,
                   save_output=True, replace=False, only_correct=False, is_icl=True):
    fdir_out = f"{output_dir}/{task_type}"
    fname_out = f"{fdir_out}/{task_name}_only_correct_{only_correct}.pkl"
    if not replace and os.path.exists(fname_out):
        print(f"File {fname_out} exists. Skipping...")
        return
    print(f"Running experiment on {task_type}/{task_name}...")
    df = pd.read_pickle(f"{data_dir}/{task_type}/{task_name}.pkl")
    if only_correct:
        df = df[df["is_correct_baseline"]].reset_index(drop=True)
    # Dropping empty prompt sources. This is an artifact of saving and reloading inputs
    df = df[df["prompt_source"]!=""].reset_index(drop=True)
    # Dropping prompt sources with \n. pandas read_pickle is not able to handle them properly and drops the rest of the input.
    df = df[~df["prompt_source"].str.contains('\n')].reset_index(drop=True)
    # After manual inspection, this example seems to have tokenization issues. Dropping.
    if task_name == "star_constellation":
        df = df[~df["prompt_source"].str.contains("service")].reset_index(drop=True)
    elif task_name == "object_superclass":
        df = df[~df["prompt_source"].str.contains("Swainson ’ s hawk and the prairie")].reset_index(drop=True)
    print(f"\tNumber of samples: {len(df)}")

    # BATCHED
    batch = []
    for _, row in tqdm.tqdm(df.iterrows()):
        for layer_source in range(mt.num_layers-1):
            for layer_target in range(mt.num_layers-1):
                item = dict(row)
                item.update({
                    "layer_source": layer_source,
                    "layer_target": layer_target,
                })
                batch.append(item)
    experiment_df = pd.DataFrame.from_records(batch)

    if n_samples > 0 and n_samples<len(experiment_df):
        experiment_df = experiment_df.sample(n=n_samples, replace=False, random_state=42).reset_index(drop=True)

    print(f"\tNumber of datapoints for patching experiment: {len(experiment_df)}")

    eval_results = evaluate_attriburte_exraction_batch(mt, experiment_df, batch_size=batch_size, is_icl=is_icl)

    results_df = experiment_df.head(len(eval_results["is_correct_patched"]))
    for key, value in eval_results.items():
        results_df[key] = list(value)

    if save_output:
        fdir_out = f"{output_dir}/{task_type}"
        if not os.path.exists(fdir_out):
            os.makedirs(fdir_out)
        results_df.to_csv(f"{fdir_out}/{task_name}_only_correct_{only_correct}.tsv", sep="\t")
        results_df.to_pickle(f"{fdir_out}/{task_name}_only_correct_{only_correct}.pkl")

    return results_df

In [None]:
for task_type in ["commonsense_updated_target_prompts", "factual_updated_target_prompts"]:
    for fname in tqdm.tqdm(os.listdir(f"./preprocessed_data/{task_type}")):
        if fname.endswith('.pkl'):
            task_name = fname[:-4]
        else:
            continue
        print(f"Processing {fname}...")
        try:
            run_experiment(task_type, task_name,
                           data_dir="./preprocessed_data",
                           output_dir=f"./outputs/results_ae",
                           batch_size=512,
                           is_icl=False,
                           only_correct=True,
                           replace=False,
                          )
        except:
            pdb.post_mortem()

## Plot
heatmaps conditional on correct base model prediciton

In [None]:
def plot_heatmap(fname, _vmin=0, _vmax=1):
    df = pd.read_pickle(fname)
    plot_ttl = f"{fname}\n{model_name.strip('./')}"

    sub_df = df[df["is_correct_baseline"]].reset_index(drop=True)
    heatmap_data_patched_given_correct_original = sub_df.groupby(['layer_source', 'layer_target'])["is_correct_patched"].mean().unstack()
    ax = sns.heatmap(data=heatmap_data_patched_given_correct_original, cmap="crest_r", vmin=_vmin, vmax=_vmax)
    ax.invert_yaxis()
    ax.set_title(f"{plot_ttl} - accuracy\n(successful patch conditional on SUCESSFUL original)")
    plt.show()
    plt.clf()

    sub_df_2 = df[df["is_correct_baseline"]==False].reset_index(drop=True)
    heatmap_data_patched_given_incorrect_original = sub_df_2.groupby(['layer_source', 'layer_target'])["is_correct_patched"].mean().unstack()
    ax = sns.heatmap(data=heatmap_data_patched_given_incorrect_original, cmap="crest_r", vmin=_vmin, vmax=_vmax)
    ax.invert_yaxis()
    ax.set_title(f"{plot_ttl} - accuracy\n(successful patch conditional on UNSUCCESSFUL original)")
    plt.show()
    plt.clf()

    heatmap_data_original = df.groupby(['layer_source', 'layer_target'])["is_correct_baseline"].mean().unstack()
    ax = sns.heatmap(data=heatmap_data_original, cmap="crest_r", vmin=_vmin, vmax=_vmax)
    ax.invert_yaxis()
    ax.set_title(f"{plot_ttl} - successful original")
    plt.show()
    plt.clf()

    heatmap_data_patched = df.groupby(['layer_source', 'layer_target'])["is_correct_patched"].mean().unstack()
    ax = sns.heatmap(data=heatmap_data_patched, cmap="crest_r", vmin=_vmin, vmax=_vmax)
    ax.invert_yaxis()
    ax.set_title(f"{plot_ttl} - successful patched")
    plt.show()

    if "is_correct_probe" in df.columns:
        heatmap_probe_correct_original = sub_df.groupby(['layer_source', 'layer_target'])["is_correct_probe"].mean().unstack()
        ax = sns.heatmap(data=heatmap_data_patched_given_correct_original, cmap="crest_r", vmin=_vmin, vmax=_vmax)
        ax.invert_yaxis()
        ax.set_title(f"{plot_ttl} - accuracy\n(probe success conditional on SUCESSFUL original)")
        plt.show()
        plt.clf()

        heatmap_probe_given_incorrect_original = sub_df_2.groupby(['layer_source', 'layer_target'])["is_correct_probe"].mean().unstack()
        ax = sns.heatmap(data=heatmap_probe_given_incorrect_original, cmap="crest_r", vmin=_vmin, vmax=_vmax)
        ax.invert_yaxis()
        ax.set_title(f"{plot_ttl} - accuracy\n(probe success conditional on UNSUCCESSFUL original)")
        plt.show()
        plt.clf()

        heatmap_probe = df.groupby(['layer_source', 'layer_target'])["is_correct_probe"].mean().unstack()
        ax = sns.heatmap(data=heatmap_probe, cmap="crest_r", vmin=_vmin, vmax=_vmax)
        ax.invert_yaxis()
        ax.set_title(f"{plot_ttl} - successful probe")
        plt.show()

In [None]:
plot_heatmap("./outputs/results_ae/commonsense/fruit_inside_color.pkl")

**Exp 0: Linguistic.**
Using the source with a prompt id as target.

**Exp 1: Linguistic.**
Using the source and target that are just parapharases of each other, but similar semantically.

**Exp 2: Commonsense.**
Sampling source prompts.

**Exp 3: Factual.**
Combining different tasks to make multihop reasoning.
Range of source should be the same as domain of target.

In [None]:
def run_experiment(task_type, task_name, batch_size=512, n_samples=-1, save_output=True):
    print(f"Running experiment on {task_type}/{task_name}...")
    df = pd.read_pickle(f"./outputs/preprocessed_data/{task_type}/{task_name}.pkl")
    filtered_df = df[df["target_baseline_target"] == df["target_baseline_prediction_gpt-j-6B"]]
    print(f"\tNumber of filtered samples: {len(filtered_df)}")

    # BATCHED
    batch = []
    for layer_source in tqdm.tqdm(range(mt.num_layers)):
        for layer_target in range(mt.num_layers):
            for _, row in filtered_df.iterrows():
                item = dict(row)
                item.update({
                    "layer_source": layer_source,
                    "layer_target": layer_target,
                })
                batch.append(item)
    experiment_df = pd.DataFrame.from_records(batch)

    if n_samples > 0 and n_samples<len(experiment_df):
        experiment_df = experiment_df.sample(n=n_samples, replace=False, random_state=42).reset_index(drop=True)

    print(f"\tNumber of datapoints for patching experiment: {len(experiment_df)}")

    prec_1, surprisal, next_token = evaluate_patch_next_token_prediction_batch(mt, experiment_df, batch_size=batch_size)

    results_df = experiment_df.head(len(prec_1))
    results_df['prec_1'] = prec_1
    results_df['surprisal'] = surprisal
    results_df['next_token'] = next_token

    if save_output:
        results_df.to_csv(f"./outputs/results/{task_type}/{task_name}.tsv", sep="\t")
        results_df.to_pickle(f"./outputs/results/{task_type}/{task_name}.pkl")

    return results_df

In [None]:
def run_experiment_prompt_id(task_type, task_name, batch_size=512, n_samples=-1, save_output=True):
    print(f"Running experiment on with prompt ID on {task_type}/{task_name}...")
    df = pd.read_pickle(f"./outputs/preprocessed_data/{task_type}/{task_name}.pkl")
    df["prompt_target"] = "cat cat dog dog 1234 1234 hello hello {}"
    df = df.drop(['target_baseline', 'target_baseline_target', "target_baseline_prediction_gpt-j-6B",
                     "target_template_cropped_toks"], axis=1)
    df["position_target"] = -1

    # Dropping duplicate target examples, we only care about source here and want to use prompt id for target
    df = df.drop_duplicates(subset=["sample_id"])

    print(f"\tNumber of samples: {len(df)}")

    # BATCHED
    batch = []
    for layer_source in tqdm.tqdm(range(mt.num_layers)):
        for layer_target in range(mt.num_layers):
            for _, row in df.iterrows():
                item = dict(row)
                item.update({
                    "layer_source": layer_source,
                    "layer_target": layer_target,
                })
                batch.append(item)
    experiment_df = pd.DataFrame.from_records(batch)
    if n_samples > 0 and n_samples<len(experiment_df):
        experiment_df = experiment_df.sample(n=n_samples, replace=False, random_state=42).reset_index(drop=True)

    print(f"\tNumber of datapoints for patching experiment: {len(experiment_df)}")

    prec_1, surprisal, next_token = evaluate_patch_next_token_prediction_batch(mt, experiment_df, batch_size=batch_size)

    results_df = experiment_df.head(len(prec_1))
    results_df['prec_1'] = prec_1
    results_df['surprisal'] = surprisal
    results_df['next_token'] = next_token
    if save_output:
        results_df.to_csv(f"./outputs/results/prompt_id/{task_type}/{task_name}.tsv", sep="\t")
        results_df.to_pickle(f"./outputs/results/prompt_id/{task_type}/{task_name}.pkl")

    return results_df

## Plots

In [None]:
def make_plots_from_df(df, plot_ttl, metrics_to_plot):
    same_layer_df = df[df['layer_source'] == df['layer_target']].reset_index(drop=True)
    for metric in metrics_to_plot:
        ax = sns.lineplot(data=same_layer_df, x='layer_target', y=metric)
        ax.set_title(f"{plot_ttl} - source == target layer - {metric}")
        plt.show()
        plt.clf()

        ax = sns.lineplot(data=df, x='layer_target', y=metric, hue="layer_source")
        ax.set_title(f"{plot_ttl} - {metric}")
        plt.show()
        plt.clf()

        heatmap_data = df.groupby(['layer_source', 'layer_target'])[metric].mean().unstack()
        my_cmap = "crest" if metric=="surprisal" else "crest_r"
        ax = sns.heatmap(data=heatmap_data, cmap=my_cmap)
        ax.invert_yaxis()
        ax.set_title(f"{plot_ttl} - {metric}")
        plt.show()
        plt.clf()

def make_plots_from_file(fname, metrics_to_plot):
    filtered_results_df = pd.read_pickle(fname)
    plot_ttl = f"{fname}\n{model_name.strip('./')}"
    make_plots_from_df(filtered_results_df, plot_ttl, metrics_to_plot)

    return filtered_results_df

## Exp 0 figures

In [None]:
# source: the opposite of **small** is
# target: cat cat hello hello {}

make_plots_from_file("./outputs/results/prompt_id/linguistic/adj_antonym.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/prompt_id/linguistic/verb_past_tense.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/prompt_id/linguistic/word_first_letter.pkl", ['prec_1', 'surprisal'])

## Exp 1 figures

In [None]:
make_plots_from_file("./outputs/results/linguistic/adj_antonym.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/linguistic/verb_past_tense.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/linguistic/word_first_letter.pkl", ['prec_1', 'surprisal'])

## Exp 2 figures

In [None]:
make_plots_from_file("./outputs/results/commonsense/task_done_by_person.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/commonsense/task_done_by_tool.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/commonsense/fruit_inside_color.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/commonsense/work_location.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/commonsense/substance_phase.pkl", ['prec_1', 'surprisal'])

In [None]:
make_plots_from_file("./outputs/results/commonsense/object_superclass.pkl", ['prec_1', 'surprisal'])

## Exp 3 figures

In [None]:
make_plots_from_file("./outputs/results/factual/combined_multihop.pkl", ['prec_1', 'surprisal'])