In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from avalanche.training.supervised import DER, Replay
from avalanche.benchmarks.classic import SplitCIFAR100, SplitTinyImageNet
from ResNet18 import resnet18
from avalanche.training.plugins import EvaluationPlugin
from ECE_metrics import ExperienceECE, ExpECEHistogram
from avalanche.evaluation.metrics import accuracy_metrics
from torch.nn import CrossEntropyLoss

In [None]:
device = "cuda:3"
num_bins = 16
# num_classes, benchmark = 100, SplitCIFAR100(n_experiences=10)
num_classes, benchmark = 200, SplitTinyImageNet(n_experiences=10)
test_stream = benchmark.test_stream

In [None]:
model_name = "ResNet18"
model = resnet18(num_classes)
# model_base_path = "./logs/SplitCIFAR100_DER_2000_2/model_DER_NoSelfTraining_NoPostProcessing1"
# model_base_path = "./logs/TinyImageNet_D2/model_DER_NoSelfTraining_NoPostProcessing2"

model_base_path = "./logs/TinyImageNet_R2/model_Replay_NoSelfTraining_NoPostProcessing2"

In [None]:
plugins = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    ExperienceECE(num_bins=num_bins),
    ExpECEHistogram(num_bins=num_bins),
)

# strategy = DER(
#     model,
#     optimizer=None,
#     device=device,
#     evaluator=plugins,
#     mem_size=2000,
#     batch_size_mem=32,
#     alpha=0.1,
#     beta=0.5,
#     eval_mb_size=32,
# )

strategy = Replay(
    model,
    optimizer=None,
    criterion=CrossEntropyLoss(),
    device=device,
    evaluator=plugins,
    mem_size=2000,
    eval_mb_size=32,
)

In [None]:
results = []
for experience in range(len(test_stream)):
    print("Computing experience", experience)
    print("Loading model:", f"{model_base_path}_exp{experience}.pt")
    strategy.model.load_state_dict(torch.load(f"{model_base_path}_exp{experience}.pt", map_location=device))
    results.append(strategy.eval(benchmark.test_stream))

In [None]:
running_accuracy = []
running_ece = []
final_accuracy = []
final_ece = []
bins = None
ece_hist_vals = []
mafia = []

metric_str = "Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp"
m = []
for i in range(len(results)):
    cur_exp_dict = results[i]
    cur_exp_acc = 0
    # compute the average over the experiences trained so far (i)
    for j in range(i+1):
        print(i, j, metric_str + f"{j:03d}", cur_exp_dict[metric_str + f"{j:03d}"])
        cur_exp_acc += cur_exp_dict[metric_str + f"{j:03d}"]
    m.append(cur_exp_acc/(i+1))

# duplicate for JointTraining
if len(m) < 10:
    m = m*10
running_accuracy.append(m)
final_accuracy.append(running_accuracy[-1][-1])

metric_str = "ECE_Exp/eval_phase/test_stream/Task000/Exp"
m = []
for i in range(len(results)):
    cur_exp_dict = results[i]
    cur_exp_ece = 0
    # compute the average over the experiences trained so far (i)
    for j in range(i+1):
        print(i, j, metric_str + f"{j:03d}", cur_exp_dict[metric_str + f"{j:03d}"])
        cur_exp_ece += cur_exp_dict[metric_str + f"{j:03d}"]
        if i == j:
            mafia.append(cur_exp_dict[metric_str + f"{j:03d}"])
    # m.append(cur_exp_ece/(i+1))
    m.append((cur_exp_ece/(i+1)))

# duplicate for JointTraining
if len(m) < 10:
    m = m*10
running_ece.append(m)
final_ece.append(running_ece[-1][-1])
# print(k, m, running_ece[-1][-1][-1])

metric_str = "ExpECEHistogram/eval_phase/test_stream/Exp"
m = []
i = -1 # after last experience
cur_exp_dict = results[i]
for j in range(10):
    print(i, j, metric_str + f"{j:03d}", cur_exp_dict[metric_str + f"{j:03d}"])
    fig = cur_exp_dict[metric_str + f"{j:03d}"]
    axes_list = fig.get_axes()
    for ax in axes_list:
        for line in ax.get_lines()[-1:]:
            x_data = line.get_xdata()
            y_data = line.get_ydata()
            # print({'x': x_data, 'y': y_data})
            if bins is None:
                bins = x_data
            m.append(y_data)
# print(bins, m)
bin_vals = []
for i in range(len(bins)):
    x = []
    for j in range(len(m)):
        # print(j, i, m[j][i])
        x.append(m[j][i])
    mean = np.mean(x)
    std = np.std(x)
    bin_vals.append((mean, std))
# print(k, bin_vals)
ece_hist_vals.append(bin_vals)

In [None]:
plt.figure(figsize=(7, 6))
x_axis = list(range(1, 10+1))
for i, vals in enumerate(running_accuracy):
    plt.plot(x_axis, vals)
plt.title('Average Experience Accuracy')
plt.xlabel('#Trained Experience')
plt.ylabel('Accuracy')
plt.ylim(0, 1)
plt.xlim(1, 10)
plt.xticks(x_axis, x_axis)
# plt.legend(loc='upper right', fontsize='small', ncol=2)
# plt.savefig(f'./imgs/{DATASET}_{NUM_EXPERIENCES}_{str.lower("Average_Experience_Accuracy")}.png', dpi=400)
plt.show()

In [None]:
plt.figure(figsize=(7, 6))
x_axis = list(range(1, 10+1))
for i, vals in enumerate(running_ece):
    plt.plot(x_axis, vals)
plt.title('Average Experience ECE')
plt.xlabel('#Trained Experience')
plt.ylabel('ECE')
plt.ylim(0, 1)
plt.xlim(1, 10)
plt.xticks(x_axis, x_axis)
plt.grid(True, linestyle='--')
# plt.legend(loc='upper right', fontsize='small', ncol=2)
# plt.savefig(f'./imgs/{DATASET}_{NUM_EXPERIENCES}_{str.lower("Average_Experience_ECE")}.png', dpi=400)
plt.show()

In [None]:
x = [i for i in range(1, 11)]
plt.plot(x, mafia)
plt.xlabel('#Trained Experience')
plt.ylabel('ECE')
plt.ylim(0, 1)
plt.xlim(1, 10)
plt.xticks(x, x)
plt.grid(True, linestyle='--')
plt.show()

In [None]:
final_accuracy, final_ece

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(12, 8))
axs = axs.flatten()
for i, vals in enumerate(ece_hist_vals):
    m = [e[0] for e in vals]
    s = [e[1] for e in vals]
    l = [max(e[0] - e[1], 0) for e in vals] # cap lower-bound at zero
    u = [e[0] + e[1] for e in vals]
    axs[i].plot([0, 1], [0, 1], '--', label='ideal')
    # axs[i].plot(bins, m, color=valid_colors[i])
    # axs[i].fill_between(bins, l, u, alpha=0.3, linestyle='--', color=valid_colors[i])
    axs[i].errorbar(bins, m, yerr=s, marker="o", linestyle="--", capsize=3, capthick=1)
    axs[i].set_ylim(-0.05, 1)
    axs[i].set_xlim(0, 1)
    axs[i].set_ylabel("Accuracy")
    axs[i].set_xlabel("Confidence")
    axs[i].grid(True, linestyle='--')
    # axs[i].legend(loc='upper left', fontsize='small')
    # axs[i].set_title(run2label[name])
plt.tight_layout()
# plt.savefig(f'./imgs/{DATASET}_{NUM_EXPERIENCES}_avg_std_calibration.png', dpi=400)
plt.show()