In [1]:
import _pickle as pickle
import matplotlib.pyplot as plt
import numpy as np
import os
import timeit

In [2]:
map_eval_name = {
    "pretrain-sample_high_prob_class_only-start_pos_0": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_0": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_0-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_0-flip_label": "Condition on Low Frequency",
    
    "pretrain-sample_high_prob_class_only-start_pos_1": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_1": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_1-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_1-flip_label": "Condition on Low Frequency",
    
    "pretrain-sample_high_prob_class_only-start_pos_7": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_7": "Condition on Low Frequency",
    "pretrain-sample_high_prob_class_only-start_pos_7-flip_label": "Condition on High Frequency",
    "pretrain-sample_low_prob_class_only-start_pos_7-flip_label": "Condition on Low Frequency",
}

stats_keys = [
    "accuracy",
    "p_iwl",
    "diff last context",
]

In [3]:
# variant_name = "simple_icl-default"
# prefixes = ["high_prob_0.5", "high_prob_0.99"]

variant_name = "simple_icl-g-high_prob_0.99"
prefixes = ["ground_truth_prob_0.0", "ground_truth_prob_0.5", "ground_truth_prob_0.75", "ground_truth_prob_0.9"]


results_dir = "/home/chanb/scratch/simple_icl/results/{}".format(variant_name)
eval_namess = [
    [
        "pretrain-sample_high_prob_class_only-start_pos_0",
        "pretrain-sample_low_prob_class_only-start_pos_0",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_0-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_0-flip_label",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_1",
        "pretrain-sample_low_prob_class_only-start_pos_1",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_1-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_1-flip_label",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_7",
        "pretrain-sample_low_prob_class_only-start_pos_7",
    ],
    [
        "pretrain-sample_high_prob_class_only-start_pos_7-flip_label",
        "pretrain-sample_low_prob_class_only-start_pos_7-flip_label",
    ],
]
plot_names = [
    "iwl",
    "iwl-flip_label",
    "icl-last_context",
    "icl-last_context-flip_label",
    "icl-except_first_context",
    "icl-except_first_context-flip_label",
]
plot_titles = [
    "In-weight Evaluation",
    "In-weight Evaluation with Flipped Label",
    "In-context Evaluation with Last Context",
    "In-context Evaluation with Last Context + Flipped Label",
    "In-context Evaluation with Contexts but First",
    "In-context Evaluation with Contexts but First + Flipped Label",
]

In [None]:
for prefix in prefixes:
    run_names = [os.path.join(results_dir, run_name) for run_name in os.listdir(results_dir) if run_name.startswith(prefix)]

    for eval_names, plot_name, plot_title in zip(eval_namess, plot_names, plot_titles):
        tic = timeit.default_timer()
        stats = dict()
        for run_i, run_name in enumerate(run_names):
            data = pickle.load(open(os.path.join(run_name, "evaluation.pkl"), "rb"))
            for eval_name in eval_names:
                stats.setdefault(eval_name, dict())
                for stats_key in stats_keys:
                    stats[eval_name].setdefault(stats_key, np.zeros((len(run_names), len(data["checkpoint_steps"]))))
                    stats[eval_name][stats_key][run_i] = data["stats"][eval_name][stats_key]

        agg_stats = dict()
        for eval_name in stats:
            agg_stats.setdefault(eval_name, dict())
            for stats_key in stats[eval_name]:
                agg_stats[eval_name][stats_key] = (
                    np.mean(stats[eval_name][stats_key], axis=0),
                    np.std(stats[eval_name][stats_key], axis=0),
                )

        fig, axes = plt.subplots(1, len(eval_names), figsize=(5 * len(eval_names), 5))

        for eval_i, eval_name in enumerate(eval_names):
            ax = axes[eval_i]
            ax.plot(
                data["checkpoint_steps"],
                np.array(agg_stats[eval_name]["accuracy"][0]) / 100.0,
                label="Accuracy" if eval_i == 0 else "",
                linewidth=3,
                c="red",
                alpha=0.7,
            )
            ax.plot(
                data["checkpoint_steps"],
                agg_stats[eval_name]["p_iwl"][0],
                label="$\\alpha(x)$" if eval_i == 0 else "",
                linestyle="--",
                c="black",
                alpha=0.3
            )
            ax.plot(
                data["checkpoint_steps"],
                data["stats"][eval_name]["diff last context"],
                label="% Diff. Last Context." if eval_i == 0 else "",
                linestyle="-.",
                c="black",
                alpha=0.3
            )
            ax.set_ylim(-0.1, 1.1)
            ax.set_title(map_eval_name[eval_name])

        fig.supylabel("Accuracy/Prob.")
        fig.supxlabel("Gradient Steps")
        fig.suptitle(plot_title)
        fig.legend(
            bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
            loc="lower center",
            ncols=5,
            borderaxespad=0.0,
            frameon=True,
            fontsize="8", 
        )

        plt.savefig(
            os.path.join("/home/chanb/src/simple_icl/cc_utils/plots", "{}-{}-{}.png".format(variant_name, prefix, plot_name)),
            format="png",
            bbox_inches="tight",
            dpi=600,
        )
        toc = timeit.default_timer()
        print("Done {} {} in {}s".format(prefix, eval_name, toc - tic))

