In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import seaborn as sns

In [None]:
TICK_SIZE=26
LABEL_SIZE=32

In [None]:
#set it as a personal checkpoint path, e.g. detault one should be "./checkpoint/"
checkpoint_dir = "./checkpoint/"

In [None]:
#results for anti-correlated error
task_up_1 = "uu-ent-up1"
task_up_2 = "uu-ent-up2"
task_down = "uu-ent-down"
seed = 0
para_to_vary_model_list = np.arange(10, 100.1, 10)
num_para = len(para_to_vary_model_list)
task_up_1_loss = np.zeros(num_para+ 1)
task_up_2_loss = np.zeros(num_para + 1)
repre = "preds"
for i, para_to_vary_model in enumerate(para_to_vary_model_list):
    checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_resnet18_task_{task_up_1}_upstream_setting_None_para_to_vary_model_{para_to_vary_model}_seed_{seed}.pth"
    checkpoint = torch.load(checkpoint_path)
    task_up_1_loss[i] = checkpoint["loss"]
    del checkpoint
    
for i, para_to_vary_model in enumerate(para_to_vary_model_list):
    checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_resnet18_task_{task_up_1}_upstream_setting_None_para_to_vary_model_{para_to_vary_model}_seed_{seed}.pth"
    checkpoint = torch.load(checkpoint_path)
    task_up_2_loss[i] = checkpoint["loss"]
    del checkpoint

task_down_loss_mat = np.zeros([num_para, num_para])
for i, para_to_vary_model_1 in enumerate(para_to_vary_model_list):
    for j, para_to_vary_model_2 in enumerate(para_to_vary_model_list):
        checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_linear_task_{task_down}_upstream_setting_{task_up_1}_{para_to_vary_model_1}_resnet18_{seed}_preds_last_{task_up_2}_{para_to_vary_model_2}_resnet18_{seed}_preds_last_para_to_vary_model_None_seed_{seed}.pth"
        checkpoint = torch.load(checkpoint_path)
        task_down_loss_mat[i, j] += checkpoint["loss"]
        del checkpoint

np.savez("results_for_figs/anti-correlated_error", loss_up_1=task_up_1_loss, loss_up_2=task_up_2_loss, loss_down=task_down_loss_mat)
    
results = np.load("results_for_figs/anti-correlated_error.npz")
loss_up_1 = results["loss_up_1"]
loss_up_2 = results["loss_up_2"]
loss_down = results["loss_down"]

fig, ax = plt.subplots(figsize=(10*0.9,9*0.9))
hmap = ax.imshow(loss_down, cmap="YlOrRd")
cbar = plt.colorbar(hmap)
cbar.set_label("Downstream test loss ($\ell_w$)", size=LABEL_SIZE, fontweight="bold")
for l in cbar.ax.yaxis.get_ticklabels():
    l.set_fontsize(TICK_SIZE)

ax.set_ylabel("Upstream test loss ($\ell_u$)", fontsize=LABEL_SIZE, fontweight="bold")
ax.set_xlabel("Upstream test loss ($\ell_v$)", fontsize=LABEL_SIZE, fontweight="bold")

# #dense ticks
# ax.set_yticks(np.arange(len(loss_up_1)))
# ax.set_yticklabels(['%.2f' % a for a in loss_up_1], rotation=0, fontsize=TICK_SIZE)
# ax.set_xticks(np.arange(len(loss_up_2)))
# ax.set_xticklabels(['%.2f' % a for a in loss_up_2], rotation=45, fontsize=TICK_SIZE)

#sparse ticks
ax.set_yticks(np.arange( int((len(loss_up_1)-1)) / 2) * 2)
ax.set_yticklabels(['%.2f' % a for a in loss_up_1[np.arange(0, len(loss_up_1)-1, 2)]], rotation=0, fontsize=TICK_SIZE)
ax.set_xticks(np.arange( int((len(loss_up_2)-1)) / 2) * 2)
ax.set_xticklabels(['%.2f' % a for a in loss_up_2[np.arange(0, len(loss_up_2)-1, 2)]], rotation=45, fontsize=TICK_SIZE)
ax.invert_yaxis()
#plt.show()
plt.savefig('results_for_figs/anti-correlated_error.pdf', bbox_inches = 'tight')

