# Imports

In [None]:
!pip install neptune-client

In [None]:
from itertools import cycle
import math
import os
import shutil
import zipfile

import matplotlib.pyplot as plt
import numpy as np


import neptune
import seaborn
from sklearn import metrics
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
def full_experiment_name(experiment):
    params = experiment.get_parameters()
    if 'size_after_pool' in params:
        return f"{experiment.name}_pool_size_{params['size_after_pool']}"
    else:
        return experiment.name

def download_artifacts(exp):
    if os.path.isdir("output"):
        shutil.rmtree("output")
        
    exp.download_artifacts()

    with zipfile.ZipFile("output.zip", 'r') as zip_ref:
        zip_ref.extractall(".")

In [None]:
def mean_max(head_probs):
    head_certainty, head_answer = head_probs.max(-1)
    return head_answer, head_certainty

def mean_entropy(head_probs):
    head_ensemble_probs = head_probs.mean(dim=1)
    head_answer = head_ensemble_probs.argmax(-1)
    head_certainty = -torch.distributions.categorical.Categorical(head_ensemble_probs).entropy()
    return head_answer, head_certainty, head_ensemble_probs

def mean_second_diff(head_probs):
    head_ensemble_probs = head_probs.mean(dim=1)
    (values, indices) = torch.sort(head_ensemble_probs, dim=-1, descending=True)
    head_answer = indices[:, 0]
    head_certainty = values[:, 0] - values[:, 1]
    return head_answer, head_certainty, head_ensemble_probs

def count_agreement(head_probs):
    head_ensemble_probs = head_probs.mean(dim=1)
    head_answer = head_ensemble_probs.argmax(-1)
    ensemble_answers = head_probs.argmax(-1)
    head_certainty = torch.true_divide((ensemble_answers == head_answer.unsqueeze(1)).sum(dim=-1),
                                       head_probs.size(dim=1))
    return head_answer, head_certainty, head_ensemble_probs


In [None]:
def calibrate_heads(method, preds, labels, iters=5000, lr=1e-2 / 2, optimizer_class=torch.optim.SGD, params=None):
    assert preds.size(2) == 1
    calib_head_preds = []
    calib_head_params = []
    for i in range(preds.size(1)):
        print(f'Calibrating head {i}')
        head_preds = preds[:, i,...].squeeze()
        calib_params = None if params is None else params[i]
        calibrated_head_preds, calib_params = method(head_preds, labels, iters=iters, lr=lr, optimizer_class=optimizer_class, params=calib_params)
        calib_head_preds.append(calibrated_head_preds)
        calib_head_params.append(calib_params)
    # unsqueeze so there is 1 net in the "ensemble"
    return torch.stack(calib_head_preds, dim=1).unsqueeze(2), calib_head_params


def temperature_scaling(preds, labels, iters, lr, optimizer_class, params=None):
    correct_preds = (preds.argmax(-1) == labels).float().detach()
    
    if params is not None:
        t = params
    else:
        t = torch.tensor([1.], requires_grad=True)
        optimizer = optimizer_class([t], lr=lr)
        for idx in range(iters):
            temped_l_probs = torch.log_softmax(preds.detach() / t, dim=-1)
            nll = F.nll_loss(temped_l_probs, labels)
            # max_probs = torch.softmax(preds.detach() * t, dim=-1).max(-1)[0]
            # ll = torch.log(max_probs + 1e-8) * correct_preds + torch.log(1 - max_probs + 1e-8) * (1 - correct_preds)
            # ll = -(max_probs - correct_preds).square()
            # nll = -ll.mean()
            optimizer.zero_grad()
            nll.backward()
            optimizer.step()
            if idx % 100 == 0:
                print(f'(temperature_scaling) nll: {nll.item()}')
                pass
    return torch.softmax(preds * t, dim=-1).cpu().detach(), t


def vector_scaling(preds, labels, iters, lr, optimizer_class, params=None):
    correct_preds = (preds.argmax(-1) == labels).float().detach()
    
    if params is not None:
        w, b = params
    else:
        w = torch.ones(preds.size(-1), requires_grad=True)
        b = torch.zeros(preds.size(-1), requires_grad=True)
        optimizer = optimizer_class([w, b], lr=lr)
        for idx in range(iters):
            temped_l_probs = torch.log_softmax(w * preds.detach() + b, dim=-1)
            nll = F.nll_loss(temped_l_probs, labels)
            optimizer.zero_grad()
            nll.backward()
            optimizer.step()
            if idx % 100 == 0:
                print(f'(vector_scaling) nll: {nll.item()}')
                pass
    return torch.softmax(w * preds.detach() + b, dim=-1).cpu().detach(), (w, b)


def matrix_scaling(preds, labels, iters, lr, optimizer_class, params=None):
    correct_preds = (preds.argmax(-1) == labels).float().detach()
    
    if params is not None:
        w, b = params
    else:
        w = torch.diag(torch.ones(preds.size(-1)))
        w.requires_grad = True
        b = torch.zeros(preds.size(-1), requires_grad=True)
        optimizer = optimizer_class([w, b], lr=lr)
        for idx in range(iters):
            temped_l_probs = torch.log_softmax(preds.detach() @ w + b, dim=-1)
            nll = F.nll_loss(temped_l_probs, labels)
            optimizer.zero_grad()
            nll.backward()
            optimizer.step()
            if idx % 100 == 0:
                print(f'(matrix_scaling) nll: {nll.item()}')
                pass
    return torch.softmax(preds.detach() @ w + b, dim=-1).cpu().detach(), (w, b)


