In [None]:
# allows update of external libraries without need to reload package
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import matplotlib.pyplot as plt
import pickle
import pathlib
import numpy as np
import warnings

import sklearn.metrics
import scipy.stats

In [None]:
FOLDER_RESULTS = pathlib.Path("../Result/Task1/electra_base-fl/")

In [None]:
with open(FOLDER_RESULTS / r"result_prob.pkl", "rb") as input_file:
    probabilities = pickle.load(input_file)

with open(FOLDER_RESULTS / r"result_score.pkl", "rb") as input_file:
    scores = pickle.load(input_file)

In [None]:
probabilities[1].keys()

In [None]:
def _compute_mean_accuracy(binnumber, labels, predictions, n_bins=10):
    mean_accuray = np.full(n_bins + 2, np.nan)
    for i_bin in range(n_bins + 2):
        mask = binnumber == i_bin
        mean_accuray[i_bin] = sklearn.metrics.accuracy_score(labels[mask], predictions[mask])
    return mean_accuray[slice(1,-1)]
    
def _compute_mean_and_standard_deviation(values):
    mean = np.mean(values)
    standard_deviation = np.sqrt(np.sum((values-mean)**2)/len(values))
    return mean, standard_deviation


def _get_ece(confidence, labels, predictions, n_bins=10):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        bins = get_bins(n_bins=n_bins)
        confidence_mean = _compute_confidence_mean(confidence, bins)
        n_per_bin, _, binnumber = get_counts_edges_binnumber(confidence, bins)
        accuracy_mean = _compute_mean_accuracy(binnumber, labels, predictions, n_bins=n_bins)    
        number_samples = _get_total_number_samples(n_per_bin)
        check_number_samples(labels, number_samples)
        ece = _compute_ece(confidence_mean, n_per_bin, accuracy_mean, number_samples)
    return ece

def _compute_ece(confidence_mean, n_per_bin, accuracy_mean, number_samples):
    return np.nansum(n_per_bin*np.abs(accuracy_mean-confidence_mean))/number_samples

def check_number_samples(labels, number_samples):
    if number_samples != np.shape(labels)[0]:
        raise ValueError(f"{number_samples=} != {np.shape(labels)[0]}")

def _get_total_number_samples(n_per_bin):
    return np.nansum(n_per_bin)

def get_counts_edges_binnumber(confidence, bins):
    return scipy.stats.binned_statistic(confidence, np.ones_like(confidence), statistic="count", bins=bins)

def _compute_confidence_mean(confidence, bins):
    confidence_mean, bin_edges, binnumber = scipy.stats.binned_statistic(confidence, confidence, statistic="mean", bins=bins)
    return confidence_mean

def get_bins(n_bins):
    return np.linspace(0, 1, n_bins + 1)

def get_confidence(probabilities, predictions):
    confidence = np.zeros_like(probabilities[:, 0])
    for i_row, (pred, prob) in enumerate(zip(predictions, probabilities)):
        confidence[i_row] = prob[pred]
    return confidence

    
def compute_ece(labels, prediction_probabilities):
    ece = np.zeros(np.shape(prediction_probabilities)[0])
    for i_run in range(len(prediction_probabilities)):
        probabilities_run = prediction_probabilities[i_run]
        predictions_run = np.argmax(probabilities_run, axis=1)
        confidence_run = get_confidence(probabilities_run, predictions_run)
        ece_run = _get_ece(confidence=confidence_run, labels=labels, predictions=predictions_run)
        ece[i_run] = ece_run
    print(f"{ece=}")
    mean, standard_deviation = _compute_mean_and_standard_deviation(ece)
    return mean, standard_deviation

In [None]:
for model_config in probabilities[1].keys():
    print(f"{model_config}")
    prediction_probabilities = np.array(probabilities[1][model_config])
    labels = np.array(probabilities[0]['test_in'])
    mean, standard_deviation = compute_ece(labels, prediction_probabilities)
    print(f"{mean=}, {standard_deviation=}")