In [None]:
#results for data-distribution mismatch
task_up = "dd-mis-up"
task_down = "dd-mis-down"
repre = "feat"
model_specify = "last"
para_to_vary_model_list = np.arange(10, 100.1, 10)
num_para = len(para_to_vary_model_list)
num_seed = 1

model = "resnet18"
task_up_loss = np.zeros(num_para)
task_down_loss = np.zeros(num_para)
task_up_loss_std = np.zeros(num_para)
task_down_loss_std = np.zeros(num_para)

for i, para_to_vary_model in enumerate(para_to_vary_model_list):
    para_to_vary_model = float(para_to_vary_model)
    loss = np.zeros([num_seed])
    for seed in range(num_seed):
        checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_{model}_task_{task_up}_upstream_setting_None_para_to_vary_model_{para_to_vary_model}_seed_{seed}.pth"
        checkpoint = torch.load(checkpoint_path)
        loss[seed] = (checkpoint["knn_loss"])
        del checkpoint
    task_up_loss[i] = loss.mean()
    task_up_loss_std[i] = loss.std()


for i, para_to_vary_model in enumerate(para_to_vary_model_list):
    para_to_vary_model = float(para_to_vary_model)
    loss = np.zeros([num_seed])
    for seed in range(num_seed):
        checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_linear_task_{task_down}_upstream_setting_{task_up}_{para_to_vary_model}_{model}_{seed}_{repre}_{model_specify}_para_to_vary_model_None_seed_{seed}.pth"
        checkpoint = torch.load(checkpoint_path)
        loss[seed] = checkpoint["loss"]
        del checkpoint
    task_down_loss[i] = loss.mean()
    task_down_loss_std[i] = loss.std()
    
np.savez("results_for_figs/data-distribution_mismatch", loss_up_mean=task_up_loss, loss_up_std=task_up_loss_std, loss_down_mean=task_down_loss, loss_down_std=task_down_loss_std)

results = np.load("results_for_figs/data-distribution_mismatch.npz")
loss_up_mean = results["loss_up_mean"]
loss_up_std = results["loss_up_std"]
loss_down_mean = results["loss_down_mean"]
loss_down_std = results["loss_down_std"]

fig, ax = plt.subplots(figsize=(10,8))
caps_set = [None, None]
(_, caps_set[0], _) = ax.errorbar(np.arange(10, 100.1, 10), loss_up_mean, yerr=loss_up_std, marker=".", label="Upstream model", capsize=8, lw=4, color="C0")
ax.set_xlabel("Training subset (p%)", fontsize=LABEL_SIZE, fontweight="bold")
ax.set_ylabel("Upstream test loss", fontsize=LABEL_SIZE, fontweight="bold")
ax.set_xticks(np.arange(10, 100.1, 10).astype(np.int))
ax.set_xticklabels(np.arange(10, 100.1, 10).astype(np.int), fontsize=TICK_SIZE)#, rotation=45
ax.tick_params(axis='y', labelsize=TICK_SIZE)

ax2 = ax.twinx()
(_, caps_set[1], _) = ax2.errorbar(np.arange(10, 100.1, 10), loss_down_mean, yerr=loss_down_std, marker=".", label="downstream", capsize=8, lw=4, color="C1")
ax2.set_ylabel("Downstream test loss", fontsize=LABEL_SIZE, fontweight="bold")
ax2.tick_params(axis='y', labelsize=TICK_SIZE)
ax2.locator_params(axis='y', nbins=5)

#caps size for error bar
for caps in caps_set:
    for cap in caps:
        cap.set_markeredgewidth(2)

plt.xticks(fontsize=TICK_SIZE)
plt.yticks(fontsize=TICK_SIZE)

ax.set_axisbelow(True)
ax.grid()
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_linewidth(3)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_linewidth(3)

plt.savefig("results_for_figs/data-distribution_mismatch.pdf", bbox_inches = 'tight')

