In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
import os
from pathlib import Path
import pickle
import timeit
#
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, models
from dotted_dict import DottedDict
from tqdm import tqdm
import pprint
#
from csprites.datasets import ClassificationDataset
import utils
from backbone import get_backbone
from optimizer import get_optimizer

In [None]:
class Net(nn.Module):
    def __init__(self, n_classes, d_in, d_hid=1024, n_hid=0):
        super(Net, self).__init__()
        dims = [d_in]
        for _ in range(n_hid):
            dims.append(d_hid)
        dims.append(n_classes)
        #
        layers = []
        for idx in range(1, len(dims) - 1, 1):
            layers.append(nn.Linear(dims[idx - 1], dims[idx]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(dims[-2], dims[-1]))
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

def get_datasets(target_idx, p_R_train, p_R_valid, p_Y_train, p_Y_valid, batch_size):
    #
    R_train = torch.Tensor(np.load(p_R_train))
    R_valid = torch.Tensor(np.load(p_R_valid))
    #
    Y_train = torch.LongTensor(np.load(p_Y_train))
    Y_valid = torch.LongTensor(np.load(p_Y_valid))
    #
    d_r = R_train.shape[1]
    #
    Y_train = Y_train[:, target_idx]
    Y_valid = Y_valid[:, target_idx]
    #
    ds_train = torch.utils.data.TensorDataset(R_train, Y_train)
    ds_valid = torch.utils.data.TensorDataset(R_valid, Y_valid)

    dl_train = DataLoader(
        ds_train,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True)
    dl_valid = DataLoader(
        ds_valid,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True)
    return dl_train, dl_valid, d_r

def train_model(model, num_epochs, optimizer, criterion):
    stats = {
    'train': {
        'loss': [],
        'acc': [],
        'epoch': [],
    },
    'valid': {
        'loss': [],
        'acc': [],
        'epoch': [],
    }
    }
    stats = DottedDict(stats)
    desc_tmp = "Epoch [{:3}/{:3}] {}:"
    #
    for epoch_idx in range(1, num_epochs + 1, 1):
        ################
        # TRAIN
        ################
        model.train()
        epoch_step = 0
        epoch_loss = 0
        epoch_total = 0
        epoch_correct = 0
        #
        desc = desc_tmp.format(epoch_idx, num_epochs, 'train')
        pbar = tqdm(dl_train, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
        #
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)
            for param in model.parameters():
                param.grad = None
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            #
            _, y_pred = torch.max(out, 1)
            total = y.size(0)
            correct = (y_pred == y).sum().item()
            #
            epoch_loss += loss.item()
            epoch_total += total
            epoch_correct += correct
            epoch_step += 1
            #
            pbar.set_postfix({'loss': loss.item(), 'acc': correct / total})
        stats.train.loss.append(epoch_loss / epoch_step)
        stats.train.acc.append(epoch_correct / epoch_total)
        stats.train.epoch.append(epoch_idx)

        ################
        # EVAL
        ################
        model.eval()
        epoch_step = 0
        epoch_loss = 0
        epoch_total = 0
        epoch_correct = 0
        #
        desc = desc_tmp.format(epoch_idx, num_epochs, 'valid')
        pbar = tqdm(dl_valid, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
        #
        for x, y in pbar:
            x = x.to(device)
            y = y.to(device)
            with torch.no_grad():
                out = model(x)
                loss = criterion(out, y)
            #
            _, y_pred = torch.max(out, 1)
            total = y.size(0)
            correct = (y_pred == y).sum().item()
            #
            epoch_loss += loss.item()
            epoch_total += total
            epoch_correct += correct
            epoch_step += 1
            #
            pbar.set_postfix({'loss': loss.item(), 'acc': correct / total})
            #
        stats.valid.loss.append(epoch_loss / epoch_step)
        stats.valid.acc.append(epoch_correct / epoch_total)
        stats.valid.epoch.append(epoch_idx)
    return stats

# Settings

In [None]:
# linprob config
linprob_config = {
    'p_eval': 'eval',
    'p_eval_results': 'results.pkl',
    'p_config': "linprob_config.pkl",
    'p_results': 'results.pkl',
    'device': "cuda",
    'cuda_visible_devices': '0',
    'n_hid': 0,
    'd_hid': 1024,
    'batch_size': 1024,
    'optimizer': 'adam',
    'optimizer_args': {'lr': 0.001, 'weight_decay': 1e-6},
    'num_epochs': 40,
}
linprob_config = DottedDict(linprob_config)

In [None]:
# TORCH SETTINGS
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = linprob_config.cuda_visible_devices
device = torch.device(linprob_config.device)

In [None]:
p_experiments_base = Path("/mnt/experiments/csprites/single_csprites_64x64_n7_c128_a32_p10_s3_bg_inf_random_function_100000")
p_experiments = [
#    p_experiments_base / "SUP_[ResNet-18]_target_[shape]",
#    p_experiments_base / "SUP_[ResNet-18]_target_[scale]",
#    p_experiments_base / "SUP_[ResNet-18]_target_[color]",
#    p_experiments_base / "SUP_[ResNet-18]_target_[angle]",
#    p_experiments_base / "SUP_[ResNet-18]_target_[py]",
#    p_experiments_base / "SUP_[ResNet-18]_target_[px]",
#    p_experiments_base / 'BTwins_[ResNet-18]_LARS',
#    p_experiments_base / 'BTwins_[ResNet-18]_Adam',
#    p_experiments_base / 'BTwins_[ResNet-18]_aug_GEO_only',
#    p_experiments_base / 'BTwins_[ResNet-18]_aug_STYLE_only',
#    p_experiments_base / 'BTwins_[ResNet-18]_4L_dp_2048',
    #p_experiments_base / 'BTwins_[ResNet-18]_4L_geo_style_02_08',
#    p_experiments_base / 'BTwins_[ResNet-18]_4L_geo_style_05_05',
#    p_experiments_base / 'BTwins_[ResNet-18]_4L_geo_style_08_02',
    p_experiments_base / 'BTwins_[ResNet-18]_geo_style_50_50_100'
]
for p in p_experiments:
    assert p.exists()

In [None]:
experiment_results = {}
for p_experiment in p_experiments:
    experiment_name = p_experiment.name
    print(experiment_name)
    #
    # experiment config
    p_experiment_config = p_experiment / "config.pkl"
    with open(p_experiment_config, "rb") as file:
        experiment_config = pickle.load(file)
    #
    # dataset config
    p_ds_config = Path(experiment_config.p_data) / "config.pkl"
    with open(p_ds_config, "rb") as file:
        ds_config = pickle.load(file)
    
    results = {}
    for target_variable in ds_config["classes"]:
        print(target_variable)
        #
        target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
        n_classes = ds_config["n_classes"][target_variable]
        #
        dl_train, dl_valid, d_r = get_datasets(target_idx,
                                               p_experiment / experiment_config["p_R_train"],
                                               p_experiment / experiment_config["p_R_valid"],
                                               p_experiment / experiment_config["p_Y_train"],
                                               p_experiment / experiment_config["p_Y_valid"],
                                               linprob_config.batch_size
                                              )
        #
        model = Net(n_classes, d_r, linprob_config.d_hid, linprob_config.n_hid)
        model = model.to(device)
        #
        optimizer = get_optimizer(linprob_config.optimizer, model.parameters(), linprob_config.optimizer_args)
        criterion = nn.CrossEntropyLoss()
        #
        stats = train_model(model, linprob_config.num_epochs, optimizer, criterion)
        results[target_variable] = stats

        # plot losses
        fig, axes = plt.subplots(1, 2)
        axes[0].plot(stats.train.epoch, stats.train.loss, label="train")
        axes[0].plot(stats.valid.epoch, stats.valid.loss, label="valid")
        #axes[0].set_yscale('log')
        axes[0].set_title("Loss")
        axes[0].legend()


        # plot accs
        axes[1].plot(stats.train.epoch, stats.train.acc, label="train")
        axes[1].plot(stats.valid.epoch, stats.valid.acc, label="valid")
        #axes[1].set_yscale('log')
        axes[1].set_title("Acc")
        axes[1].legend()
        #
        fig.suptitle(target_variable)
        plt.tight_layout()
        plt.show()
    p_results = p_experiment / linprob_config["p_results"]
    p_linprob_config = p_experiment / linprob_config["p_config"]
    #
    with open(p_results, "wb") as file:
        pickle.dump(results, file)
    with open(p_linprob_config, "wb") as file:
        pickle.dump(linprob_config, file)
    experiment_results[experiment_name] = results

# Load all results 

In [None]:
p_experiments = [
    p_experiments_base / "SUP_[ResNet-18]_target_[shape]",
    p_experiments_base / "SUP_[ResNet-18]_target_[scale]",
    p_experiments_base / "SUP_[ResNet-18]_target_[color]",
    p_experiments_base / "SUP_[ResNet-18]_target_[angle]",
    p_experiments_base / "SUP_[ResNet-18]_target_[py]",
    p_experiments_base / "SUP_[ResNet-18]_target_[px]",
#    p_experiments_base / 'BTwins_[ResNet-18]_LARS',
    p_experiments_base / 'BTwins_[ResNet-18]_Adam',
    p_experiments_base / 'BTwins_[ResNet-18]_aug_STYLE_only',
    p_experiments_base / 'BTwins_[ResNet-18]_aug_GEO_only',
    #p_experiments_base / 'BTwins_[ResNet-18]_4L_dp_2048',
    p_experiments_base / 'BTwins_[ResNet-18]_4L_geo_style_02_08',
    p_experiments_base / 'BTwins_[ResNet-18]_4L_geo_style_05_05',
    p_experiments_base / 'BTwins_[ResNet-18]_4L_geo_style_08_02',
#    p_experiments_base / 'BTwins_[ResNet-18]_geo_style_50_50_100'
]
for p in p_experiments:
    if not p.exists():
        print(p)
    assert p.exists()

In [None]:
experiment_results = {}
for p_experiment in p_experiments:
    experiment_name = p_experiment.name
    print(experiment_name)
    #
    p_results = p_experiment / linprob_config["p_results"]
    #
    with open(p_results, "rb") as file:
        results = pickle.load(file)
    experiment_results[experiment_name] = results

In [None]:
p_eval =  p_experiments_base / linprob_config["p_eval"]
p_eval.mkdir(exist_ok=True)
p_results = p_eval / linprob_config["p_results"]
p_config = p_eval / linprob_config["p_config"]
p_plot = p_eval / "results.png"

In [None]:
with open(p_results, "wb") as file:
    pickle.dump(experiment_results, file)
with open(p_config, "wb") as file:
    pickle.dump(linprob_config, file)

In [None]:
n_avrg = 5
results = {}
model_names = list(experiment_results.keys())
target_variables = list(experiment_results[model_names[0]].keys())
all_accs = []
for model_name in model_names:
    model_accs = []
    for target_variable in target_variables:
        accs = experiment_results[model_name][target_variable]["valid"]['acc'][-n_avrg:]
        acc = sum(accs) / len(accs)
        model_accs.append(acc)
    all_accs.append(model_accs)
accs = np.array(all_accs)

In [None]:
scale_factor = 2
n_rows = len(model_names)
n_cols = len(target_variables)

fig, ax = plt.subplots(figsize=(n_cols * scale_factor,
                                n_rows * scale_factor))
im = ax.imshow(accs, cmap="copper")
#
ax.set_xticks(np.arange(len(target_variables)))
ax.set_yticks(np.arange(len(model_names)))
#
ax.set_xticklabels(target_variables)
ax.set_yticklabels(model_names)
#
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
for col_idx in range(len(target_variables)):
    for row_idx in range(len(model_names)):
        text = ax.text(col_idx, row_idx, "{:.2f}".format(accs[row_idx, col_idx]),
                       ha="center", va="center", color="w")
        #text = ax.text(col_idx, row_idx, "r{},c{}".format(row_idx, col_idx))

ax.set_title("Accs")
fig.tight_layout()
fig.savefig(p_plot)
plt.show()

In [None]:
# row means
means = accs.mean(axis=1)

In [None]:
for model_name, mean in zip(model_names, means):
    print("{:<40}: {:.2f}".format(model_name, mean))

In [None]:
means