Done ground_truth_prob_0.0 pretrain-sample_low_prob_class_only-start_pos_0 in 98.3193833399564s
Done ground_truth_prob_0.0 pretrain-sample_low_prob_class_only-start_pos_0-flip_label in 92.18565797200426s
Done ground_truth_prob_0.0 pretrain-sample_low_prob_class_only-start_pos_1 in 92.24864241096657s
Done ground_truth_prob_0.0 pretrain-sample_low_prob_class_only-start_pos_1-flip_label in 92.05579636094626s
Done ground_truth_prob_0.0 pretrain-sample_low_prob_class_only-start_pos_7 in 90.76008329202887s
Done ground_truth_prob_0.0 pretrain-sample_low_prob_class_only-start_pos_7-flip_label in 91.16979015001561s
Done ground_truth_prob_0.5 pretrain-sample_low_prob_class_only-start_pos_0 in 93.50471578491852s
Done ground_truth_prob_0.5 pretrain-sample_low_prob_class_only-start_pos_0-flip_label in 91.37264405505266s
Done ground_truth_prob_0.5 pretrain-sample_low_prob_class_only-start_pos_1 in 91.10508855199441s
Done ground_truth_prob_0.5 pretrain-sample_low_prob_class_only-start_pos_1-flip_labe

In [None]:
assert 0

In [None]:
# TODO: Check P(g(x) = 0) = 0 with ICL flip label, see if we get 50/50
prefix = "ground_truth_prob_0.0"
eval_names = [
    "pretrain-sample_high_prob_class_only-start_pos_1-flip_label",
    "pretrain-sample_low_prob_class_only-start_pos_1-flip_label",
]

In [None]:
run_names = [os.path.join(results_dir, run_name) for run_name in os.listdir(results_dir) if run_name.startswith(prefix)]

In [None]:
data["stats"][eval_name].keys()

In [None]:
import jax

In [None]:
eval_name

In [None]:
jax.nn.softmax(data["stats"][eval_name]["similarity"][0] / 0.1)

In [None]:
len(data["stats"][eval_name]["ic_pred"])

In [None]:
data["stats"][eval_name]["ic_pred"][0], data["stats"][eval_name]["iw_pred"][0], data["stats"][eval_name]["p_iwl"][0]

In [None]:
stats = dict()
for run_i, run_name in enumerate(run_names):
    data = pickle.load(open(os.path.join(run_name, "evaluation.pkl"), "rb"))
    for eval_name in eval_names:
        stats.setdefault(eval_name, dict())
        for stats_key in stats_keys:
            stats[eval_name].setdefault(stats_key, np.zeros((len(run_names), len(data["checkpoint_steps"]))))
            stats[eval_name][stats_key][run_i] = data["stats"][eval_name][stats_key]

agg_stats = dict()
for eval_name in stats:
    agg_stats.setdefault(eval_name, dict())
    for stats_key in stats[eval_name]:
        agg_stats[eval_name][stats_key] = (
            np.mean(stats[eval_name][stats_key], axis=0),
            np.std(stats[eval_name][stats_key], axis=0),
        )

In [None]:
agg_stats

In [None]:
fig, axes = plt.subplots(1, len(eval_names), figsize=(5 * len(eval_names), 5))

for eval_i, eval_name in enumerate(eval_names):
    ax = axes[eval_i]
    ax.plot(
        data["checkpoint_steps"],
        np.array(agg_stats[eval_name]["accuracy"][0]) / 100.0,
        label="Accuracy" if eval_i == 0 else "",
        linewidth=3,
        c="red",
        alpha=0.7,
    )
    ax.plot(
        data["checkpoint_steps"],
        agg_stats[eval_name]["p_iwl"][0],
        label="$\\alpha(x)$" if eval_i == 0 else "",
        linestyle="--",
        c="black",
        alpha=0.3
    )
    ax.plot(
        data["checkpoint_steps"],
        data["stats"][eval_name]["diff last context"],
        label="% Diff. Last Context." if eval_i == 0 else "",
        linestyle="-.",
        c="black",
        alpha=0.3
    )
    ax.set_ylim(-0.1, 1.1)
    ax.set_title(map_eval_name[eval_name])

fig.supylabel("Accuracy/Prob.")
fig.supxlabel("Gradient Steps")
fig.suptitle(plot_title)
fig.legend(
    bbox_to_anchor=(0.0, 1.0, 1.0, 0.0),
    loc="lower center",
    ncols=5,
    borderaxespad=0.0,
    frameon=True,
    fontsize="8",
)
plt.show()