def dirichlet_calibration(probs, labels, iters, lr, optimizer_class, params=None):
    ln_probs = probs.log()
    correct_preds = (probs.argmax(-1) == labels).float().detach()
    
    if params is not None:
        w, b = params
    else:
        w = torch.diag(torch.ones(probs.size(-1)))
        w.requires_grad = True
        b = torch.zeros(probs.size(-1), requires_grad=True)
        optimizer = optimizer_class([w, b], lr=lr)
        for idx in range(iters):
            temped_l_probs = torch.log_softmax(probs.log().detach() @ w + b, dim=-1)
            nll = F.nll_loss(temped_l_probs, labels)
            optimizer.zero_grad()
            nll.backward()
            optimizer.step()
            if idx % 100 == 0:
                print(f'(dirichlet_calibration) nll: {nll.item()}')
                pass
    return torch.softmax(probs.log().detach() @ w + b, dim=-1).cpu().detach(), (w, b)

#===============================================================================

ids_to_calibrate = []
# ids_to_calibrate = ['CON1-676', 'CON1-677', 'CON1-676_ada_weighting', 'CON1-677_r']
ids_to_add = []
for exp_id in ids_to_calibrate:
    if experiment_preds_softmaxed[exp_id] == True:
        methods = [dirichlet_calibration]
    else:
        # methods = [temperature_scaling, vector_scaling, matrix_scaling]
        methods = [temperature_scaling, matrix_scaling]
    for calibration_method in methods:
        new_id = f'{exp_id}_{calibration_method.__name__}'
        
        experiment_names[new_id] = f'{experiment_names[exp_id]}_{calibration_method.__name__}'
        experiment_cls_weights[new_id] = experiment_cls_weights[exp_id]

        train_logits = experiment_train_logits[exp_id]
        train_last_logits = experiment_train_last_logits[exp_id]
        train_labels = experiment_train_labels[exp_id]
        print(f'calibrating {experiment_names[exp_id]}')
        experiment_train_logits[new_id], calib_params = calibrate_heads(calibration_method, train_logits, train_labels)
        experiment_train_last_logits[new_id] = torch.softmax(train_last_logits, dim=-1)
        experiment_train_labels[new_id] = train_labels

        test_logits = experiment_test_logits[exp_id]
        test_last_logits = experiment_test_last_logits[exp_id]
        test_labels = experiment_test_labels[exp_id]
        # note that training calib params are used here
        experiment_test_logits[new_id], _ = calibrate_heads(calibration_method, test_logits, test_labels, params=calib_params)
        experiment_test_last_logits[new_id] = torch.softmax(test_last_logits, dim=-1)
        experiment_test_labels[new_id] = test_labels

        experiment_total_params[new_id] = experiment_total_params[exp_id]
        experiment_total_ops[new_id] = experiment_total_ops[exp_id]

        experiment_preds_softmaxed[new_id] = True
        
        if not torch.isnan(experiment_train_logits[new_id]).any():
            ids_to_add.append(new_id)
    

In [None]:
def plot_roc_curves(data, head_i):
    current_palette = cycle(seaborn.color_palette('tab10'))
    plt.figure(figsize=(15, 7))
    for name, auc_scores, curve_data, accs in data:
        current_color = next(current_palette)
        plt.plot(curve_data[head_i][0], curve_data[head_i][1], 
                 label=f'{name} CA-AUC={auc_scores[head_i]:0.2f} ACC={accs[head_i]:0.2f}', color=current_color)
        
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.legend()
    plt.title(f'head {head_i}')
    
    plt.show()


def plot_aucs(data):
    current_palette = cycle(seaborn.color_palette("tab10"))
    fig, axs = plt.subplots(1, 2, figsize=(15, 7))
    ax1 = axs[0]
    ax2 = axs[1]
    ax1.set_ylabel("Correct-answer AUC")
    ax1.set_xlabel("Layer")
    # ax2 = ax1.twinx()
    ax2.set_ylabel("Accuracy")
    ax2.set_xlabel("Layer")
    for name, auc_scores, _, accs in data:
        current_color = next(current_palette)
        ax1.scatter(np.arange(len(auc_scores)) + 1, auc_scores, label=name, color=current_color, marker='x')    
        ax2.scatter(np.arange(len(auc_scores)) + 1, accs, color=current_color, marker='+')
    # ax1.legend()
    plt.show()


# auc_data = []
# for exp_id in experiment_subset_ids:
#     test_logits = experiment_test_logits[exp_id]
#     test_last_logits = experiment_test_last_logits[exp_id]
#     test_labels = experiment_test_labels[exp_id]
#     auc_scores, roc_curves = uncertainty_auroc_check(test_logits, test_last_logits, test_labels, softmax=not experiment_preds_softmaxed[exp_id])
#     last_acc, ens_accs, _ = check_head_pred_acc(test_logits, test_last_logits, test_labels, softmax=not experiment_preds_softmaxed[exp_id])
#     accs = ens_accs + [last_acc]
#     auc_data.append((experiment_names[exp_id], auc_scores, roc_curves, accs))



def uncertainty_auroc_check(logits, last_logits, labels, unc_method=mean_max, softmax=True):
    head_certainty_responses = [[] for _ in range(logits.size(1) + 1)]
    head_correct_responses = [[] for _ in range(logits.size(1) + 1)]

    probs = logits.softmax(-1) if softmax else logits
    last_probs = last_logits.softmax(-1) if softmax else last_logits

    answered = torch.tensor([-1 for _ in range(len(last_logits))])
    
    for head_i in range(logits.size(1)):
        head_probs = probs[:,head_i, :, :]
        head_answer, head_certainty, head_ensemble_probs = unc_method(head_probs)
        head_correct = head_answer == labels
        head_certainty_responses[head_i].append(head_certainty.cpu().detach().numpy())
        head_correct_responses[head_i].append(head_correct.cpu().detach().numpy())

    last_answer, last_certainty, last_ensemble_probs = unc_method(last_probs.unsqueeze(1))
    last_correct = last_answer == labels
    head_certainty_responses[logits.size(1)].append(last_certainty.cpu().detach().numpy())
    head_correct_responses[logits.size(1)].append(last_correct.cpu().detach().numpy())

    head_certainty_responses = [
        np.concatenate(certainty_responses)
        for certainty_responses in head_certainty_responses
    ]
    head_correct_responses = [
        np.concatenate(correct_responses)
        for correct_responses in head_correct_responses
    ]

    roc_curves = []
    auc_scores = []
    for certainty_responses, correct_responses in zip(
            head_certainty_responses, head_correct_responses):
        fpr, tpr, thresholds = metrics.roc_curve(correct_responses,
                                                 certainty_responses)
        auc = metrics.auc(fpr, tpr)
        roc_curves.append((fpr, tpr, thresholds))
        auc_scores.append(auc)

    return auc_scores, roc_curves

