In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
DATA_PATH = "./logs/run3"
DATASET = "Atari"
MODEL_NAME = "NatureDQNNetwork"
NUM_EXPERIENCES = 5

In [None]:
# all combinations of training
run2name = {
    "n_nst_npp" : "Naive_NoSelfTraining_NoPostProcessing",
    "n_st_npp" : "Naive_SelfTraining_NoPostProcessing",
    "n_nst_pp" : "Naive_NoSelfTraining_PostProcessing",
    # "n_nst_pp_md" : "Naive_NoSelfTraining_PostProcessing_MixedData",

    "r_nst_npp" : "Replay_NoSelfTraining_NoPostProcessing",
    "r_st_npp" : "Replay_SelfTraining_NoPostProcessing",
    "r_nst_pp" : "Replay_NoSelfTraining_PostProcessing",
    # "r_nst_pp_md" : "Replay_NoSelfTraining_PostProcessing_MixedData",
    
    "j_nst_npp" : "JointTraining_NoSelfTraining_NoPostProcessing",
    "j_st_npp" : "JointTraining_SelfTraining_NoPostProcessing",
    "j_nst_pp" : "JointTraining_NoSelfTraining_PostProcessing",
    # "j_nst_pp_md" : "JointTraining_NoSelfTraining_PostProcessing_MixedData"
}

In [None]:
run2label = {
    "n_nst_npp" : "Naive",
    "n_st_npp" : "Naive_SelfTraining",
    "n_nst_pp" : "Naive_PostProcessing",
    "n_nst_pp_md" : "Naive_PostProcessing_MixedData",

    "r_nst_npp" : "Replay",
    "r_st_npp" : "Replay_SelfTraining",
    "r_nst_pp" : "Replay_PostProcessing",
    "r_nst_pp_md" : "Replay_PostProcessing_MixedData",
    
    "j_nst_npp" : "JointTraining",
    "j_st_npp" : "JointTraining_SelfTraining",
    "j_nst_pp" : "JointTraining_PostProcessing",
    "j_nst_pp_md" : "JointTraining_PostProcessing_MixedData"
}

In [None]:
running_accuracy = []
running_ece = []
final_accuracy = []
final_ece = []
bins = None
ece_hist_vals = []

for k, name in run2name.items():
    print(f">> {name} <<")
    with open(f"{DATA_PATH}/{DATASET}_{MODEL_NAME}_{name}_dict", "rb") as file:
        data = pickle.load(file)

        # print("---- ACCURACY ----")

        metric_str = "Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp"
        m = []
        for i in range(len(data)):
            cur_exp_dict = data[i]
            cur_exp_acc = 0
            for j in range(NUM_EXPERIENCES):
                # print(i, j, metric_str + f"{j:03d}", cur_exp_dict[metric_str + f"{j:03d}"])
                cur_exp_acc += cur_exp_dict[metric_str + f"{j:03d}"]
            m.append(cur_exp_acc/NUM_EXPERIENCES)
        
        # duplicate for JointTraining
        if len(m) < NUM_EXPERIENCES:
            m = m*5
        # print(k, m)
        running_accuracy.append((k, m))
        final_accuracy.append((k, running_accuracy[-1][-1][-1]))

        # print("---- ECE ----")

        metric_str = "ECE_Exp/eval_phase/test_stream/Task000/Exp"
        m = []
        for i in range(len(data)):
            cur_exp_dict = data[i]
            cur_exp_ece = 0
            for j in range(NUM_EXPERIENCES):
                # print(i, j, metric_str + f"{j:03d}", cur_exp_dict[metric_str + f"{j:03d}"])
                cur_exp_ece += cur_exp_dict[metric_str + f"{j:03d}"]
            m.append(cur_exp_ece/NUM_EXPERIENCES)
        
        # duplicate for JointTraining
        if len(m) < NUM_EXPERIENCES:
            m = m*5
        # print(k, m)
        running_ece.append((k, m))
        final_ece.append((k, running_ece[-1][-1][-1]))

        # print("---- ECE HISTOGRAMS ----")

        metric_str = "ExpECEHistogram/eval_phase/test_stream/Exp"
        m = []
        i = 4 if "Joint" not in name else 0
        cur_exp_dict = data[i]
        for j in range(NUM_EXPERIENCES):
            # print(i, j, metric_str + f"{j:03d}", cur_exp_dict[metric_str + f"{j:03d}"], type(cur_exp_dict[metric_str + f"{j:03d}"]))
            fig = cur_exp_dict[metric_str + f"{j:03d}"]
            axes_list = fig.get_axes()
            for ax in axes_list:
                for line in ax.get_lines()[-1:]:
                    x_data = line.get_xdata()
                    y_data = line.get_ydata()
                    # print({'x': x_data, 'y': y_data})
                    if bins is None:
                        bins = x_data
                    m.append(y_data)
        # print(bins, m)
        bin_vals = []
        for i in range(len(bins)):
            x = []
            for j in range(len(m)):
                # print(j, i, m[j][i])
                x.append(m[j][i])
            mean = np.mean(x)
            std = np.std(x)
            bin_vals.append((mean, std))
        # print(bins, bin_vals)
        ece_hist_vals.append((k, bin_vals))


In [None]:
# PLOT 1: Average accuracy on all experiences after training on exp j

plt.figure(figsize=(7, 6))
x_axis = list(range(NUM_EXPERIENCES))
for name, vals in running_accuracy:
    plt.plot(x_axis, vals, label=run2label[name])
plt.title('Average Experience Accuracy')
plt.xlabel('#TrainedExperience')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.xlim(0, NUM_EXPERIENCES-1)
plt.xticks(x_axis, x_axis)
plt.legend(loc='upper right', fontsize='small', ncol=2)
plt.show()

In [None]:
# PLOT 2: Average ece on all experiences after training on exp j

plt.figure(figsize=(7, 6))
x_axis = list(range(NUM_EXPERIENCES))
for name, vals in running_ece:
    plt.plot(x_axis, vals, label=run2label[name])
plt.title('Average Experience ECE')
plt.xlabel('#TrainedExperience')
plt.ylabel('ECE')
plt.ylim(0, 1)
plt.xlim(0, NUM_EXPERIENCES-1)
plt.xticks(x_axis, x_axis)
plt.legend(loc='upper right', fontsize='small', ncol=2)
plt.show()

In [None]:
# TABLE : average accuracy/ece on all experiences at the end of training

table_data = []
for (n, acc), (_, ece) in zip(final_accuracy, final_ece):
    table_data.append((run2label[n], round(acc, 2), round(ece,2)))

dt = pd.DataFrame(table_data, columns=["RunName", "Accuracy", "ECE"])
print(dt)

In [None]:
# HISTOGRAM : avg/std across all experiences at the end of training

valid_colors = ['green', 'red', 'cyan', 'magenta', 'black', 'purple', 'orange', 'brown', 'gray', 'olive', 'indigo', 'turquoise']

fig, axs = plt.subplots(3, 4, figsize=(12, 8))
axs = axs.flatten()
for i, (name, vals) in enumerate(ece_hist_vals):
    m = [e[0] for e in vals]
    l = [e[0] - e[1] for e in vals]
    u = [e[0] + e[1] for e in vals]
    axs[i].plot([0, 1], [0, 1], '--', label='ideal')
    axs[i].plot(bins, m, label=run2label[name], color=valid_colors[i])
    axs[i].fill_between(bins, l, u, alpha=0.3, linestyle='--', color=valid_colors[i])
    axs[i].legend(loc='upper left', fontsize='small')
plt.tight_layout()
plt.show()