# UQ metrics

In [44]:
import rasterio, yaml, os
import numpy as np
from pathlib import Path
import scipy
from typing import *

PREDICTIONS_DIR = Path("results/dev/2023-03-14_15-45-23")
PKL_DIR = Path('data/pkl/2021-05-18_10-57-45')
GT_DIR = Path('data/preprocessed')

EAST = ['346', '9', '341', '354', '415', '418', '416', '429', '439', '560', '472', '521', '498',
        '522', '564', '764', '781', '825', '796', '805', '827', '891', '835', '920', '959', '1023', '998',
        '527', '477', '542', '471']
WEST = ['528', '537', '792', '988', '769']
NORTH = ['819', '909', '896']
ALL = EAST + WEST + NORTH

with (PKL_DIR / 'stats.yaml').open() as fh:
    # load training set statistics for data normalization
    stats = yaml.safe_load(fh)
    labels_mean = np.array(stats['labels_mean'])

projects = [f.stem.split("_")[0] for f in PREDICTIONS_DIR.glob('*_mean.tif') if f.stem.split("_")[0] in ALL]

## Quantitative metrics

Let $\mathcal{D}=\left\{(\mathbf{x}_i, \mathbf{y}_i) \in \mathcal{X}\times\mathcal{Y}\right\}_{i=1,\ldots,N}$ be the test set and $\mathcal{P}=\left\{(\hat\mu, \hat\sigma^2)  \in \mathcal{X}\times\mathcal{Y}\right\}_{i=1,\ldots,N}$ be the corresponding pixel-wise predicted mean and variance.

In [90]:
def nan_uce(variance, mean, gt, n_bins):
    d = gt.shape[0]
    # Mask nan
    mask = (~np.isnan(variance) & ~np.isnan(mean)).all(0)
    variance = variance[:,mask]
    mean = mean[:,mask]
    gt = gt[:,mask]
    # Compute UCE for each variables
    uce = np.empty((d,))
    prop_in_bins = np.empty((d, n_bins))
    uncertainty_in_bins = np.empty((d, n_bins))
    variance_in_bins = np.empty((d, n_bins))
    for i, (var, mu, tgt) in enumerate(zip(variance, mean, gt)):
        # Linear binning
        bins = np.linspace(var.min(), var.max(), n_bins)
        # Get variance bin indexes
        bins_ids = np.digitize(var, bins=bins)
        # Loop on bins to compute statistics
        _uce = 0
        for bin_id in np.unique(bins_ids)-1:
            # Select bin
            pos = bins_ids==bin_id+1
            prop_in_bin = pos.astype("float").mean() # bin_size / N
            bin_var = var[pos]
            bin_mean = mu[pos]
            bin_tgt = tgt[pos]
            # Compute stats
            mean_uncertainty = bin_var.mean()
            mean_variance = ((bin_mean-bin_tgt)**2).mean()
            _uce += prop_in_bin * np.abs(mean_variance - mean_uncertainty)
            # keep result
            prop_in_bins[i,bin_id] = prop_in_bin
            uncertainty_in_bins[i,bin_id] = mean_uncertainty
            variance_in_bins[i,bin_id] = mean_variance
        uce[i] = _uce
    return uce, uncertainty_in_bins, variance_in_bins, prop_in_bins

def nan_ence(variance, mean, gt, n_bins):
    _, bins_mean_mse, bins_mean_variance, bins_proportions = nan_uce(variance, mean, gt, n_bins)
    bins_mean_rmse, bins_mean_std = np.sqrt(bins_mean_mse), np.sqrt(bins_mean_variance)
    ence = (np.abs(bins_mean_std-bins_mean_rmse) / bins_mean_std).mean(1)
    return ence, bins_mean_mse, bins_mean_variance, bins_proportions

### 1. Uncertainty Calibration Error (UCE) [Laves 2020, Laves 2021, Levi 2019, Becker 2023] and Expected Normalized Calibration Error (ENCE) [Levi 2019, Zhou 2021a]

The code above is correct BUT it computes UCE/ENCE on a single sample which does not really make sense. We need to compute it on all the predictions/gt pairs. To do so, we could for instance, create a 1d vector for each variable that contains the values for all the predicted pixels. Then, we can use a similar code to compute the metrics.

In [91]:
project = projects[0]
mean_file = os.path.join(PREDICTIONS_DIR, f"{project}_mean.tif")
with rasterio.open(mean_file) as fh:
    mean = fh.read(fh.indexes)
with rasterio.open(PREDICTIONS_DIR / (project + '_variance.tif')) as fh:
    variance = fh.read(fh.indexes)
with rasterio.open(GT_DIR / (project + '.tif')) as fh:
    gt = fh.read(fh.indexes)
    gt_mask = fh.read_masks(1).astype(bool)

In [93]:
nan_uce(variance, mean, gt, 20)[0], nan_ence(variance, mean, gt, 20)[0]

(array([3.57581110e+00, 1.91647382e+00, 2.77684128e+03, 1.15106952e-03,
        6.44200623e+03]),
 array([0.41494112, 0.55113869, 0.99615498, 0.63229099, 0.99686558]))