def mean_max(head_probs):
    head_certainty, head_answer = head_probs.max(-1)
    return head_answer, head_certainty

def mean_entropy(head_probs):
    head_ensemble_probs = head_probs.mean(dim=1)
    head_answer = head_ensemble_probs.argmax(-1)
    head_certainty = -torch.distributions.categorical.Categorical(head_ensemble_probs).entropy()
    return head_answer, head_certainty, head_ensemble_probs

def mean_second_diff(head_probs):
    head_ensemble_probs = head_probs.mean(dim=1)
    (values, indices) = torch.sort(head_ensemble_probs, dim=-1, descending=True)
    head_answer = indices[:, 0]
    head_certainty = values[:, 0] - values[:, 1]
    return head_answer, head_certainty, head_ensemble_probs

def count_agreement(head_probs):
    head_ensemble_probs = head_probs.mean(dim=1)
    head_answer = head_ensemble_probs.argmax(-1)
    ensemble_answers = head_probs.argmax(-1)
    head_certainty = torch.true_divide((ensemble_answers == head_answer.unsqueeze(1)).sum(dim=-1),
                                       head_probs.size(dim=1))
    return head_answer, head_certainty, head_ensemble_probs


In [None]:
def confidence_ece(probs, labels, n_bins = 20):
    num_examples = labels.size(0)
    max_probs, max_inds = probs.max(-1)
    bins = torch.linspace(0., 1., steps=n_bins + 1)
    conf_ece = 0.0
    for i in range(n_bins):
        bin_start = bins[i]
        bin_end = bins[i + 1]
        examples_in_bin = (bin_start <= max_probs) * (max_probs < bin_end)
        bin_size = examples_in_bin.sum().item()
        if bin_size > 0:
            bin_acc = (max_inds[examples_in_bin] == labels[examples_in_bin]).sum().item() / bin_size
            bin_avg_confidence = max_probs[examples_in_bin].sum().item() / bin_size
            conf_ece += bin_size * abs(bin_acc - bin_avg_confidence)
    conf_ece /= num_examples
    return conf_ece


def classwise_ece(probs, labels, n_bins = 20):
    num_examples = probs.size(0)
    num_classes = probs.size(-1)
    bins = torch.linspace(0., 1., steps=n_bins + 1)
    class_ece = 0.0
    for k in range(num_classes):
        for i in range(n_bins):
            bin_start = bins[i]
            bin_end = bins[i + 1]
            examples_in_bin = (bin_start <= probs[:, k]) * (probs[:, k] < bin_end)
            bin_size = examples_in_bin.sum().item()
            if bin_size > 0:
                bin_class_prop = (labels[examples_in_bin] == k).sum().item() / bin_size
                bin_avg_prob = probs[examples_in_bin][:, k].sum().item() / bin_size
                class_ece += bin_size * abs(bin_class_prop - bin_avg_prob)
    class_ece /= num_classes * num_examples
    return class_ece

def calibration_histogram(preds, labels, title):
    # preds shape [N, C]
    # labels shape [C]
    
    # probs = torch.softmax(preds, -1)
    probs = preds
    max_probs, max_inds = probs.max(1)
    
    bin_accs = []
    bin_examples = []
    bins_num = 20
    for bin_idx in range(bins_num):
        bin_start = bin_idx / bins_num
        bin_end = bin_start + 1 / bins_num
        
        examples_in_bin = (bin_start < max_probs) * (max_probs < bin_end)
        bin_examples = [examples_in_bin.sum()]
        
        bin_preds = preds[examples_in_bin]
        bin_probs = probs[examples_in_bin]
        bin_labels = labels[examples_in_bin]
        
        if len(bin_preds > 0):
            bin_acc = (bin_preds.argmax(1) == bin_labels).float().mean()
        else:
            bin_acc = 0.
        
        bin_accs += [bin_acc]
    bin_examples = np.array(bin_examples)
        
    x = np.arange(bins_num) / bins_num
    error = np.abs(np.array(bin_accs) - np.arange(bins_num) / bins_num) * bin_examples
    error = error.sum() / bin_examples.sum()
    plt.title(title)
    plt.bar(range(bins_num), bin_accs)
    plt.plot(range(bins_num), np.arange(bins_num) / bins_num)
    plt.scatter(range(bins_num), np.arange(bins_num) / bins_num, zorder=100)
    plt.xticks(np.linspace(0, bins_num, num=5), np.linspace(0, 1, num=5))
    
    plt.show()
    print(f"Error {title}: {error}")

def get_calibrated_preds(preds, labels):
    all_calibrated_preds = []
    for idx in range(7):
        calibration_histogram(preds[:, idx, :], labels, "before calibration")
        calibrated_preds = calibrate_dataset(preds[:, idx, :], labels)
        calibration_histogram(calibrated_preds, labels, "after calibration")

        all_calibrated_preds.append(calibrated_preds)
    all_calibrated_preds = torch.stack(all_calibrated_preds, dim=1)
    return all_calibrated_preds

In [None]:
def get_eces(logits, labels, bins=20, softmax=True):
    probs = torch.softmax(logits, dim=-1) if softmax else logits
    num_heads = logits.size(1)
    conf_eces = []
    class_eces = []

    for head_i in range(num_heads):
        head_probs = probs[:,head_i, :]
        conf_eces.append(confidence_ece(head_probs, labels))
        class_eces.append(classwise_ece(head_probs, labels))

    return np.array(conf_eces), np.array(class_eces)

