In [1]:
import _pickle as pickle
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import timeit

from itertools import product

In [2]:
map_eval_name = {
    "pretraining": "",
    "pretrain-sample_high_prob_class_only-start_pos_0": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_0": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_0-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_0-flip_label": "Condition on Low Frequency",
    
    "pretrain-sample_high_prob_class_only-start_pos_1": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_1": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_1-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_1-flip_label": "Condition on Low Frequency",
    
    "pretrain-sample_high_prob_class_only-start_pos_7": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_7": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_7-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_7-flip_label": "Condition on Low Frequency",
}

stats_keys = [
    "accuracy",
    "p_iwl",
    "context contains query class",
]

In [3]:
repo_path = "/Users/chanb/research/ualberta/simple_icl"
results_dir = "/Users/chanb/research/ualberta/simple_icl/results/simple_icl/results"

# repo_path = "/home/chanb/src/simple_icl"
# results_dir = "/home/chanb/scratch/simple_icl/results"

In [8]:
# variant_name = "simple_icl-fixed_g"
variant_name = "simple_icl-fixed_g-context_alpha"
# variant_name = "simple_icl-learned_g"

checkpoint_steps = 100

stats_file = os.path.join(repo_path, "cc_utils/agg_stats", "{}.feather".format(variant_name))
stats = pd.read_feather(stats_file)

os.makedirs(
    os.path.join(repo_path, "cc_utils/acc-plots", variant_name),
    exist_ok=True
)

results_dir = os.path.join(results_dir, variant_name)
eval_namess = [
    [
        "pretraining"
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_0",
        "pretrain-sample_low_prob_class_only-start_pos_0",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_0-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_0-flip_label",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_1",
        "pretrain-sample_low_prob_class_only-start_pos_1",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_1-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_1-flip_label",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_7",
        "pretrain-sample_low_prob_class_only-start_pos_7",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_7-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_7-flip_label",
    ],
]
plot_names = [
    "pretraining",
    "iwl",
    "iwl-flip_label",
    "icl-last_context",
    "icl-last_context-flip_label",
    "icl-except_first_context",
    "icl-except_first_context-flip_label",
]
plot_titles = [
    "Pretraining",
    "In-weight Evaluation",
    "In-weight Evaluation with Flipped Label",
    "In-context Evaluation with Last Context",
    "In-context Evaluation with Last Context + Flipped Label",
    "In-context Evaluation with Contexts but First",
    "In-context Evaluation with Contexts but First + Flipped Label",
]

In [9]:
variants = stats["variant"].unique()

# Plot Accuracy based on Checkpoint Intervals

In [10]:
for eval_names, plot_name, plot_title in zip(eval_namess, plot_names, plot_titles):
    tic = timeit.default_timer()

    agg_stats = dict()
    for variant in variants:
        agg_stats.setdefault(variant, dict())
        for eval_name in eval_names:
            agg_stats[variant].setdefault(eval_name, dict())
            for stats_key in stats_keys:
                curr_stats = stats.loc[
                    (stats["variant"] == variant)
                    & (stats["eval_name"] == eval_name)
                    & (stats["stats_key"] == stats_key)
                ]["stats"].to_list()
                agg_stats[variant][eval_name][stats_key] = (
                    np.mean(curr_stats, axis=0),
                    np.std(curr_stats, axis=0),
                )

    for variant in variants:
        os.makedirs(os.path.join(repo_path, "cc_utils/acc-plots", variant_name, variant), exist_ok=True)
        fig, axes = plt.subplots(1, len(eval_names), figsize=(5 * len(eval_names), 5))
        for eval_i, eval_name in enumerate(eval_names):
            if len(eval_names) > 1:
                ax = axes[eval_i]
            else:
                ax = axes
            x_range = np.arange(len(agg_stats[variant][eval_name]["accuracy"][0])) * checkpoint_steps
            ax.plot(
                x_range,
                np.array(agg_stats[variant][eval_name]["accuracy"][0]) / 100.0,
                label="Accuracy" if eval_i == 0 else "",
                linewidth=3,
                c="red",
                alpha=0.7,
            )
            ax.plot(
                x_range,
                agg_stats[variant][eval_name]["p_iwl"][0],
                label="$\\alpha(x)$" if eval_i == 0 else "",
                linestyle="--",
                c="black",
                alpha=0.3
            )
            ax.plot(
                x_range,
                agg_stats[variant][eval_name]["context contains query class"][0],
                label="% $\geq 1$ Context from Query Class" if eval_i == 0 else "",
                linestyle="-.",
                c="black",
                alpha=0.3
            )
            ax.set_ylim(-0.1, 1.1)
            ax.set_title(map_eval_name[eval_name])

        fig.supylabel("Accuracy/Prob.")
        fig.supxlabel("Gradient Steps")
        fig.suptitle(plot_title)
        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=5,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8", 
        )

        plt.savefig(
            os.path.join(repo_path, "cc_utils/acc-plots", "{}/{}/{}.png".format(variant_name, variant, plot_name)),
            format="png",
            bbox_inches="tight",
            dpi=600,
        )
        plt.close(fig)

    toc = timeit.default_timer()
    print("Done {} {} {} in {}s".format(variant_name, variant, eval_name, toc - tic))

