In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
import os
from pathlib import Path
import pickle
import timeit
#
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, models
from dotted_dict import DottedDict
from tqdm import tqdm
import pprint
#
from csprites.datasets import ClassificationDataset
import utils
from backbone import get_backbone
from optimizer import get_optimizer
from plot_utils import *

In [None]:
class Net(nn.Module):
    def __init__(self, n_classes, d_in, d_hid=1024, n_hid=0):
        super(Net, self).__init__()
        dims = [d_in]
        for _ in range(n_hid):
            dims.append(d_hid)
        dims.append(n_classes)
        #
        layers = []
        for idx in range(1, len(dims) - 1, 1):
            layers.append(nn.Linear(dims[idx - 1], dims[idx]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

def get_datasets(target_idx, p_R_train, p_R_valid, p_Y_train, p_Y_valid, batch_size):
    #
    R_train = torch.Tensor(np.load(p_R_train))
    R_valid = torch.Tensor(np.load(p_R_valid))
    #
    Y_train = torch.LongTensor(np.load(p_Y_train))
    Y_valid = torch.LongTensor(np.load(p_Y_valid))
    #
    d_r = R_train.shape[1]
    #
    Y_train = Y_train[:, target_idx]
    Y_valid = Y_valid[:, target_idx]
    #
    ds_train = torch.utils.data.TensorDataset(R_train, Y_train)
    ds_valid = torch.utils.data.TensorDataset(R_valid, Y_valid)

    dl_train = DataLoader(
        ds_train,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True)
    dl_valid = DataLoader(
        ds_valid,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True)
    return dl_train, dl_valid, d_r

def train_model(model, num_epochs, optimizer, criterion):
    stats = {
    'train': {
        'loss': [],
        'acc': [],
        'epoch': [],
    },
    'valid': {
        'loss': [],
        'acc': [],
        'epoch': [],
    }
    }
    stats = DottedDict(stats)
    desc_tmp = "Epoch [{:3}/{:3}] {}:"
    #
    for epoch_idx in range(1, num_epochs + 1, 1):
        ################
        # TRAIN
        ################
        model.train()
        epoch_step = 0
        epoch_loss = 0
        epoch_total = 0
        epoch_correct = 0
        #
        desc = desc_tmp.format(epoch_idx, num_epochs, 'train')
        pbar = tqdm(dl_train, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
        #
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)
            for param in model.parameters():
                param.grad = None
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            #
            _, y_pred = torch.max(out, 1)
            total = y.size(0)
            correct = (y_pred == y).sum().item()
            #
            epoch_loss += loss.item()
            epoch_total += total
            epoch_correct += correct
            epoch_step += 1
            #
            pbar.set_postfix({'loss': loss.item(), 'acc': correct / total})
        stats.train.loss.append(epoch_loss / epoch_step)
        stats.train.acc.append(epoch_correct / epoch_total)
        stats.train.epoch.append(epoch_idx)

        ################
        # EVAL
        ################
        model.eval()
        epoch_step = 0
        epoch_loss = 0
        epoch_total = 0
        epoch_correct = 0
        #
        desc = desc_tmp.format(epoch_idx, num_epochs, 'valid')
        pbar = tqdm(dl_valid, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
        #
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)
            with torch.no_grad():
                out = model(x)
                loss = criterion(out, y)
            #
            _, y_pred = torch.max(out, 1)
            total = y.size(0)
            correct = (y_pred == y).sum().item()
            #
            epoch_loss += loss.item()
            epoch_total += total
            epoch_correct += correct
            epoch_step += 1
            #
            pbar.set_postfix({'loss': loss.item(), 'acc': correct / total})
            #
        stats.valid.loss.append(epoch_loss / epoch_step)
        stats.valid.acc.append(epoch_correct / epoch_total)
        stats.valid.epoch.append(epoch_idx)
    return stats

# Settings

In [None]:
# linprob config
linprob_config = {
    'p_eval': 'eval',
    'p_eval_results': 'results.pkl',
    'p_config': "linprob_config.pkl",
    'p_results': 'results.pkl',
    'device': "cuda",
    'cuda_visible_devices': '0',
    'n_hid': 0,
    'd_hid': 1024,
    'batch_size': 1024,
    'optimizer': 'adam',
    'optimizer_args': {'lr': 0.001, 'weight_decay': 1e-6},
    'num_epochs': 20,
    'eval_name': "Width_and_Depth"
}
linprob_config = DottedDict(linprob_config)

In [None]:
# TORCH SETTINGS
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = linprob_config.cuda_visible_devices
device = torch.device(linprob_config.device)

In [None]:
p_experiments_base = Path("/mnt/experiments/csprites/single_csprites_64x64_n7_c128_a32_p10_s3_bg_inf_random_function_100000")
p_experiments_base = p_experiments_base / "beta"
experiment_names = [
    "Beta_w8_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_20-26-36",
    "Beta_w8_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_19-56-29",
    "Beta_w8_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_19-26-26",
    "Beta_w8_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_18-56-30",
    "Beta_w16_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_22-27-56",
    "Beta_w16_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_21-57-28",
    "Beta_w16_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_21-27-05",
    "Beta_w16_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_20-56-43",
    "Beta_w32_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_00-29-59",
    "Beta_w32_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_23-28-55",
    "Beta_w32_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_22-58-22",
    "Beta_w32_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_23-59-27",
    "Beta_w64_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_01-31-01",
    "Beta_w64_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_01-00-30",
    "Beta_w64_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_02-01-25",
    "Beta_w64_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_02-31-53",
    "Beta_w128_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_04-33-57",
    "Beta_w128_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_03-02-22",
    "Beta_w128_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_04-03-29",
    "Beta_w128_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_03-32-53",
    "Beta_w256_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_06-36-04",
    "Beta_w256_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_05-34-55",
    "Beta_w256_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_05-04-23",
    "Beta_w256_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_06-05-29",
    "Beta_w512_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_08-38-17",
    "Beta_w512_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_08-07-42",
    "Beta_w512_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_07-37-07",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_18-51-20",
    "Beta_w1024_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_09-08-54",
    "Beta_w1024_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_09-39-36",
    "Beta_w1024_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_10-10-18",
    "Beta_w1024_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_10-39-43",
]

p_experiments = [p_experiments_base / name for name in experiment_names]

for p in p_experiments:
    assert p.exists()

In [None]:
p_experiments_base = Path("/mnt/experiments/csprites/single_csprites_64x64_n7_c128_a32_p10_s3_bg_inf_random_function_100000")
p_experiments_base = p_experiments_base / "beta"
experiment_names = [
    "Beta_w512_d1_alpha_0.1_beta_0.9_wdkl_100_[ResNet-18]_2021-08-31_05-35-28",
    "Beta_w512_d1_alpha_0.2_beta_0.8_wdkl_100_[ResNet-18]_2021-08-31_06-36-50",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_100_[ResNet-18]_2021-08-31_05-04-48",
    "Beta_w512_d1_alpha_0.8_beta_0.2_wdkl_100_[ResNet-18]_2021-08-31_07-07-30",
    "Beta_w512_d1_alpha_0.9_beta_0.1_wdkl_100_[ResNet-18]_2021-08-31_06-06-06",
    "Beta_w512_d1_alpha_1.1111111111111112_beta_10.0_wdkl_100_[ResNet-18]_2021-08-31_08-39-33",
    "Beta_w512_d1_alpha_1.25_beta_5.0_wdkl_100_[ResNet-18]_2021-08-31_09-40-55",
    "Beta_w512_d1_alpha_2.0_beta_2.0_wdkl_100_[ResNet-18]_2021-08-31_07-38-10",
    "Beta_w512_d1_alpha_5.0_beta_1.25_wdkl_100_[ResNet-18]_2021-08-31_09-10-13",
    "Beta_w512_d1_alpha_10.0_beta_1.1111111111111112_wdkl_100_[ResNet-18]_2021-08-31_08-08-53",
    "Beta_w512_d1_alpha_0.1_beta_0.9_wdkl_50_[ResNet-18]_2021-08-31_00-28-35",
    "Beta_w512_d1_alpha_0.2_beta_0.8_wdkl_50_[ResNet-18]_2021-08-31_01-29-57",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_50_[ResNet-18]_2021-08-30_23-57-54",
    "Beta_w512_d1_alpha_0.8_beta_0.2_wdkl_50_[ResNet-18]_2021-08-31_02-00-36",
    "Beta_w512_d1_alpha_0.9_beta_0.1_wdkl_50_[ResNet-18]_2021-08-31_00-59-16",
    "Beta_w512_d1_alpha_1.1111111111111112_beta_10.0_wdkl_50_[ResNet-18]_2021-08-31_03-32-47",
    "Beta_w512_d1_alpha_1.25_beta_5.0_wdkl_50_[ResNet-18]_2021-08-31_04-34-09",
    "Beta_w512_d1_alpha_2.0_beta_2.0_wdkl_50_[ResNet-18]_2021-08-31_02-31-18",
    "Beta_w512_d1_alpha_5.0_beta_1.25_wdkl_50_[ResNet-18]_2021-08-31_04-03-32",
    "Beta_w512_d1_alpha_10.0_beta_1.1111111111111112_wdkl_50_[ResNet-18]_2021-08-31_03-02-04",
    "Beta_w512_d1_alpha_0.1_beta_0.9_wdkl_10_[ResNet-18]_2021-08-30_19-21-46",
    "Beta_w512_d1_alpha_0.2_beta_0.8_wdkl_10_[ResNet-18]_2021-08-30_20-23-05",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_07-06-32",
    "Beta_w512_d1_alpha_0.8_beta_0.2_wdkl_10_[ResNet-18]_2021-08-30_20-53-43",
    "Beta_w512_d1_alpha_0.9_beta_0.1_wdkl_10_[ResNet-18]_2021-08-30_19-52-22",
    "Beta_w512_d1_alpha_1.1111111111111112_beta_10.0_wdkl_10_[ResNet-18]_2021-08-30_22-25-51",
    "Beta_w512_d1_alpha_1.25_beta_5.0_wdkl_10_[ResNet-18]_2021-08-30_23-27-14",
    "Beta_w512_d1_alpha_2.0_beta_2.0_wdkl_10_[ResNet-18]_2021-08-30_21-24-23",
    "Beta_w512_d1_alpha_5.0_beta_1.25_wdkl_10_[ResNet-18]_2021-08-30_22-56-32",
    "Beta_w512_d1_alpha_10.0_beta_1.1111111111111112_wdkl_10_[ResNet-18]_2021-08-30_21-55-05",
]

p_experiments = [p_experiments_base / name for name in experiment_names]

for p in p_experiments:
    assert p.exists()

In [None]:
names = []
for p in list(p_experiments_base.glob("*")):
    if p.name in experiment_names:
        continue
    print(p.name)
    #if "w1024" in p.name and "alpha_0.5" in p.name and "wdkl_10" in p.name:
    #    print(p.name) 

In [None]:
step = 1
max_steps = len(p_experiments) * 2 * 6
print_str = "[{:>3}/{:>3}] {:>10} {:<10} {}"
for p_experiment in p_experiments:
    experiment_name = p_experiment.name
    
    # experiment config
    p_experiment_config = p_experiment / "config.pkl"
    with open(p_experiment_config, "rb") as file:
        experiment_config = pickle.load(file)
    # dataset config
    p_ds_config = Path(experiment_config.p_data) / "config.pkl"
    with open(p_ds_config, "rb") as file:
        ds_config = pickle.load(file)
    ds_reps = {
            "bacbone": {
                "R_train": experiment_config["p_R_train"],
                "R_valid": experiment_config["p_R_valid"],
                "Y_train": experiment_config["p_Y_train"],
                "Y_valid": experiment_config["p_Y_valid"],
                  },
           "betapro": {
               "R_train": experiment_config["p_R_train_bp"],
               "R_valid": experiment_config["p_R_valid_bp"],
               "Y_train": experiment_config["p_Y_train_bp"],
               "Y_valid": experiment_config["p_Y_valid_bp"],
                }
           }
    results = {ds_key: {} for ds_key in ds_reps.keys()}
    for ds_key in ds_reps.keys():
        for target_variable in ds_config["classes"]:
            print(print_str.format(step, max_steps, ds_key, target_variable, experiment_name))
            target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
            n_classes = ds_config["n_classes"][target_variable]
            #
            dl_train, dl_valid, d_r = get_datasets(target_idx,
                                                   p_experiment / ds_reps[ds_key]["R_train"],
                                                   p_experiment / ds_reps[ds_key]["R_valid"],
                                                   p_experiment / ds_reps[ds_key]["Y_train"],
                                                   p_experiment / ds_reps[ds_key]["Y_valid"],
                                                   linprob_config.batch_size
                                                  )
            #
            model = Net(n_classes, d_r, linprob_config.d_hid, linprob_config.n_hid)
            model = model.to(device)
            #
            optimizer = get_optimizer(linprob_config.optimizer, model.parameters(), linprob_config.optimizer_args)
            criterion = nn.CrossEntropyLoss()
            #
            stats = train_model(model, linprob_config.num_epochs, optimizer, criterion)
            results[ds_key][target_variable] = stats
            step += 1
    p_results = p_experiment / linprob_config["p_results"]
    p_linprob_config = p_experiment / linprob_config["p_config"]
    #
    with open(p_results, "wb") as file:
        pickle.dump(results, file)
    with open(p_linprob_config, "wb") as file:
        pickle.dump(linprob_config, file)

# Visualize Experiments

In [None]:
vis_name = "Depth_vs_Width"

In [None]:
experiment_names_width_first = [
    "Beta_w8_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_20-26-36",
    "Beta_w8_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_19-56-29",
    "Beta_w8_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_19-26-26",
    "Beta_w8_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_18-56-30",
    "Beta_w16_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_22-27-56",
    "Beta_w16_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_21-57-28",
    "Beta_w16_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_21-27-05",
    "Beta_w16_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_20-56-43",
    "Beta_w32_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_00-29-59",
    "Beta_w32_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_23-59-27",
    "Beta_w32_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_23-28-55",
    "Beta_w32_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_22-58-22",
    "Beta_w64_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_02-31-53",
    "Beta_w64_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_02-01-25",
    "Beta_w64_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_01-31-01",
    "Beta_w64_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_01-00-30",
    "Beta_w128_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_04-33-57",
    "Beta_w128_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_04-03-29",
    "Beta_w128_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_03-32-53",
    "Beta_w128_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_03-02-22",
    "Beta_w256_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_06-36-04",
    "Beta_w256_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_06-05-29",
    "Beta_w256_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_05-34-55",
    "Beta_w256_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_05-04-23",
    "Beta_w512_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_08-38-17",
    "Beta_w512_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_08-07-42",
    "Beta_w512_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_07-37-07",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_18-51-20",
    "Beta_w1024_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_10-39-43",
    "Beta_w1024_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_10-10-18",
    "Beta_w1024_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_09-39-36",
    "Beta_w1024_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_09-08-54",
]
experiment_names_deexperiment_names_depth_firstpth_first = [
    "Beta_w8_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_20-26-36",
    "Beta_w16_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_22-27-56",
    "Beta_w32_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_00-29-59",
    "Beta_w64_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_02-31-53",
    "Beta_w128_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_04-33-57",
    "Beta_w256_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_06-36-04",
    "Beta_w512_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_08-38-17",
    "Beta_w1024_d4_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_10-39-43",

    "Beta_w8_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_19-56-29",
    "Beta_w16_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_21-57-28",
    "Beta_w32_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_23-59-27",
    "Beta_w64_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_02-01-25",
    "Beta_w128_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_04-03-29",
    "Beta_w256_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_06-05-29",
    "Beta_w512_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_08-07-42",
    "Beta_w1024_d3_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_10-10-18",

    "Beta_w8_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_19-26-26",
    "Beta_w16_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_21-27-05",
    "Beta_w32_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_23-28-55",
    "Beta_w64_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_01-31-01",
    "Beta_w128_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_03-32-53",
    "Beta_w256_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_05-34-55",
    "Beta_w512_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_07-37-07",
    "Beta_w1024_d2_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_09-39-36",

    "Beta_w8_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_18-56-30",
    "Beta_w16_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_20-56-43",
    "Beta_w32_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_22-58-22",
    "Beta_w64_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_01-00-30",
    "Beta_w128_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_03-02-22",
    "Beta_w256_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_05-04-23",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-30_18-51-20",
    "Beta_w1024_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_09-08-54",
]
experiment_names_new = [
    "Beta_w512_d1_alpha_0.1_beta_0.9_wdkl_100_[ResNet-18]_2021-08-31_05-35-28",
    "Beta_w512_d1_alpha_0.2_beta_0.8_wdkl_100_[ResNet-18]_2021-08-31_06-36-50",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_100_[ResNet-18]_2021-08-31_05-04-48",
    "Beta_w512_d1_alpha_0.8_beta_0.2_wdkl_100_[ResNet-18]_2021-08-31_07-07-30",
    "Beta_w512_d1_alpha_0.9_beta_0.1_wdkl_100_[ResNet-18]_2021-08-31_06-06-06",
#    "Beta_w512_d1_alpha_1.1111111111111112_beta_10.0_wdkl_100_[ResNet-18]_2021-08-31_08-39-33",
#    "Beta_w512_d1_alpha_1.25_beta_5.0_wdkl_100_[ResNet-18]_2021-08-31_09-40-55",
#    "Beta_w512_d1_alpha_2.0_beta_2.0_wdkl_100_[ResNet-18]_2021-08-31_07-38-10",
#    "Beta_w512_d1_alpha_5.0_beta_1.25_wdkl_100_[ResNet-18]_2021-08-31_09-10-13",
#    "Beta_w512_d1_alpha_10.0_beta_1.1111111111111112_wdkl_100_[ResNet-18]_2021-08-31_08-08-53",
    "Beta_w512_d1_alpha_0.1_beta_0.9_wdkl_50_[ResNet-18]_2021-08-31_00-28-35",
    "Beta_w512_d1_alpha_0.2_beta_0.8_wdkl_50_[ResNet-18]_2021-08-31_01-29-57",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_50_[ResNet-18]_2021-08-30_23-57-54",
    "Beta_w512_d1_alpha_0.8_beta_0.2_wdkl_50_[ResNet-18]_2021-08-31_02-00-36",
    "Beta_w512_d1_alpha_0.9_beta_0.1_wdkl_50_[ResNet-18]_2021-08-31_00-59-16",
#    "Beta_w512_d1_alpha_1.1111111111111112_beta_10.0_wdkl_50_[ResNet-18]_2021-08-31_03-32-47",
#    "Beta_w512_d1_alpha_1.25_beta_5.0_wdkl_50_[ResNet-18]_2021-08-31_04-34-09",
#    "Beta_w512_d1_alpha_2.0_beta_2.0_wdkl_50_[ResNet-18]_2021-08-31_02-31-18",
#    "Beta_w512_d1_alpha_5.0_beta_1.25_wdkl_50_[ResNet-18]_2021-08-31_04-03-32",
#    "Beta_w512_d1_alpha_10.0_beta_1.1111111111111112_wdkl_50_[ResNet-18]_2021-08-31_03-02-04",
    "Beta_w512_d1_alpha_0.1_beta_0.9_wdkl_10_[ResNet-18]_2021-08-30_19-21-46",
    "Beta_w512_d1_alpha_0.2_beta_0.8_wdkl_10_[ResNet-18]_2021-08-30_20-23-05",
    "Beta_w512_d1_alpha_0.5_beta_0.5_wdkl_10_[ResNet-18]_2021-08-31_07-06-32",
    "Beta_w512_d1_alpha_0.8_beta_0.2_wdkl_10_[ResNet-18]_2021-08-30_20-53-43",
    "Beta_w512_d1_alpha_0.9_beta_0.1_wdkl_10_[ResNet-18]_2021-08-30_19-52-22",
#    "Beta_w512_d1_alpha_1.1111111111111112_beta_10.0_wdkl_10_[ResNet-18]_2021-08-30_22-25-51",
#    "Beta_w512_d1_alpha_1.25_beta_5.0_wdkl_10_[ResNet-18]_2021-08-30_23-27-14",
#    "Beta_w512_d1_alpha_2.0_beta_2.0_wdkl_10_[ResNet-18]_2021-08-30_21-24-23",
#    "Beta_w512_d1_alpha_5.0_beta_1.25_wdkl_10_[ResNet-18]_2021-08-30_22-56-32",
#    "Beta_w512_d1_alpha_10.0_beta_1.1111111111111112_wdkl_10_[ResNet-18]_2021-08-30_21-55-05",
]
experiment_names = experiment_names_new
vis_name = "width_first"

In [None]:
p_experiments = [p_experiments_base / name for name in experiment_names]
experiment_results = {}
for p_experiment in p_experiments:
    experiment_name = p_experiment.name
    print(experiment_name)
    #
    p_results = p_experiment / linprob_config["p_results"]
    #
    with open(p_results, "rb") as file:
        results = pickle.load(file)
    experiment_results[experiment_name] = results

In [None]:
p_eval =  p_experiments_base / linprob_config["p_eval"]
p_eval.mkdir(exist_ok=True)
p_results = p_eval / linprob_config["p_results"]
p_config = p_eval / linprob_config["p_config"]
p_plot = p_eval / "{}.png".format(vis_name)

In [None]:
with open(p_results, "wb") as file:
    pickle.dump(experiment_results, file)
with open(p_config, "wb") as file:
    pickle.dump(linprob_config, file)

In [None]:
n_avrg = 5
results = {}
model_names = list(experiment_results.keys())
representations = list(experiment_results[model_names[0]].keys())
target_variables = list(experiment_results[model_names[0]][representations[0]].keys())

In [None]:
all_accs = []
all_names = []
for model_name in model_names:
    for rep in representations:
        if rep == "bacbone":
            continue
        rep_accs = []
        for target_variable in target_variables:
            accs = experiment_results[model_name][rep][target_variable]["valid"]['acc'][-n_avrg:]
            acc = sum(accs) / len(accs)
            rep_accs.append(acc)
        name = model_name.split("_2021")[0] + "_{}".format(rep)
        all_names.append(name)
        all_accs.append(rep_accs)
accs = np.array(all_accs)
names = all_names

In [None]:
plot_mat(accs, names, target_variables, p_file=p_plot)

In [None]:
# row means
means = accs.mean(axis=1)

In [None]:
for model_name, mean in zip(names, means):
    print("{:<40}: {:.2f}".format(model_name, mean))