def get_ensemble_eces(exp, dataset="test", unc_method=mean_max, bins=20):
    probs = torch.softmax(exp[f'{dataset}_logits'], dim=-1)
    num_heads = probs.size(1)

    conf_eces = []
    class_eces = []

    for head_i in range(num_heads):
        head_probs = probs[:,head_i, :]
        conf_eces.append(confidence_ece(head_probs, exp[f'{dataset}_labels']))
        class_eces.append(classwise_ece(head_probs, exp[f'{dataset}_labels']))

    return np.array(conf_eces), np.array(class_eces)


def plot_eces(data):
    current_palette = cycle(seaborn.color_palette("tab10"))
    fig, axs = plt.subplots(2, 1, figsize=(16, 16))
    ax1 = axs[0]
    ax2 = axs[1]
    ax1.set_ylabel("Confidence ECE")
    ax1.set_xlabel("Layer")
    # ax2 = ax1.twinx()
    ax2.set_ylabel("Classwise ECE")
    ax2.set_xlabel("Layer")
    for name, conf_eces, class_eces in data:
        current_color = next(current_palette)
        marker_type = "o" if "scaling" in name else "x" 
        ax1.scatter(np.arange(len(conf_eces)) + 1, conf_eces, label=name, color=current_color,
                    marker=marker_type)    
        ax2.scatter(np.arange(len(class_eces)) + 1, class_eces, color=current_color, marker=marker_type)
    ax1.legend()
    ax2.legend()
    plt.show()



# experiment = experiment_subset[-1]
# test_logits = experiment_test_logits[exp_id]
# test_last_logits = experiment_test_last_logits[exp_id]
# test_labels = experiment_test_labels[exp_id]
# ece_data = []
# conf_eces, class_eces = get_ensemble_eces(test_logits, test_labels)
# ece_data.append((full_experiment_name(experiment), conf_eces, class_eces))
# for net_i in range(test_logits.size(-2)):
#     conf_eces, class_eces = get_eces(test_logits[:, :, net_i, :], test_labels)
#     ece_data.append((f'{full_experiment_name(experiment)}_net_{net_i}', conf_eces, class_eces))
# plot_eces(ece_data)

In [None]:
def get_head_accs(logits, labels):
    head_accs = (logits.argmax(-1) == labels.unsqueeze(1)).float().mean(0).numpy()
    return head_accs

def head_improvement_matrix(exp, dataset="test", softmax=True):
    preds = exp[f'{dataset}_logits'].softmax(-1)
    labels = exp[f'{dataset}_labels']
    
    improved_counts = torch.zeros(
        preds.size(1),
        preds.size(1),
        dtype=int)
    total_counts = torch.zeros(
        preds.size(1),
        preds.size(1),
        dtype=int)

    corrects = preds.argmax(-1) == labels.unsqueeze(-1)
    for idx in range(preds.size(1)):
        failed_indices = corrects[:, idx] == False
        improved_counts[idx, :] += corrects[failed_indices].sum(dim=0)
        total_counts[idx, :] += failed_indices.sum(dim=0)

    head_improvement = (improved_counts.float() /
                        total_counts.float()).detach().cpu().numpy()

    plt.figure(figsize=(30, 30))
    plt.ylabel("How many misclassified by head...")
    plt.xlabel("...are answered correctly by head")
    plt.xticks(list(range(preds.size(1))))
    plt.yticks(list(range(preds.size(1))))
    plt.imshow(head_improvement, vmin=0., vmax=1., cmap=plt.get_cmap('viridis'))
    for row_idx, row in enumerate(head_improvement):
        for col_idx, val in enumerate(row):
            val_str = f"{val * 100:.2f}%"
            plt.text(col_idx,
                     row_idx,
                     val_str,
                     horizontalalignment="center",
                     c="white")

    fig = plt.gcf()
    return head_improvement, fig



def get_bounds(logits, unc_method=mean_max, softmax=True):
    num_heads = logits.size(1)
    probs = logits.softmax(-1) if softmax else logits
    ensemble_probs = probs.mean(dim=2)
    min_value = math.inf
    max_value = -math.inf
    for head_i in range(num_heads):
        head_probs = probs[:,head_i, :]
        _, head_certainty = unc_method(head_probs)
        min_value = min(min_value, head_certainty.min().item())
        max_value = max(max_value, head_certainty.max().item())
    return min_value, max_value

def ensemble_predictor(exp, threshold=1., dataset="test", unc_method=mean_max):
    probs = exp[f'{dataset}_logits'].softmax(-1)
    answered = torch.tensor([-1] * len(probs))
    ensemble_pred = torch.tensor([-1] * len(probs))
    
    for head_i in range(exp['num_heads']):
        head_probs = probs[:, head_i, :]
        head_answer, head_certainty = unc_method(head_probs)
        agreement = (head_certainty > threshold) * (answered < 0)
        answered[agreement] = head_i
        ensemble_pred[agreement] = head_answer[agreement]
    # unanswered samples get the original (last head) answer
    ensemble_pred[answered < 0] = probs[:, head_i, :].argmax(-1)[answered < 0]
    answered[answered < 0] = head_i

    head_ops = exp['total_ops'].expand(len(probs), -1)
    avg_ops = head_ops.gather(1, answered.unsqueeze(1)).mean().item()
    acc = (ensemble_pred == exp[f'{dataset}_labels']).float().mean().item()
    return acc, avg_ops
    
def ensemble_check(exp, x_linspace, dataset="test", unc_method=mean_max):
    accs = []
    avg_ops = []
    for thresh in x_linspace:
        acc, avg_op = ensemble_predictor(
            exp, thresh, dataset=dataset, unc_method=unc_method)
        accs.append(acc)
        avg_ops.append(avg_op)
    return accs, avg_ops