Done simple_icl-fixed_g-context_alpha high_prob_0.9-ground_truth_prob_0.75 pretraining in 25.783180166996317s
Done simple_icl-fixed_g-context_alpha high_prob_0.9-ground_truth_prob_0.75 pretrain-sample_low_prob_class_only-start_pos_0 in 50.36323950000224s
Done simple_icl-fixed_g-context_alpha high_prob_0.9-ground_truth_prob_0.75 pretrain-sample_low_prob_class_only-start_pos_0-flip_label in 50.570351208007196s
Done simple_icl-fixed_g-context_alpha high_prob_0.9-ground_truth_prob_0.75 pretrain-sample_low_prob_class_only-start_pos_1 in 50.210196457992424s
Done simple_icl-fixed_g-context_alpha high_prob_0.9-ground_truth_prob_0.75 pretrain-sample_low_prob_class_only-start_pos_1-flip_label in 48.25654795799346s
Done simple_icl-fixed_g-context_alpha high_prob_0.9-ground_truth_prob_0.75 pretrain-sample_low_prob_class_only-start_pos_7 in 50.521385208005086s
Done simple_icl-fixed_g-context_alpha high_prob_0.9-ground_truth_prob_0.75 pretrain-sample_low_prob_class_only-start_pos_7-flip_label in 51.

# Plot Accuracy based on Setting

In [24]:
variant_1 = "high_prob"
variant_2 = "ground_truth_prob"

variant_1_values = sorted([float(el) for el in stats[variant_1].unique()])
variant_2_values = sorted([float(el) for el in stats[variant_2].unique()])

In [25]:
stats[variant_1]

0         0.8
1         0.8
2         0.8
3         0.8
4         0.8
         ... 
196975    0.9
196976    0.9
196977    0.9
196978    0.9
196979    0.9
Name: high_prob, Length: 196980, dtype: object

In [26]:
stats_key = "p_iwl"
checkpoint_i = 10

os.makedirs(
    os.path.join(repo_path, "cc_utils/{}-{}-plots".format(variant_1, variant_2), variant_name),
    exist_ok=True,
)