In [None]:
#legend
figsize = (5, 1)
fig_leg = plt.figure(figsize=figsize)
legend_properties = {'weight': 'bold', 'size': LABEL_SIZE}
ax_leg = fig_leg.add_subplot(111)
#merge two legend
legend_handles_labels = (ax.get_legend_handles_labels()[0] + ax2.get_legend_handles_labels()[0], 
                         ax.get_legend_handles_labels()[1] + ax2.get_legend_handles_labels()[1])
ax_leg.axis('off')
fig_leg.savefig('results_for_figs/legend.pdf', bbox_inches = 'tight')

In [None]:
#results for loss-function mismatch
task_up = "lf-mis-up"
task_down = "lf-mis-down"
repre = "logits"
model_specify = "best"
para_to_vary_model_list = np.arange(0, 10, 1)
num_para = len(para_to_vary_model_list)
num_seed = 40

model = "resnet18"
task_up_loss = np.zeros(num_para)
task_down_loss = np.zeros(num_para)
task_up_loss_std = np.zeros(num_para)
task_down_loss_std = np.zeros(num_para)

for i, para_to_vary_model in enumerate(para_to_vary_model_list):
    para_to_vary_model = float(para_to_vary_model)
    loss = np.zeros([num_seed])
    for seed in range(num_seed):
        checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_{model}_task_{task_up}_upstream_setting_None_para_to_vary_model_{para_to_vary_model}_seed_{seed}.pth"
        checkpoint = torch.load(checkpoint_path)
        loss[seed] = (checkpoint["best_state"]["loss"])
        del checkpoint
    task_up_loss[i] = loss.mean()
    task_up_loss_std[i] = loss.std()


for i, para_to_vary_model in enumerate(para_to_vary_model_list):
    para_to_vary_model = float(para_to_vary_model)
    loss = np.zeros([num_seed])
    for seed in range(num_seed):
        checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_MLP-128_task_{task_down}_upstream_setting_{task_up}_{para_to_vary_model}_{model}_{seed}_{repre}_{model_specify}_para_to_vary_model_None_seed_{seed}.pth"
        checkpoint = torch.load(checkpoint_path)
        loss[seed] = checkpoint["loss"]
        del checkpoint
    task_down_loss[i] = loss.mean()
    task_down_loss_std[i] = loss.std()

np.savez("results_for_figs/loss-function_mismatch", loss_up_mean=task_up_loss, loss_up_std=task_up_loss_std, loss_down_mean=task_down_loss, loss_down_std=task_down_loss_std)

results = np.load("results_for_figs/loss-function_mismatch.npz")
loss_up_mean = results["loss_up_mean"]
loss_up_std = results["loss_up_std"]
loss_down_mean = results["loss_down_mean"]
loss_down_std = results["loss_down_std"]

fig, ax = plt.subplots(figsize=(10,8))
caps_set = [None, None]
(_, caps_set[0], _) = ax.errorbar(np.arange(10, 100.1, 10), loss_up_mean, yerr=loss_up_std, marker=".", label="upstream", capsize=8, lw=4, color="C0")
ax.set_xlabel("Rate of noisy examples (r)", fontsize=LABEL_SIZE, fontweight="bold")
ax.set_ylabel("Upstream test loss", fontsize=LABEL_SIZE, fontweight="bold")
ax.set_xticks(np.arange(10, 100.1, 10).astype(np.int))
ax.set_xticklabels(para_to_vary_model_list.astype(np.int), fontsize=TICK_SIZE)#, rotation=45
ax.tick_params(axis='y', labelsize=TICK_SIZE)

ax2 = ax.twinx()
(_, caps_set[1], _) = ax2.errorbar(np.arange(10, 100.1, 10), loss_down_mean, yerr=loss_down_std, marker=".", label="downstream", capsize=8, lw=4, color="C1")
ax2.set_ylabel("Downstream test loss", fontsize=LABEL_SIZE, fontweight="bold")
ax2.tick_params(axis='y', labelsize=TICK_SIZE)
ax2.locator_params(axis='y', nbins=5)

#caps size for error bar
for caps in caps_set:
    for cap in caps:
        cap.set_markeredgewidth(2)