def patience_check(exp, dataset="test"):
    accs = []
    avg_ops = []
    for patience_val in range(1, exp['num_heads'] + 1):
        acc, avg_op = patience_predictor(
            exp, patience_val, dataset=dataset)
        accs.append(acc)
        avg_ops.append(avg_op)
    return accs, avg_ops

def patience_predictor(exp, patience_thresh, dataset="test"):
    preds = exp[f'{dataset}_logits'].argmax(-1)

    answered = torch.tensor([-1] * len(preds))
    ensemble_pred = torch.tensor([-1] * len(preds))
    prev_pred = torch.tensor([-1] * len(preds))
    patience = torch.tensor([0] * len(preds))
    
    for head_i in range(exp['num_heads']):
        head_preds = preds[:, head_i]
        patience = torch.where(
            head_preds == prev_pred,
            patience + 1,
            1
        )
        agreement = (patience >= patience_thresh) * (answered < 0)
        answered[agreement] = head_i
        ensemble_pred[agreement] = head_preds[agreement]
        prev_pred = head_preds.clone()
    # unanswered samples get the original (last head) answer
    ensemble_pred[answered < 0] = head_preds[answered < 0]
    answered[answered < 0] = head_i

    head_ops = exp['total_ops'].expand(len(preds), -1)
    avg_ops = head_ops.gather(1, answered.unsqueeze(1)).mean().item()
    acc = (ensemble_pred == exp[f'{dataset}_labels']).float().mean().item()
    return acc, avg_ops


## Improvability

In [None]:
FONT_SIZE = 22
FIVE_THIRTY_EIGHT = {
        "SDN": "#30a2da",
        "PABEE": "#fc4f30",
        "Zero Time Waste": "#e5ae38",
        "Stacking": "#6d904f",
        "Ensembling": "#810f7c",
}

def improvability(exp, dataset="test", show_plot=False):
    improvability = [0.]
    correctness = (exp[f'{dataset}_logits'].argmax(-1) == exp[f'{dataset}_labels'].unsqueeze(1)).numpy()
    for head_idx in range(1, correctness.shape[1]):
        wrong_ids = (correctness[:, head_idx] == False) > 0
        # print((correctness[wrong_ids, :head_idx].sum(1) > 0).mean())
        improv = (correctness[wrong_ids, :head_idx].sum(1) > 0).mean() # - (1 - 0.9 ** head_idx)
        improvability += [improv] 
    improvability = np.array(improvability)

    return improvability[1:]

def correctness(exp, dataset="test"):
    improvability = [0.]
    correctness = (exp[f'{dataset}_logits'].argmax(-1) == exp[f'{dataset}_labels'].unsqueeze(1)).numpy()
    for head_idx in range(1, correctness.shape[1]):
        wrong_ids = (correctness[:, head_idx] == True)
        # print((correctness[wrong_ids, -1].astype(float) > 0))
        improv = (correctness[wrong_ids, -2].astype(float) > 0).mean() # - (1 - 0.9 ** head_idx)
        improvability += [improv]
    improvability = np.array(improvability)
    return improvability[1:]



def plot_improvability(exps, title=""):
    plt.figure(figsize=(15, 9))
    color_scheme = FIVE_THIRTY_EIGHT
    for exp_id, exp in exps.items():
        if exp_id == "Base Network" or exp_id == "SDN+Stacking":
            continue
        improv = improvability(exp)
        plt.plot(
            np.arange(len(improv)) + 2, improv, marker="o", markersize=20, linewidth=3.,
                    label=exp_id, c=color_scheme[exp_id])
        # plt.scatter(np.arange(len(improv)) + 1, improv, s=60, label=exp_id)
    num_heads = exp['num_heads']
    num_classes = exp['test_logits'].shape[-1]
    

    base = (1 / num_classes)
    random_improv = (1 - base ** np.arange(num_heads))[1:]
    print(random_improv)
    # plt.plot(np.arange(len(random_improv)) + 1, random_improv, '--', marker="o", markersize=7.75, color="gray", label="Random Baseline")
    plt.xlabel("IC", fontsize=FONT_SIZE)
    plt.ylabel("Hindsight Improvability", fontsize=FONT_SIZE)
    plt.title(title + "\n(lower is better)", fontdict={'fontsize': FONT_SIZE + 1})

    plt.gca().xaxis.set_major_locator(mpl.ticker.MultipleLocator(1))
    for tick in plt.gca().xaxis.get_major_ticks():
        tick.label.set_fontsize(FONT_SIZE - 4) 
    for tick in plt.gca().yaxis.get_major_ticks():
        tick.label.set_fontsize(FONT_SIZE - 4)


    plt.legend(prop={'size': FONT_SIZE})
    plt.show()

## Time-Acc Plot

In [None]:
import matplotlib as mpl
from matplotlib.legend_handler import HandlerLine2D, HandlerTuple


def draw_time_acc_plot(data, ax, first=False, title=None):
    
    current_palette = FIVE_THIRTY_EIGHT
    # current_symbols = ["^", "s", ]

    baseline_name, _, _, baseline_acc, baseline_ops  = data[0]
    baseline_ops = baseline_ops.item()


    for name, acc, avg_ops, head_accs, head_ops in data[1:]:
        print("OPS", avg_ops, baseline_ops)
        avg_ops = np.array(avg_ops)
        current_color = current_palette[name]
        ax.plot(
            avg_ops, acc, label=name,
            ls="-",
            linewidth=3.5,
            color=current_color)
        # plt.scatter(avg_ops, acc, color=current_color)
        if name == "PABEE":
            continue
        ax.scatter(head_ops, head_accs, s=90, marker='o', color=current_color,
                    edgecolors="black", linewidths=0., zorder=3)

    
    ax.scatter(
        baseline_ops, baseline_acc, s=250, marker='X',
        color="black", label=baseline_name, zorder=3, linewidths=0.)

    ax.set_xlim(right=1.1 * baseline_ops)
    ax.set_ylabel('Accuracy', fontsize=FONT_SIZE)
    ax.set_xlabel('Inference Time', fontsize=FONT_SIZE)
    ax.set_title(title, fontdict={'fontsize': FONT_SIZE + 1})

    ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(baseline_ops / 4))
    ax.xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=baseline_ops))

    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(FONT_SIZE - 4) 
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(FONT_SIZE - 4) 

    
    if first:
        handles, labels = ax.get_legend_handles_labels()
        circles = []
        for name, _, _, _, _ in data[1:]:
            if name == "PABEE":
                continue
            circles += [mpl.lines.Line2D([], [], color=FIVE_THIRTY_EIGHT[name], marker='o', linestyle='None',
                                markersize=10, label='IC')]

        handles += [tuple(circles)]
        labels += ["IC"]
        ax.legend(handles=handles, labels=labels, prop={'size': FONT_SIZE}, handler_map={tuple: HandlerTuple(ndivide=None)})
    

