In [None]:
import _pickle as pickle
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import timeit

from itertools import product

In [None]:
map_eval_name = {
    "pretrain-sample_high_prob_class_only-start_pos_0": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_0": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_0-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_0-flip_label": "Condition on Low Frequency",
    
    "pretrain-sample_high_prob_class_only-start_pos_1": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_1": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_1-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_1-flip_label": "Condition on Low Frequency",
    
    "pretrain-sample_high_prob_class_only-start_pos_7": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_7": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_7-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_7-flip_label": "Condition on Low Frequency",
}

stats_keys = [
    "accuracy",
    "p_iwl",
    "context contains query class",
]

In [None]:
repo_path = "/Users/chanb/research/ualberta/simple_icl"
results_dir = "/Users/chanb/research/ualberta/simple_icl/results/simple_icl/results"

# repo_path = "/home/chanb/src/simple_icl"
# results_dir = "/home/chanb/scratch/simple_icl/results"

In [None]:
variant_name = "simple_icl-fixed_g"
# variant_name = "simple_icl-fixed_g-context_alpha"
# variant_name = "simple_icl-learned_g"

checkpoint_steps = 100

stats_file = os.path.join(repo_path, "cc_utils/agg_stats", "{}.feather".format(variant_name))
stats = pd.read_feather(stats_file)

os.makedirs(
    os.path.join(repo_path, "cc_utils/acc-plots", variant_name),
    exist_ok=True
)

results_dir = os.path.join(results_dir, variant_name)
eval_namess = [
    [
        "pretrain-sample_high_prob_class_only-start_pos_0",
        "pretrain-sample_low_prob_class_only-start_pos_0",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_0-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_0-flip_label",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_1",
        "pretrain-sample_low_prob_class_only-start_pos_1",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_1-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_1-flip_label",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_7",
        "pretrain-sample_low_prob_class_only-start_pos_7",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_7-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_7-flip_label",
    ],
]
plot_names = [
    "iwl",
    "iwl-flip_label",
    "icl-last_context",
    "icl-last_context-flip_label",
    "icl-except_first_context",
    "icl-except_first_context-flip_label",
]
plot_titles = [
    "In-weight Evaluation",
    "In-weight Evaluation with Flipped Label",
    "In-context Evaluation with Last Context",
    "In-context Evaluation with Last Context + Flipped Label",
    "In-context Evaluation with Contexts but First",
    "In-context Evaluation with Contexts but First + Flipped Label",
]

In [None]:
variants = stats["variant"].unique()

# Plot Accuracy based on Checkpoint Intervals

In [None]:
for eval_names, plot_name, plot_title in zip(eval_namess, plot_names, plot_titles):
    tic = timeit.default_timer()

    agg_stats = dict()
    for variant in variants:
        agg_stats.setdefault(variant, dict())
        for eval_name in eval_names:
            agg_stats[variant].setdefault(eval_name, dict())
            for stats_key in stats_keys:
                curr_stats = stats.loc[
                    (stats["variant"] == variant)
                    & (stats["eval_name"] == eval_name)
                    & (stats["stats_key"] == stats_key)
                ]["stats"].to_list()
                agg_stats[variant][eval_name][stats_key] = (
                    np.mean(curr_stats, axis=0),
                    np.std(curr_stats, axis=0),
                )

    for variant in variants:
        os.makedirs(os.path.join(repo_path, "cc_utils/acc-plots", variant_name, variant), exist_ok=True)
        fig, axes = plt.subplots(1, len(eval_names), figsize=(5 * len(eval_names), 5))
        for eval_i, eval_name in enumerate(eval_names):
            ax = axes[eval_i]
            x_range = np.arange(len(agg_stats[variant][eval_name]["accuracy"][0])) * checkpoint_steps
            ax.plot(
                x_range,
                np.array(agg_stats[variant][eval_name]["accuracy"][0]) / 100.0,
                label="Accuracy" if eval_i == 0 else "",
                linewidth=3,
                c="red",
                alpha=0.7,
            )
            ax.plot(
                x_range,
                agg_stats[variant][eval_name]["p_iwl"][0],
                label="$\\alpha(x)$" if eval_i == 0 else "",
                linestyle="--",
                c="black",
                alpha=0.3
            )
            ax.plot(
                x_range,
                agg_stats[variant][eval_name]["context contains query class"][0],
                label="% $\geq 1$ Context from Query Class" if eval_i == 0 else "",
                linestyle="-.",
                c="black",
                alpha=0.3
            )
            ax.set_ylim(-0.1, 1.1)
            ax.set_title(map_eval_name[eval_name])

        fig.supylabel("Accuracy/Prob.")
        fig.supxlabel("Gradient Steps")
        fig.suptitle(plot_title)
        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=5,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8", 
        )

        plt.savefig(
            os.path.join(repo_path, "cc_utils/acc-plots", "{}/{}/{}.png".format(variant_name, variant, plot_name)),
            format="png",
            bbox_inches="tight",
            dpi=600,
        )
        plt.close(fig)

    toc = timeit.default_timer()
    print("Done {} {} {} in {}s".format(variant_name, variant, eval_name, toc - tic))

