Generalized Calibration Error

Copyright 2022 Carnegie Mellon University.

NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE
MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO
WARRANTIES OF ANY KIND, EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER 
INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS FOR PURPOSE OR 
MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL. 
CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT
TO FREEDOM FROM PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.

Released under a MIT (SEI)-style license, please see license.txt or contact 
permission@sei.cmu.edu for full terms.

[DISTRIBUTION STATEMENT A] This material has been approved for public release 
and unlimited distribution.  Please see Copyright notice for non-US Government 
use and distribution.

This Software includes and/or makes use of the following Third-Party Software 
subject to its own license:

1. calibration (https://github.com/uu-sml/calibration/blob/master/LICENSE) 
Copyright 2019 Carl Andersson, David Widmann.

2. NumPy (https://github.com/numpy/numpy/blob/main/LICENSE.txt) 
Copyright 2005-2022 NumPy Developers.

# Example Notebook for the Generalized Calibration Error Library
Each of the blocks below provide examples of the different components of a calibration error, and how to use them.  For discussion of these components and examples that match use cases for calibration error see ["What is Your Metric Telling You? Evaluating Classifier Calibration under Context-Specific Definitions of Reliability"](https://arxiv.org/abs/2205.11454).  API documentation can be found found in the docs directory (See README for build instructions of docs).

Each lettered section is meant to be self contained, and can be run without running other sections' code (except for the initial imports and loading of data).

## Initial imports and loading of data (always do these before running any of the cases below)

In [None]:
import numpy as np
from generalized_calibration_error import gce # Generalized calibration error class to create and use a calibration metric

In [None]:
# probs is the outputs over the test set of a simple ResNet trained on CIFAR100; labels are the corresponding labels to those instances.
probs = np.load("example_data/ResNet50_CIFAR100_test_outputs.npy")
labels = np.load("example_data/CIFAR100_test_labels.npy")

## A) Lenses

### 1) Top-1 Lens

In [None]:
# Top-1 lens as defined by component function
from generalized_calibration_error.components.lenses import top_1_lens # components is a module in which different components to design a calibration error reside.  The lenses library within components contains the predefined lenses
top1_ece_v1 = gce(lens = top_1_lens)
top1_ece_v1(probs, labels)

In [None]:
# Top-1 lens as defined by the more general Top-k lens (Should be identical to the above code block)
from generalized_calibration_error.components.lenses import top_k_lens
top1_ece_v2 = gce(lens = top_k_lens(k=1))
top1_ece_v2(probs,labels)

### 2) Full Lens

In [None]:
full_gece_new = gce() # Default lens is the "full" lens
full_gece_new(probs, labels)

### 3)Top-5 Lens

In [None]:
from generalized_calibration_error.components.lenses import top_k_lens
top5_ece = gce(lens = top_k_lens(k=5))
top5_ece(probs,labels)

### 4) Grouping Lens

In [None]:
# Group first 50 versus last 50 (resulting in a binary classification problem)
from generalized_calibration_error.components.lenses import group_lens
group_ece = gce(lens = group_lens(groups = [[*range(50)], [*range(50,100)]]))
group_ece(probs, labels)

In [None]:
# Group each class in it's own group (should be identical to Full lens case)
from generalized_calibration_error.components.lenses import group_lens
group_ece = gce(lens = group_lens(groups = [[num] for num in range(100)]))
group_ece(probs, labels)

### 5) Class-Marginal Lens

In [None]:
# Class 0 marginal lens (class 0 versus rest), using the specialized class marginal class
from generalized_calibration_error.components.lenses import class_marginal_lens
class_0_marginal_gece_v1 = gce(lens = class_marginal_lens(class_num = 0))
class_0_marginal_gece_v1(probs, labels)

In [None]:
# Class 0 marginal lens (class 0 versus rest), using the more general grouping lens (should be identical to the code block above)
from generalized_calibration_error.components.lenses import group_lens
class_0_marginal_gece_v1 = gce(lens = group_lens(groups = [[0], [*range(1,100)]]))
class_0_marginal_gece_v1(probs, labels)

## B) Selection Ops

### Lens for this section (run this for ops examples)

In [None]:
from generalized_calibration_error.components.lenses import top_1_lens

### 1) Label Selection

In [None]:
# Select instances labeled with class 0, compute the top-1 ECE
from generalized_calibration_error.components.selection_ops import label_selection_op
first_class_selection_op = label_selection_op(classes = [0])
label_selection_ece = gce(lens=top_1_lens, selection_op = first_class_selection_op, preselection = True) # Preselection set to True means the selection operator is applied BEFORE the lens is applied
label_selection_ece(probs, labels)

In [None]:
label_selection_ece = gce(lens=top_1_lens, selection_op = first_class_selection_op, preselection = False) # Preselection set to False means the selection operator is applied AFTER the lens is applied
label_selection_ece(probs, labels)

### 2) Output Selection

In [None]:
# After a top-1 lens, select all instances with max probability greater than or equal to 0.999999
import operator
from generalized_calibration_error.components.selection_ops import output_selection_op
high_conf_selection_op = output_selection_op(operator = operator.ge, rhs_value = 0.999999, membership_dim = 0)
output_selection_gece = gce(lens=top_1_lens, selection_op = high_conf_selection_op, preselection = False)
output_selection_gece(probs,labels)

In [None]:
# Same as above but select instances with max probability less than or equal to 0.2
low_conf_selection_op = output_selection_op(operator = operator.le, rhs_value = 0.2, membership_dim = 0)
output_selection_ece = gce(lens=top_1_lens, selection_op = low_conf_selection_op, preselection = False)
output_selection_ece(probs,labels)

In [None]:
# Same as above but select instances with max probability between 0.45 and 0.55, inclusive
high_end_selection_op = output_selection_op(operator = operator.le, rhs_value = 0.55, membership_dim = 0)
low_end_selection_op = output_selection_op(operator = operator.ge, rhs_value = 0.45, membership_dim = 0)
output_selection_ece = gce(lens=top_1_lens, selection_op = lambda a,b: high_end_selection_op(*low_end_selection_op(a,b)), preselection = False)
output_selection_ece(probs,labels)

## C) Aggregation Functions

### Lens for this section (run these for aggregation examples)

In [None]:
from generalized_calibration_error.components.lenses import top_1_lens

### 1) Expectation

In [None]:
# Perform top-1 EXPECTED calibration error (In this case, due to histogram binning being the estimation scheme, the expectation is over bins).
from generalized_calibration_error.components.aggregation_fns import expectation
ece = gce(lens = top_1_lens, aggregation_fn = expectation)
ece(probs,labels)

### 2) Maximum

In [None]:
# Perform top-1 MAXIMUM calibration error (In this case, due to histogram binning being the estimation scheme, the maximum is over bins)
from generalized_calibration_error.components.aggregation_fns import maximum
mce = gce(lens = top_1_lens, aggregation_fn = maximum)
mce(probs,labels)

## D) Estimation Schemes

### Lens for this section (run these for estimation examples)

In [None]:
from generalized_calibration_error.components.lenses import top_1_lens

### 1) Uniform Binning Histogram Estimation

In [None]:
# Estimation is done via histogram binning with 1000 uniform sized bins 
from generalized_calibration_error.components.estimation_schemes.histogram_binning_estimators import uniform_histogram_binning_estimator
uniform_hist_top1_ece = gce(lens = top_1_lens, estimation_scheme = uniform_histogram_binning_estimator(bins=1000))
uniform_hist_top1_ece(probs,labels)

### 2) Adaptive Binning Histogram Estimation

In [None]:
# Estimation is done via histogram binning with adaptive bins with approximately 0.001 of the data in each bin (~ 10 instances per bin)
from generalized_calibration_error.components.estimation_schemes.histogram_binning_estimators import adaptive_histogram_binning_estimator
adaptive_hist_top1_ece = gce(lens = top_1_lens, estimation_scheme = adaptive_histogram_binning_estimator(frac_per_bin=0.001))
adaptive_hist_top1_ece(probs,labels)

## E) Distance Functions

### Max lens for the next two subsections (run this first) (also assumes no selection operators and uniform binning with 15 bins)

In [None]:
from generalized_calibration_error.components.lenses import top_1_lens

### 1) Total Variation Distance

In [None]:
# Distance used for error calculation is total variation distance (should be identical to A.1)
from generalized_calibration_error.components.distance_fns import tvd
tvd_ece = gce(lens = top_1_lens, distance_fn = tvd)
tvd_ece(probs, labels)

### 2) Generalized Mahalanobis Distance

In [None]:
# Distance used for error calculation is a Mahalanobis distance with identity as the precision matrix (equivalent to l2 distance)
from generalized_calibration_error.components.distance_fns import generalized_mahalanobis_distance
GMD_ece = gce(lens = top_1_lens, distance_fn = generalized_mahalanobis_distance(W = np.eye(2)))
GMD_ece(probs, labels)

In [None]:
# Same as above block, but with a non-identity precision matrix
GMD_ece = gce(lens = top_1_lens, distance_fn = generalized_mahalanobis_distance(W = np.array([[1,0.5],[0.05,1]])))
GMD_ece(probs, labels)

### Grouping Lens, Interval Selection operators, and adaptive binning, for the next two subsections

In [None]:
# Group first 50 versus last 50
from generalized_calibration_error.components.lenses import group_lens
group_lens = group_lens(groups = [[*range(50)], [*range(50,100)]])

# Select all instances between 0.6 and 0.4 probability
import operator
from generalized_calibration_error.components.selection_ops import output_selection_op
high_end_selection_op = output_selection_op(operator = operator.le, rhs_value = 0.6, membership_dim = 0)
low_end_selection_op = output_selection_op(operator = operator.ge, rhs_value = 0.4, membership_dim = 0)
interval_selection_op =  lambda a,b: high_end_selection_op(*low_end_selection_op(a,b))

# Adaptive binning
from generalized_calibration_error.components.estimation_schemes.histogram_binning_estimators import adaptive_histogram_binning_estimator
adaptive_estimation_scheme = adaptive_histogram_binning_estimator(frac_per_bin=0.1)

### 3) Total Interval TVD

In [None]:
# Distance used is the standard total variational distance, except for bins where the mean label is inside the interval [0.4] [0.6], in which it is 0. (incurs error for only bins that violate a interval assumption)
from generalized_calibration_error.components.distance_fns import interval_tvd
medium_interval_tvd = interval_tvd(interval = [0.4,0.6], inclusivity = [True, True])

medium_interval_ece = gce(lens = group_lens, 
                         selection_op = interval_selection_op, 
                         preselection = False, 
                         estimation_scheme = adaptive_estimation_scheme, 
                         distance_fn = medium_interval_tvd)
medium_interval_ece(probs, labels)

### 4) Inter-Interval Distance

In [None]:
# Distance used is distance of the mean label to the closest interval boundary for the interval [0.4, 0.6], or 0 if within the interval (incurs error for only bins that violate a interval assumption, proportional to degree of violation)
from generalized_calibration_error.components.distance_fns import inter_interval_distance
medium_interval_distance = inter_interval_distance(interval = [0.4,0.6])

medium_interval_ece = gce(lens = group_lens, 
                         selection_op = interval_selection_op, 
                         preselection = False, 
                         estimation_scheme = adaptive_estimation_scheme, 
                         distance_fn = medium_interval_distance)
medium_interval_ece(probs, labels)