In [39]:
from itertools import combinations
from pathlib import Path
from typing import Literal, Optional

from matplotlib import pyplot as plt
from matplotlib import rcParams
import numpy as np
import pandas as pd
import seaborn as sns

In [2]:
# increase figure size to (8, 8)
rcParams["figure.figsize"] = (10, 6)
rcParams["figure.dpi"] = 150
plt.style.use("ggplot")

# Constants

In [105]:
SAVE_DIR = Path("extraction_results")
DEBERTA = "deberta-v2-xxlarge-mnli"

DATASET_ORDER = ("imdb", "amazon-polarity", "ag-news", "dbpedia-14", "copa", "rte", "boolq", "qnli", "piqa", "all")
METHOD_ORDER = ("CCS", "LR", "CCS-md", "LR-md", "RCCS")

# Utils

In [32]:
def load_probs(
    model_name: str,
    train: str,
    test: str,
    method: str = "CCS",
    save_dir: Optional[Path] = None,
):
    save_dir = save_dir or SAVE_DIR

    dir = (save_dir / "rccs") if method.startswith("RCCS") else save_dir
    folder = dir / f"states_{model_name}_{method}" / train
    pattern = f"{test}*_{method}.csv" if test != "all" else f"*_{method}.csv"
    return pd.concat([pd.read_csv(f) for f in folder.glob(pattern)])


def get_max_acc_df(df):
    """Compute max(acc, 1-acc) for each experiment.

    Use the in-domain accuracy to decide whether to flip the accuracy instead of
    computing max(acc, 1-acc) for test datasets even if the in-domain accuracy
    is greater than 0.5.
    """
    processed_df = []
    for key, exp_df in df.groupby(["model", "prefix", "method", "prompt_level"]):
        for train_ds in exp_df["train"].unique():
            train_df = exp_df[exp_df["train"] == train_ds].copy()
            if train_ds == "all":
                # If trained on all the datasets, use the mean accuracy on each
                # dataset to decide whether to flip the accuracy.
                in_domain_acc = train_df["accuracy"].mean()
            else:
                in_domain_acc = train_df[train_df["test"] == train_ds]["accuracy"]
                assert len(in_domain_acc) == 1
                in_domain_acc = in_domain_acc.iloc[0]
            if in_domain_acc < 0.5:
                print(f"Flipping accuracy for {key} {train_ds}")
                train_df["accuracy"] = 1 - train_df["accuracy"]

            processed_df.append(train_df)
    return pd.concat(processed_df)


def load_stats_dfs(
    model_name: str,
    train: Optional[str] = None,
    test: Optional[str] = None,
    method: Optional[str] = None,
    prefix: Optional[str] = None,
    save_dir: Optional[Path] = None,
    max_acc: bool = True,
):
    save_dir = save_dir or SAVE_DIR

    if method is not None and method.startswith("RCCS"):
        dir = save_dir / "rccs"
    else:
        dir = save_dir

    if prefix is None:
        pattern = f"{model_name}*.csv"
    else:
        pattern = f"{model_name}_{prefix}_*.csv"
    csvs = dir.glob(pattern)
    dfs = [pd.read_csv(f) for f in csvs]

    if not dfs:
        raise ValueError(
            f"No csvs found for {model_name}, {train}, {test}, {method}"
        )

    # Filter by train & method
    if train is not None and train != "all":
        dfs = [df[df["train"] == train] for df in dfs]
    if method is not None:
        dfs = [df[df["method"] == method] for df in dfs]

    if test is not None and test != "all":
        # Filter by test
        dfs = [df[df["test"] == test] for df in dfs]
        assert all(len(df) == 1 for df in dfs)

    if max_acc:
        dfs = [get_max_acc_df(df) for df in dfs]

    return dfs


def load_stats(
    model_name: str,
    train: str,
    test: str,
    method: str,
    prefix: str,
    save_dir: Optional[Path] = None,
):
    dfs = load_stats_dfs(model_name, train, test, method, prefix, save_dir=save_dir)
    # If test = "all", average across all test datasets for each seed.
    # Otherwise, return the stats for the single test dataset for each seed.
    return {
        k: np.array([df[k].mean() for df in dfs])
        for k in ["accuracy", "loss", "cons_loss", "sim_loss"]
        if all(k in df.columns for df in dfs)
    }  # k: (seeds,)

# Load data

In [33]:
stats_dfs = load_stats_dfs(DEBERTA)

# Transfer accuracy

In [101]:
prefix = "normal"
std_annot = False

# Filter DataFrames to only include rows with prefix "normal"
filtered_data = []
for df in stats_dfs:
    prefix_df = df[df["prefix"] == prefix]
    if not prefix_df.empty:
        filtered_data.append(prefix_df)

# Combine filtered DataFrames into a single DataFrame
combined_df = pd.concat(filtered_data, ignore_index=True)

# Calculate mean accuracy and standard deviation of accuracy for each combination
agg_df = combined_df.groupby(['model', 'method', 'train', 'test']).agg(
    mean_accuracy=('accuracy', 'mean'),
    std_accuracy=('accuracy', 'std')
).reset_index()