plt.xticks(fontsize=TICK_SIZE)
plt.yticks(fontsize=TICK_SIZE)

ax.set_axisbelow(True)
ax.grid()

ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_linewidth(3)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_linewidth(3)

#plt.show()
plt.savefig('results_for_figs/loss-function_mismatch.pdf', bbox_inches = 'tight')

In [None]:
#results for hypothesis-space mismatch
num_seed = 1
task_up = "hs-mis-up"
task_down = "hs-mis-down"
repre = "feat"
model_specify = "last"

model_set = ["convnet", "convnet-512", "convnet-512-256", "convnet-512-256-128", "convnet-512-256-128-64"]
depth_set = ["1", "2", "3", "4", "5"]

num_model = len(model_set)
task_up_loss = np.zeros(num_model)
task_down_loss = np.zeros(num_model)
task_up_loss_std = np.zeros(num_model)
task_down_loss_std = np.zeros(num_model)

for i, model in enumerate(model_set):
    loss = np.zeros([num_seed])
    for seed in range(num_seed):
        checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_{model}_task_{task_up}_upstream_setting_None_para_to_vary_model_None_seed_{seed}.pth"
        checkpoint = torch.load(checkpoint_path)
        loss[seed] = checkpoint["knn_loss"]
        del checkpoint
    task_up_loss[i] = loss.mean()
    task_up_loss_std[i] = loss.std()

for i, model in enumerate(model_set):
    loss = np.zeros([num_seed])
    for seed in range(num_seed):
        checkpoint_path = f"{checkpoint_dir}/checkpoint/svhn_linear_task_{task_down}_upstream_setting_{task_up}_None_{model}_{seed}_{repre}_{model_specify}_para_to_vary_model_None_seed_{seed}.pth"
        checkpoint = torch.load(checkpoint_path)
        loss[seed] = checkpoint["loss"]
        del checkpoint
    task_down_loss[i] = loss.mean()
    task_down_loss_std[i] = loss.std()

np.savez("results_for_figs/hypothesis-space_mismatch", loss_up_mean=task_up_loss, loss_up_std=task_up_loss_std, loss_down_mean=task_down_loss, loss_down_std=task_down_loss_std)

results = np.load("results_for_figs/hypothesis-space_mismatch.npz")
loss_up_mean = results["loss_up_mean"]
loss_up_std = results["loss_up_std"]
loss_down_mean = results["loss_down_mean"]
loss_down_std = results["loss_down_std"]

depth_set = ["1", "2", "3", "4", "5"]

fig, ax = plt.subplots(figsize=(10,8))
caps_set = [None, None]
(_, caps_set[0], _) = ax.errorbar(range(len(depth_set)), loss_up_mean, yerr=loss_up_std, marker=".", label="upstream", capsize=8, lw=4, color="C0")

ax.set_xticks(range(len(depth_set)))
ax.set_xticklabels(depth_set, fontsize=TICK_SIZE)#, rotation=45
ax.set_ylabel("Upstream Test loss", fontsize=LABEL_SIZE, fontweight="bold")
ax.set_xlabel("Number of layers in MLP", fontsize=LABEL_SIZE, fontweight="bold")
ax.tick_params(axis='y', labelsize=TICK_SIZE)

ax2 = ax.twinx()
(_, caps_set[1], _) = ax2.errorbar(range(len(depth_set)), loss_down_mean, yerr=loss_down_std, marker=".", label="downstream", capsize=8, lw=4, color="C1")

for caps in caps_set:
    for cap in caps:
        cap.set_markeredgewidth(2)

ax2.set_ylabel("Downstream test loss", fontsize=LABEL_SIZE, fontweight="bold")
ax2.tick_params(axis='y', labelsize=TICK_SIZE)
ax2.locator_params(axis='y', nbins=5)

plt.xticks(fontsize=TICK_SIZE)
plt.yticks(fontsize=TICK_SIZE)

ax.set_axisbelow(True)
ax.grid()
ax.spines['right'].set_linewidth(3)
ax.spines['left'].set_linewidth(3)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_linewidth(3)

#plt.show()
plt.savefig('hypothesis-space_mismatch.pdf', bbox_inches = 'tight')