Double check the plots... why no 0 accuracy?

# Plot Accuracy based on Setting

In [None]:
variant_1 = "high_prob"
variant_2 = "ground_truth_prob"

variant_1_values = sorted([float(el) for el in stats[variant_1].unique()])
variant_2_values = sorted([float(el) for el in stats[variant_2].unique()])

In [None]:
stats[variant_1]

In [None]:
stats_key = "accuracy"

os.makedirs(
    os.path.join(repo_path, "cc_utils/{}-{}-plots".format(variant_1, variant_2), variant_name),
    exist_ok=True,
)

for eval_names, plot_name, plot_title in zip(eval_namess, plot_names, plot_titles):
    tic = timeit.default_timer()

    agg_stats = dict()
    for (variant_1_value, variant_2_value) in product(variant_1_values, variant_2_values):
        agg_stats.setdefault(variant_1_value, dict())
        for eval_name in eval_names:
            agg_stats[variant_1_value].setdefault(eval_name, dict(mean=[], std=[]))
            curr_stats = stats.loc[
                (stats[variant_1] == str(variant_1_value))
                & (stats[variant_2] == str(variant_2_value))
                & (stats["eval_name"] == eval_name)
                & (stats["stats_key"] == stats_key)
            ]["stats"].to_list()

            agg_stats[variant_1_value][eval_name]["mean"].append(
                np.mean(curr_stats, axis=0)[-1].item()
            )
            agg_stats[variant_1_value][eval_name]["std"].append(
                np.std(curr_stats, axis=0)[-1].item()
            )
    
    fig, axes = plt.subplots(1, len(eval_names), figsize=(5 * len(eval_names), 5))
    for eval_i, eval_name in enumerate(eval_names):
        for variant_1_value in variant_1_values:
            ax = axes[eval_i]
            ax.plot(
                variant_2_values,
                np.array(agg_stats[variant_1_value][eval_name]["mean"]) / 100.0,
                label="{} - {}".format(variant_1, variant_1_value) if eval_i == 0 else "",
                linewidth=1,
                alpha=0.7,
            )
            ax.set_xlim(-0.1, 1.1)
            ax.set_ylim(-0.1, 1.1)
            ax.set_title(map_eval_name[eval_name])

        fig.supylabel("Accuracy")
        fig.supxlabel(variant_2)
        fig.suptitle(plot_title)
        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=5,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8", 
        )

    plt.savefig(
        os.path.join(repo_path, "cc_utils/{}-{}-plots".format(variant_1, variant_2), variant_name, "{}.png".format(plot_name)),
        format="png",
        bbox_inches="tight",
        dpi=600,
    )
    plt.close(fig)

    toc = timeit.default_timer()
    print("Done {} {} {} in {}s".format(variant_name, variant, eval_name, toc - tic))

# Individual Check

In [None]:
stats.loc[
    (stats["variant"] == "high_prob_0.8-ground_truth_prob_0.0")
    & (stats["eval_name"] == "pretrain-sample_low_prob_class_only-start_pos_0")
    & (stats["seed"] == "seed_0")
]