# Reliability Diagram

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from os.path import join

from cal_methods import HistogramBinning, TemperatureScaling
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression

import sys
from os import path
sys.path.append(path.dirname(path.dirname( path.abspath("utility"))))
from utility.unpickle_probs import unpickle_probs
from utility.evaluation import get_bin_info, softmax, ECE, MCE

## Load in the data

In [None]:
# PATH = '../../logits/EXP2/hist_1'
# files = [
#     'cp_10_logits.p',
#     'cp_50_logits.p',
#     'cp_100_logits.p',
#     'cp_200_logits.p',
#     'cp_300_logits.p',
#     'cp_500_logits.p'
# ]

### Reliability diagrams as subgraph

In [None]:
# reliability diagram plotting for subplot case.
def rel_diagram_sub(accs, confs, ax, M = 10, name = "Reliability Diagram", xname = "", yname=""):

    acc_conf = np.column_stack([accs,confs])
    acc_conf.sort(axis=1)
    outputs = acc_conf[:, 0]
    gap = acc_conf[:, 1]

    bin_size = 1/M
    positions = np.arange(0+bin_size/2, 1+bin_size/2, bin_size)

    # Plot gap first, so its below everything
    gap_plt = ax.bar(positions, gap, width = bin_size, edgecolor = "red", color = "red", alpha = 0.3, label="Gap", linewidth=2, zorder=2)

    #Bars with outputs
    output_plt = ax.bar(positions, outputs, width = bin_size, edgecolor = "black", color = "blue", label="Outputs", zorder = 3)

    # Line plot with center line.
    ax.set_aspect('equal')
    ax.plot([0,1], [0,1], linestyle = "--")
    ax.legend(handles = [gap_plt, output_plt])
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_title(name, fontsize=24)
    ax.set_xlabel(xname, fontsize=22, color = "black")
    ax.set_ylabel(yname, fontsize=22, color = "black")

In [None]:
def get_pred_conf(y_probs, normalize = False):
    
    y_preds = np.argmax(y_probs, axis=1)  # Take maximum confidence as prediction

    if normalize:
        y_confs = np.max(y_probs, axis=1)/np.sum(y_probs, axis=1)
    else:
        y_confs = np.max(y_probs, axis=1)  # Take only maximum confidence
        
    return y_preds, y_confs

## Calibration methods for both 1-vs-rest and mutliclass approach

In [None]:
def cal_res(method, path, file, M = 15, name = "", approach = "single", m_kwargs = {}):    
    bin_size = 1/M

    FILE_PATH = join(path, file)
    (y_logits_val, y_val), (y_logits_test, y_test) = unpickle_probs(FILE_PATH)
    
    y_probs_val = softmax(y_logits_val)
    y_probs_test = softmax(y_logits_test)
    
    if method is None:
        y_preds_test, y_confs_test = get_pred_conf(y_probs_test, normalize = False)
        return get_bin_info(y_confs_test, y_preds_test, y_test, bin_size = bin_size)
    
    
    if approach == "single":
        
        K = y_probs_test.shape[1]

        # Go through all the classes
        for k in range(K):
            # Prep class labels (1 fixed true class, 0 other classes)
            y_cal = np.array(y_val == k, dtype="int")[:, 0]

            # Train model
            model = method(**m_kwargs)
            model.fit(y_probs_val[:, k], y_cal) # Get only one column with probs for given class "k"

            y_probs_val[:, k] = model.predict(y_probs_val[:, k])  # Predict new values based on the fittting
            y_probs_test[:, k] = model.predict(y_probs_test[:, k])

            # Replace NaN with 0, as it should be close to zero  # TODO is it needed?
            idx_nan = np.where(np.isnan(y_probs_test))
            y_probs_test[idx_nan] = 0

            idx_nan = np.where(np.isnan(y_probs_val))
            y_probs_val[idx_nan] = 0
            
            y_preds_val, y_confs_val = get_pred_conf(y_probs_val, normalize = False)
            y_preds_test, y_confs_test = get_pred_conf(y_probs_test, normalize = False)
    
    else:
        model = method(**m_kwargs)
        model.fit(y_logits_val, y_val)

        y_probs_val = model.predict(y_logits_val) 
        y_probs_test = model.predict(y_logits_test)

        y_preds_val, y_confs_val = get_pred_conf(y_probs_val, normalize = False)
        y_preds_test, y_confs_test = get_pred_conf(y_probs_test, normalize = False)
            
    ece = ECE(y_confs_test, y_preds_test, y_test, bin_size = bin_size)
    mce = MCE(y_confs_test, y_preds_test, y_test, bin_size = bin_size)
    accs_test, confs_test, len_bins_test = get_bin_info(y_confs_test, y_preds_test, y_test, bin_size = bin_size)
    
    return ((accs_test, confs_test, len_bins_test), ece, mce)
    

In [None]:
def gen_plots(files, plot_names =  [], M = 15):
    for i, file in enumerate(files):
        
        uc, uc_ece, uc_mce = cal_res(None, PATH, file, M)
        ts, ts_ece, ts_mce = cal_res(TemperatureScaling, PATH, file, M, "", "multi")
        hb, hb_ece, hb_mce = cal_res(HistogramBinning, PATH, file, M, "", "single", {'M':M})
        ir, ir_ece, ir_mce = cal_res(IsotonicRegression, PATH, file, M, "", "single", {'y_min':0, 'y_max':1})
        
        accs_confs = []
        accs_confs.append(uc)
        accs_confs.append(ts)
        accs_confs.append(hb)
        accs_confs.append(ir)

        plt.style.use('ggplot')
        fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(22.5, 6), sharex='col', sharey='row')
        fig.suptitle('epoch_' + file.split('_')[1], size = 30)
        names = [" Uncal|ECE:{}|MCE:{} ".format(uc_ece, uc_mce),
                 " Temp|ECE:{}|MCE:{} ".format(ts_ece, ts_mce),
                 " Histo|ECE:{}|MCE:{}".format(hb_ece, hb_mce),
                 " Iso|ECE:{}|MCE:{} ".format(ir_ece, ir_mce)]
        
        for j in range(4):
            rel_diagram_sub(accs_confs[j][0], accs_confs[j][1], ax[j] , M = M, name = names[j], xname="Confidence")

        ax[0].set_ylabel("Accuracy", color = "black")
        
        for ax_temp in ax:    
            plt.setp(ax_temp.get_xticklabels(), rotation='horizontal', fontsize=18)
            plt.setp(ax_temp.get_yticklabels(), fontsize=18)

        plt.savefig("../../reliability_diagrams/EXP2/ls_1/epoch_" + file.split('_')[1] + ".pdf", format='pdf', dpi=1000, bbox_inches='tight', pad_inches=0.2)
        plt.show()

In [None]:
gen_plots(files, M = 15)