def plot_time_acc(experiments, patience_id=None, dev=False):
    dataset = "test"
    print("=" * 15, dataset.upper(), "=" * 15)
    time_acc_data = []
    for exp_id, exp in experiments.items():
        print(f'Total params for experiment {exp_id}: {exp["total_params"]}')
        head_accs = get_head_accs(exp[f'{dataset}_logits'], exp[f'{dataset}_labels'])
        unc_method = mean_max
        min_value, max_value = get_bounds(exp[f'{dataset}_logits'], unc_method=unc_method)
        x_linspace = np.concatenate([
            np.linspace(min_value, 0.95, num=150),
            np.linspace(0.95, 1., num=150)])
        x_linspace = np.linspace(min_value, 1., num=200)
        if dev:
            x_linspace = np.linspace(min_value, 1., num=10)
        accs, avg_ops = ensemble_check(exp, x_linspace, dataset=dataset, unc_method=unc_method)
        time_acc_data.append((f'{exp_id}', accs, avg_ops, head_accs, exp['total_ops']))
    
    if patience_id is not None:
        exp = experiments[patience_id]
        head_accs = get_head_accs(exp[f'{dataset}_logits'], exp[f'{dataset}_labels'])
        accs, avg_ops = patience_check(exp, dataset=dataset)

        patience_data = ('PABEE', accs, avg_ops, head_accs, exp['total_ops'])
        time_acc_data = time_acc_data[:-1] + [patience_data] + [time_acc_data[-1]]
    return time_acc_data

def draw_table(experiments, baseline_id, dataset="test", patience_id=None):
    baseline_exp = experiments[baseline_id]
    baseline_ops = baseline_exp['total_ops'][0]
    baseline_acc = (
        baseline_exp[f'{dataset}_logits'].squeeze().argmax(-1) == baseline_exp[f'{dataset}_labels']).float().mean()

    print(f"Baseline: {baseline_acc * 100:.1f}%")

    for exp_id, exp in experiments.items():
        if exp_id == baseline_id:
            continue
        unc_method = mean_max
        min_value, max_value = get_bounds(exp[f'{dataset}_logits'], unc_method=unc_method)
        x_linspace = np.linspace(min_value, 1., num=1000)
        accs, avg_ops = ensemble_check(exp, x_linspace, dataset=dataset, unc_method=unc_method)
        avg_ops = torch.tensor(avg_ops)
        accs = torch.tensor(accs)

        print(exp["name"])
        print("& ", end="")
        for thresh in [0.25, 0.5, 0.75, 1.]:
            ops_thresh = thresh * baseline_ops
            last_acc = accs[avg_ops < ops_thresh][-1]
            print(f"{last_acc * 100:.1f}", end=" & ")
        print(f"{accs.max() * 100:.1f} ")
    
    if patience_id is not None:
        exp = experiments[patience_id]
        accs, avg_ops = patience_check(exp, dataset=dataset)
        avg_ops = torch.tensor(avg_ops)
        accs = torch.tensor(accs)

        print("patience")
        for thresh in [0.25, 0.5, 0.75, 1.]:
            ops_thresh = thresh * baseline_ops
            last_acc = accs[avg_ops < ops_thresh][-1]
            print(f"{last_acc * 100:.1f}", end=" & ")
        print(f"{accs.max() * 100:.1f} & ")

In [None]:
def prepare_regular_experiment(exp_id):
    exp = project.get_experiments(exp_id)[0]
    exp_dict = {}
    exp_dict['name'] = full_experiment_name(exp)
    download_artifacts(exp)
    
    test_logits = torch.load('output/test_logits', map_location="cpu")
    # train_logits = torch.load('output/train_logits', map_location="cpu")
    
    if len(test_logits.shape) > 2:  # last_logits non-None
        test_last_logits = torch.load('output/test_last_logits', map_location="cpu")
        test_logits = torch.cat([
            test_logits.squeeze(),
            test_last_logits.unsqueeze(1),
        ], 1)
    
        # train_last_logits = torch.load('output/train_last_logits', map_location="cpu")
        # train_logits = torch.cat([
        #     train_logits.squeeze(),
        #     train_last_logits.unsqueeze(1),
        # ], 1)
    else:
        # train_logits = train_logits.unsqueeze(1)
        test_logits = test_logits.unsqueeze(1)
    
    # exp_dict['train_logits'] = train_logits
    # exp_dict['train_labels'] = torch.load('output/train_labels', map_location='cpu')
    exp_dict['test_logits'] = test_logits
    exp_dict['test_labels'] = torch.load('output/test_labels', map_location='cpu')


    
    total_params = torch.load('output/total_params', map_location="cpu")
    total_ops = torch.load('output/total_ops', map_location="cpu")
    
    if isinstance(total_params, dict):
        total_params = torch.tensor([total_params[n] for n in range(max(total_params.keys()) + 1)])
        total_ops = torch.tensor([total_ops[n] for n in range(max(total_ops.keys()) + 1)])
    
    print(exp_id, total_ops.shape)
    exp_dict['total_params'] = total_params
    exp_dict['total_ops'] = total_ops
    exp_dict['num_heads'] = test_logits.shape[1]
    
    return exp_dict



