In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from group_sae.utils import MODEL_MAP

In [None]:
task_mapping = {
    "ioi": "IOI",
    "subject_verb": "Subject-Verb Agreement",
    "greater_than": "Greater Than",
}

In [None]:
size = "1b"
method = "attrib"
what = "faithfulness"
model = "pythia-{}".format(size)
is_topk = True
faith_dir = "/home/fbelotti/group-sae/faithfulness/pythia-{}_downstream/faithfulness_{}".format(
    size, "topk" if is_topk else "thr"
)
dfs = []
for task in ["subject_verb", "ioi", "greater_than"]:
    for cluster in ["Baseline"] + [f"K{i}" for i in range(1, MODEL_MAP[model]["n_layers"] - 1)]:
        df = pd.read_csv(f"{faith_dir}/{model}_{task}_{cluster}_{method}_{what}.csv")
        df["G"] = int(cluster.split("K")[1]) if cluster != "Baseline" else 0
        df["task"] = task
        dfs.append(df)
faith_df = pd.concat(dfs)
if not is_topk:
    faith_df["N_cut"] = pd.cut(faith_df["N"], 50)
    faith_df["N_cut"] = faith_df["N_cut"].apply(lambda x: x.mid)

In [None]:
faith_df

In [None]:
mean_df = faith_df.groupby(["G", "N"])["score"].mean().reset_index()

In [None]:
faith_df_integrated = (
    faith_df.groupby(["task", "G"]).apply(lambda x: np.trapz(x["score"], x["N"])).rename("area")
)
faith_df_integrated = faith_df_integrated.reset_index()
faith_df_integrated = faith_df_integrated.sort_values(by=["G"])

In [None]:
mean_df_integrated = faith_df_integrated.groupby(["G"])["area"].mean().reset_index()
mean_df_integrated = mean_df_integrated.sort_values(by=["G"])

In [None]:
palette = ["gray"] + sns.color_palette("flare", n_colors=len(mean_df_integrated), as_cmap=False)[
    1:
]
fig, axes = plt.subplots(2, 2, figsize=(15, 7), sharey=False)
for i, task in enumerate(["subject_verb", "ioi", "greater_than"]):
    ax = axes[i // 2, i % 2]
    sns.lineplot(
        data=faith_df[faith_df["task"] == task],
        x="N",
        y="score",
        hue="G",
        palette=palette,
        ax=ax,
        legend=i == 0,
    )
    if i == 0:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend().remove()
        labels[labels.index("0")] = "Baseline"
    ax.set_xlabel("G")
    ax.set_ylabel("AUC")
    ax.set_title(task_mapping[task])
    ax.yaxis.set_tick_params(labelbottom=True)
ax = axes[1, 1]
sns.lineplot(
    data=mean_df,
    x="N",
    y="score",
    hue="G",
    palette=palette,
    ax=ax,
    legend=False,
)
ax.set_title("Average")
ax.yaxis.set_tick_params(labelbottom=True)
fig.suptitle(f"{model.title()} - {what.title()}")
fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.5, -0.01), ncols=5)
plt.tight_layout()

In [None]:
palette = ["gray"] + sns.color_palette("flare", n_colors=len(mean_df_integrated), as_cmap=False)[
    1:
]
fig, axes = plt.subplots(2, 2, figsize=(15, 7), sharey=True)
for i, task in enumerate(["subject_verb", "ioi", "greater_than"]):
    ax = axes[i // 2, i % 2]
    sns.barplot(
        data=faith_df_integrated[faith_df_integrated["task"] == task],
        x="G",
        y="area",
        hue="G",
        palette=palette,
        ax=ax,
        legend=False,
    )
    ax.set_xlabel("G")
    ax.set_ylabel("AUC")
    ax.set_title(task_mapping[task])
    ax.yaxis.set_tick_params(labelbottom=True)
ax = axes[1, 1]
sns.barplot(
    data=mean_df_integrated,
    x="G",
    y="area",
    hue="G",
    palette=palette,
    ax=ax,
    legend=False,
)
ax.set_title("Average")
ax.yaxis.set_tick_params(labelbottom=True)
fig.suptitle(f"{model.title()} - {what.title()}")
plt.tight_layout()