for eval_names, plot_name, plot_title in zip(eval_namess, plot_names, plot_titles):
    tic = timeit.default_timer()

    agg_stats = dict()
    for (variant_1_value, variant_2_value) in product(variant_1_values, variant_2_values):
        agg_stats.setdefault(variant_1_value, dict())
        for eval_name in eval_names:
            agg_stats[variant_1_value].setdefault(eval_name, dict(mean=[], std=[]))
            curr_stats = stats.loc[
                (stats[variant_1] == str(variant_1_value))
                & (stats[variant_2] == str(variant_2_value))
                & (stats["eval_name"] == eval_name)
                & (stats["stats_key"] == stats_key)
            ]["stats"].to_list()

            agg_stats[variant_1_value][eval_name]["mean"].append(
                np.mean(curr_stats, axis=0)[checkpoint_i].item()
            )
            agg_stats[variant_1_value][eval_name]["std"].append(
                np.std(curr_stats, axis=0)[checkpoint_i].item()
            )
    
    fig, axes = plt.subplots(1, len(eval_names), figsize=(5 * len(eval_names), 5))
    for eval_i, eval_name in enumerate(eval_names):
        for variant_1_value in variant_1_values:
            ax = axes[eval_i]
            ax.plot(
                variant_2_values,
                np.array(agg_stats[variant_1_value][eval_name]["mean"]),
                label="{} - {}".format(variant_1, variant_1_value) if eval_i == 0 else "",
                linewidth=1,
                alpha=0.7,
            )
            ax.set_xlim(-0.1, 1.1)
            ax.set_ylim(-0.1, 1.1)
            ax.set_title(map_eval_name[eval_name])

        fig.supylabel("Accuracy")
        fig.supxlabel(variant_2)
        fig.suptitle(plot_title)
        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=5,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8", 
        )

    plt.savefig(
        os.path.join(repo_path, "cc_utils/{}-{}-plots".format(variant_1, variant_2), variant_name, "{}.png".format(plot_name)),
        format="png",
        bbox_inches="tight",
        dpi=600,
    )
    plt.close(fig)

    toc = timeit.default_timer()
    print("Done {} in {}s".format(eval_name, toc - tic))

Done pretrain-sample_low_prob_class_only-start_pos_0 in 3.436614334001206s
Done pretrain-sample_low_prob_class_only-start_pos_0-flip_label in 3.40101758298988s
Done pretrain-sample_low_prob_class_only-start_pos_1 in 3.377094750001561s
Done pretrain-sample_low_prob_class_only-start_pos_1-flip_label in 3.4244206250004936s
Done pretrain-sample_low_prob_class_only-start_pos_7 in 3.3892472079896834s
Done pretrain-sample_low_prob_class_only-start_pos_7-flip_label in 3.4258552500104997s


# Individual Check

In [11]:
stats["stats_key"].unique()

array(['accuracy', 'p_iwl', 'context contains query class', 'ic_pred',
       'iw_pred', 'num p_iwl >= 0.5',
       'p_iwl given context contains query class'], dtype=object)

In [12]:
stats.loc[
    # (stats["variant"] == "high_prob_0.8-ground_truth_prob_0.0")
    # & (stats["eval_name"] == "pretrain-sample_low_prob_class_only-start_pos_0")
    # & (stats["seed"] == "seed_0")
    # & (stats["stats_key"] == "p_iwl")
    (stats["eval_name"] == "pretrain-sample_low_prob_class_only-start_pos_0")
    & (stats["seed"] == "seed_0")
    & (stats["stats_key"] == "iw_pred")
]

Unnamed: 0,variant,seed,eval_name,stats_key,stats,high_prob,ground_truth_prob
2468,high_prob_0.75-ground_truth_prob_0.25,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.27000001072883606, 0.23100000619888306, 0.2...",0.75,0.25
2937,high_prob_0.75-ground_truth_prob_1.0,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.75,1.0
12317,high_prob_0.75-ground_truth_prob_0.1,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.10000000149011612, 0.08800000697374344, 0.1...",0.75,0.1
12786,high_prob_0.8-ground_truth_prob_0.9,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.10000000149011612, 0.08800000697374344, 0.1...",0.8,0.9
13255,high_prob_0.8-ground_truth_prob_0.25,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.27000001072883606, 0.23100000619888306, 0.2...",0.8,0.25
13724,high_prob_0.8-ground_truth_prob_0.1,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.10000000149011612, 0.08800000697374344, 0.1...",0.8,0.1
15600,high_prob_0.67-ground_truth_prob_0.75,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.27000001072883606, 0.23100000619888306, 0.2...",0.67,0.75
33422,high_prob_0.9-ground_truth_prob_0.0,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.9,0.0
39988,high_prob_0.67-ground_truth_prob_0.5,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.5270000100135803, 0.4970000088214874, 0.537...",0.67,0.5
46554,high_prob_0.75-ground_truth_prob_0.5,seed_0,pretrain-sample_low_prob_class_only-start_pos_0,iw_pred,"[0.5270000100135803, 0.4970000088214874, 0.537...",0.75,0.5