def prepare_ensb_experiment(tag):
    exps = project.get_experiments(tag)[0]
    exp_dict = {}
    exp_dict['name'] = full_experiment_name(exp)
    download_artifacts(exp)
    
    test_logits = torch.load('output/test_logits', map_location="cpu")
    # train_logits = torch.load('output/train_logits', map_location="cpu")
    
    if len(test_logits.shape) > 2:  # last_logits non-None
        test_last_logits = torch.load('output/test_last_logits', map_location="cpu")
        test_logits = torch.cat([
            test_logits.squeeze(),
            test_last_logits.unsqueeze(1),
        ], 1)
    
        # train_last_logits = torch.load('output/train_last_logits', map_location="cpu")
        # train_logits = torch.cat([
        #     train_logits.squeeze(),
        #     train_last_logits.unsqueeze(1),
        # ], 1)
    else:
        # train_logits = train_logits.unsqueeze(1)
        test_logits = test_logits.unsqueeze(1)
    
    # exp_dict['train_logits'] = train_logits
    # exp_dict['train_labels'] = torch.load('output/train_labels', map_location='cpu')
    exp_dict['test_logits'] = test_logits
    exp_dict['test_labels'] = torch.load('output/test_labels', map_location='cpu')
    
    total_params = torch.load('output/total_params', map_location="cpu")
    total_ops = torch.load('output/total_ops', map_location="cpu")
    
    if isinstance(total_params, dict):
        total_params = torch.tensor([total_params[n] for n in range(max(total_params.keys()) + 1)])
        total_ops = torch.tensor([total_ops[n] for n in range(max(total_ops.keys()) + 1)])
    
    print(total_params, total_ops)
    exp_dict['total_params'] = total_params
    exp_dict['total_ops'] = total_ops
    exp_dict['num_heads'] = test_logits.shape[1]
    
    return exp_dict

def prepare_run_ensb_experiment(tag):
    exps = project.get_experiments(tag=tag)

    exp_dict = {}
    exp_dict['name'] = tag

    child_tags = list(tag for tag in exps[0].get_tags() if "child_of_" in tag)
    assert len(child_tags) == 1
    parent_tag = child_tags[0].replace("child_of_", "")
    parent_exps = project.get_experiments(tag=parent_tag)
    assert len(parent_exps) == 1, parent_exps

    download_artifacts(parent_exps[0])

    total_params = torch.load("output/total_params", map_location="cpu")
    total_ops = torch.load("output/total_ops", map_location="cpu")
    
    total_params = torch.tensor([total_params[n] for n in range(max(total_params.keys()) + 1)])
    total_ops = torch.tensor([total_ops[n] for n in range(max(total_ops.keys()) + 1)])

    head_dict = {}
    for exp in exps:
        key = int(exp.get_parameters()['head_idx'])
        head_dict[key] = exp

    # train_logits = []
    test_logits = []

    for key in range(max(head_dict.keys()) + 1):
        exp = head_dict[key]
        download_artifacts(exp)
        # train_logits += [torch.load("output/train_logits", map_location="cpu")]
        test_logits += [torch.load("output/test_logits", map_location="cpu")]

    # exp_dict['train_logits'] = torch.stack(train_logits, 1)
    exp_dict['test_logits'] = torch.stack(test_logits, 1)

    num_classes = exp_dict['test_logits'].shape[-1]

    # for idx in range(0, len(total_ops)):
    #     total_ops[idx] += (((idx + 1) * num_classes * 2 - 1) * num_classes) / 1e9


    new_ops = torch.zeros_like(total_ops)
    new_ops[0] = (num_classes + num_classes) / 1e9
    for idx in range(1, len(total_ops)):
        mul_ops = (idx + 1) * num_classes
        add_ops = idx * num_classes
        bias_ops = num_classes 

        new_ops[idx] = new_ops[idx - 1] + (mul_ops + add_ops + num_classes) / 1e9
        # total_ops[idx] += (((idx + 1) * num_classes * 2 - 1) * num_classes) / 1e9
    
    print(new_ops, total_ops)
    exp_dict['total_params'] = total_params
    exp_dict['total_ops'] = total_ops + new_ops

    print(tag, total_ops.shape)
    exp_dict['num_heads'] = len(exps)
    
    # exp_dict['train_labels'] = torch.load("output/train_labels", map_location="cpu")
    exp_dict['test_labels'] = torch.load("output/test_labels", map_location="cpu")

    return exp_dict


# Time-Acc Plot

In [None]:
project = neptune.init('TODO fill me', api_token='=')

In [None]:
cifar10_experiments = {}

cifar10_experiments["Base Network"] = prepare_regular_experiment("CON1-2505")
cifar10_experiments["SDN"] = prepare_regular_experiment("CON1-2508")
# experiments["SDN+Stacking"] = prepare_regular_experiment("CON1-2507")
cifar10_experiments["Zero Time Waste"] = prepare_run_ensb_experiment('20210116_cifar10_mobilenet_running_ensb')

In [None]:
cifar100_experiments = {}

cifar100_experiments["Base Network"] = prepare_regular_experiment("CON1-2948")
cifar100_experiments["SDN"] = prepare_regular_experiment("CON1-2963")
# experiments["SDN+Stacking"] = prepare_regular_experiment("CON1-2964")
cifar100_experiments["Zero Time Waste"] = prepare_run_ensb_experiment('20210116_cifar100_vgg16bn_running_ensb')

In [None]:
tinyimagenet_experiments = {}