for model in agg_df.model.unique():
    model_df = agg_df[agg_df['model'] == model]
    for method in METHOD_ORDER:
        df = model_df[model_df['method'] == method]
        if len(df) == 0:
            continue

        # Modify the pivot table creation to include both mean accuracy and std accuracy in the annotations
        pivot_table_mean = df.pivot("train", "test", "mean_accuracy").reindex(index=DATASET_ORDER, columns=DATASET_ORDER)
        pivot_table_std = df.pivot("train", "test", "std_accuracy").reindex(index=DATASET_ORDER, columns=DATASET_ORDER)

        # Reorder the pivot tables according to DATASET_ORDER
        pivot_table_mean_ordered = pivot_table_mean.reindex(index=DATASET_ORDER, columns=DATASET_ORDER)
        pivot_table_std_ordered = pivot_table_std.reindex(index=DATASET_ORDER, columns=DATASET_ORDER)

        # Combine mean and std into a single string for each cell annotation
        annotations = pivot_table_mean.applymap("{:.2f}".format)
        if std_annot:
            annotations += "+-" + pivot_table_std.applymap("{:.2f}".format)
            annot_size = 6
        else:
            annot_size = 10

        # Create the heatmap with custom annotations
        plt.figure(figsize=(10, 8))
        sns.heatmap(pivot_table_mean, annot=annotations, fmt="", cmap="viridis",
                    cbar_kws={'label': 'Mean Accuracy'}, vmin=0.5, vmax=1, annot_kws={"size": annot_size})

        # Adding title and axis labels
        plt.title(f'{model} - {method} Mean Accuracy')
        plt.xlabel('Test Dataset')
        plt.ylabel('Train Dataset')

        # Show the plot
        plt.show()


In [60]:
agg_df[agg_df["method"] == "CCS"]

Unnamed: 0,model,method,train,test,mean_accuracy,std_accuracy
0,deberta-v2-xxlarge-mnli,CCS,ag-news,ag-news,0.889594,0.011566
1,deberta-v2-xxlarge-mnli,CCS,ag-news,amazon-polarity,0.821023,0.043104
2,deberta-v2-xxlarge-mnli,CCS,ag-news,boolq,0.660775,0.018843
3,deberta-v2-xxlarge-mnli,CCS,ag-news,copa,0.546167,0.007097
4,deberta-v2-xxlarge-mnli,CCS,ag-news,dbpedia-14,0.964125,0.003495
...,...,...,...,...,...,...
85,deberta-v2-xxlarge-mnli,CCS,rte,dbpedia-14,0.645031,0.032496
86,deberta-v2-xxlarge-mnli,CCS,rte,imdb,0.621442,0.017998
87,deberta-v2-xxlarge-mnli,CCS,rte,piqa,0.539841,0.008679
88,deberta-v2-xxlarge-mnli,CCS,rte,qnli,0.638950,0.010866


In [15]:
stats_dict = load_stats(DEBERTA, "imdb", "all", "CCS", "normal")

In [16]:
stats_dict

{'accuracy': array([0.73433219, 0.73827072, 0.73210774, 0.73512759, 0.73495521,
        0.73751971, 0.7392942 , 0.73676192, 0.73923347, 0.73621103]),
 'loss': array([0.48174268, 0.49184621, 0.48894357, 0.50459721, 0.46131563,
        0.50674405, 0.50379212, 0.48608973, 0.48088501, 0.48863824]),
 'cons_loss': array([0.17713956, 0.17177885, 0.17801163, 0.18261436, 0.17174362,
        0.18504469, 0.18097273, 0.17447856, 0.17481037, 0.17436141]),
 'sim_loss': array([0.30460312, 0.32006736, 0.31093193, 0.32198284, 0.28957201,
        0.32169935, 0.3228194 , 0.31161117, 0.30607464, 0.31427682])}

In [4]:
fpath = "extraction_results/deberta-v2-xxlarge-mnli_normal-dot_0.csv"
df = pd.read_csv(fpath)

In [6]:
df.train.unique()

array(['all', 'imdb', 'amazon-polarity', 'ag-news', 'dbpedia-14', 'copa',
       'rte', 'boolq', 'qnli', 'piqa'], dtype=object)

In [8]:
train_test_list = []
for i in range(len(df)):
    train_test_list.append((df.iloc[i]['train'], df.iloc[i]['test']))

test_test_sets = sorted(set(train_test_list))
test_test_sets

[('ag-news', 'ag-news'),
 ('ag-news', 'amazon-polarity'),
 ('ag-news', 'boolq'),
 ('ag-news', 'copa'),
 ('ag-news', 'dbpedia-14'),
 ('ag-news', 'imdb'),
 ('ag-news', 'piqa'),
 ('ag-news', 'qnli'),
 ('ag-news', 'rte'),
 ('all', 'ag-news'),
 ('all', 'amazon-polarity'),
 ('all', 'boolq'),
 ('all', 'copa'),
 ('all', 'dbpedia-14'),
 ('all', 'imdb'),
 ('all', 'piqa'),
 ('all', 'qnli'),
 ('all', 'rte'),
 ('amazon-polarity', 'ag-news'),
 ('amazon-polarity', 'amazon-polarity'),
 ('amazon-polarity', 'boolq'),
 ('amazon-polarity', 'copa'),
 ('amazon-polarity', 'dbpedia-14'),
 ('amazon-polarity', 'imdb'),
 ('amazon-polarity', 'piqa'),
 ('amazon-polarity', 'qnli'),
 ('amazon-polarity', 'rte'),
 ('boolq', 'ag-news'),
 ('boolq', 'amazon-polarity'),
 ('boolq', 'boolq'),
 ('boolq', 'copa'),
 ('boolq', 'dbpedia-14'),
 ('boolq', 'imdb'),
 ('boolq', 'piqa'),
 ('boolq', 'qnli'),
 ('boolq', 'rte'),
 ('copa', 'ag-news'),
 ('copa', 'amazon-polarity'),
 ('copa', 'boolq'),
 ('copa', 'copa'),
 ('copa', 'dbpedia-