tinyimagenet_experiments["Base Network"] = prepare_regular_experiment("CON1-3223")
tinyimagenet_experiments["SDN"] = prepare_regular_experiment("CON1-3354")
# experiments["SDN+Stacking"] = prepare_regular_experiment("CON1-3355")
tinyimagenet_experiments["Zero Time Waste"] = prepare_run_ensb_experiment('20210123_tinyimagenet_resnet56_running_ensb')

In [None]:
DEV = False
fig, axes = plt.subplots(1, 1, figsize=(15, 9))
axes = [axes]
# [1, 2]?
seaborn.set_style('whitegrid')
# plt.style.use('fivethirtyeight')

# time_acc_data = plot_time_acc(cifar10_experiments, patience_id='SDN', dev=DEV)
# draw_time_acc_plot(time_acc_data, ax=axes[0], first=True, title="MobileNet - CIFAR-10")

# time_acc_data = plot_time_acc(cifar100_experiments, patience_id='SDN', dev=DEV)
# draw_time_acc_plot(time_acc_data, ax=axes[1], first=False, title="VGG16 - CIFAR-100")
# 
time_acc_data = plot_time_acc(tinyimagenet_experiments, patience_id='SDN', dev=DEV)
draw_time_acc_plot(time_acc_data, ax=axes[0], first=True, title="ResNet56 - Tiny ImageNet")
axes[0].set_xlabel('Inference Time', fontsize=FONT_SIZE)

plt.show()

In [None]:
DEV = False
fig, axes = plt.subplots(1, 2, figsize=(25, 10))
seaborn.set_style('whitegrid')
plt.tight_layout()

# time_acc_data = plot_time_acc(cifar10_experiments, patience_id='SDN', dev=False)
# draw_time_acc_plot(time_acc_data, ax=axes[0], first=True, title="MobileNet - CIFAR-10")

time_acc_data = plot_time_acc(cifar100_experiments, patience_id='SDN', dev=DEV)
draw_time_acc_plot(time_acc_data, ax=axes[0], first=True, title="VGG16 on CIFAR-100")

time_acc_data = plot_time_acc(tinyimagenet_experiments, patience_id='SDN', dev=DEV)
draw_time_acc_plot(time_acc_data, ax=axes[1], first=False, title="ResNet56 on TinyImagenet")

axes[1].set_ylabel('', fontsize=FONT_SIZE)


plt.show(fig)

In [None]:
fig.tight_layout()
plt.show(fig)

In [None]:
for exp_id, exp in experiments.items():
    head_improvement, fig = head_improvement_matrix(exp, dataset="test")
    plt.title(exp['name'])
    plt.show(fig)
    
plot_improvability(experiments)

# Improvability

In [None]:
c100_mobile_experiments = {}

c100_mobile_experiments['Base Network'] = prepare_regular_experiment('CON1-2506')
c100_mobile_experiments['SDN'] = prepare_regular_experiment('CON1-2509')
# c100_mobile_experiments['SDN+Stacking'] = prepare_regular_experiment('CON1-2510')
c100_mobile_experiments["Zero Time Waste"] = prepare_run_ensb_experiment('20210116_cifar100_mobilenet_running_ensb')


In [None]:
plot_improvability(c100_mobile_experiments, title="Hindsight Improvability for MobileNet on CIFAR-100")

# Ablations

In [None]:
c100_vgg_experiments = {}

c100_vgg_experiments['Base Network'] = prepare_regular_experiment('CON1-2948')
c100_vgg_experiments['SDN'] = prepare_regular_experiment('CON1-2963')
c100_vgg_experiments['Stacking'] = prepare_regular_experiment('CON1-2964')
c100_vgg_experiments["Ensembling"] = prepare_run_ensb_experiment('20210116_cifar100_vgg16bn_running_ensb_baseline')
c100_vgg_experiments["Zero Time Waste"] = prepare_run_ensb_experiment('20210116_cifar100_vgg16bn_running_ensb')

In [None]:
c100_resnet_experiments = {}

c100_resnet_experiments['Base Network'] = prepare_regular_experiment('CON1-2544')
c100_resnet_experiments['SDN'] = prepare_regular_experiment('CON1-2563')
c100_resnet_experiments['Stacking'] = prepare_regular_experiment('CON1-2564')
c100_resnet_experiments["Ensembling"] = prepare_run_ensb_experiment('20210116_cifar100_resnet56_running_ensb_baseline')
c100_resnet_experiments["Zero Time Waste"] = prepare_run_ensb_experiment('20210116_cifar100_resnet56_running_ensb')

In [None]:
ids_to_include = ['CON1-2544', 'CON1-2563', 'CON1-2564']
# tags_to_include = ['20210109_first_exp_running_ensb_test', '20210109_first_exp_running_ensb_train']
tags_to_include = ['20210116_cifar100_resnet56_running_ensb']

c100_mobile_experiments = {}

c10_mobile_experiments['Base Network'] = prepare_regular_experiment('CON1-2505')
c10_mobile_experiments['SDN'] = prepare_regular_experiment('CON1-2508')
c10_mobile_experiments['SDN+Stacking'] = prepare_regular_experiment('CON1-2507')
c10_mobile_experiments["Zero Time Waste"] = prepare_run_ensb_experiment('20210116_cifar10_mobilenet_running_ensb')


In [None]:
DEV = False
fig, axes = plt.subplots(2, 1, figsize=(15, 18))
# axes = [axes]
seaborn.set_style('whitegrid')

time_acc_data = plot_time_acc(c100_resnet_experiments, patience_id=None, dev=DEV)
draw_time_acc_plot(time_acc_data, ax=axes[0], first=True, title="ResNet56 - CIFAR-100")
# axes[0].set_xlabel('GigaOps', fontsize=FONT_SIZE)

time_acc_data = plot_time_acc(c100_vgg_experiments, patience_id=None, dev=DEV)
draw_time_acc_plot(time_acc_data, ax=axes[1], first=False, title="VGG - CIFAR-100")
axes[1].set_xlabel('Inference Time', fontsize=FONT_SIZE)

plt.show()