In [1]:
import os
import numpy as np

import torch

from sklearn.calibration import calibration_curve

from source.constants import RESULTS_PATH, PLOTS_PATH
from source.data.medical_imaging import get_chexpert

os.makedirs(PLOTS_PATH, exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
method_seeds = [42, 142, 242, 342, 442]
dseed = 42

model = ["resnet18", "resnet34", "resnet50", "regnet", "efficientnet", "efficientnet_mcdropout"][2]

pas = ["old", "woman", "white"]
pa = 0 # 0, 1, 2

In [3]:
full_ds, _, _ = get_chexpert(load_to_ram=False)

run_path = os.path.join(RESULTS_PATH, f"chexpert_{model}_mseed{method_seeds[0]}_dseed{dseed}")
fair_inds = torch.load(os.path.join(run_path, "fair_inds.pt"))
val_inds = torch.load(os.path.join(run_path, "val_inds.pt"))

print(len(fair_inds), len(val_inds))

y_fair_t = full_ds.targets[fair_inds]
a_fair_t = full_ds.protected_attributes[pa, fair_inds]
y_val_t = full_ds.targets[val_inds]
a_val_t = full_ds.protected_attributes[pa, val_inds]

# switch 0 / 1 in protected attribute
pas = ["young", "man", "non-white"]
a_fair_t = 1 - a_fair_t
a_val_t = 1 - a_val_t

# switch label
# y_fair_t = 1 - y_fair_t
# y_val_t = 1 - y_val_t

p_a_fair = a_fair_t.float().mean().item() * 100
p_y_fair = y_fair_t.float().mean().item() * 100

# patients general 65401
# patients with race 58010
24638 24638


In [4]:
# load probits
fair_probits, val_probits = list(), list()
for mseed in method_seeds:
    path = os.path.join(RESULTS_PATH, f"chexpert_{model}_mseed{mseed}_dseed{dseed}")

    # don't do fairness ensemble on medical imaging - use this split as test dataset
    fp = torch.load(os.path.join(path, f"fair_probits.pt"))
    #fp = 1 - fp
    fair_probits.append(fp)
    vp = torch.load(os.path.join(path, f"val_probits.pt"))
    #vp = 1 - vp
    val_probits.append(vp)

In [5]:
def ece(y_probs, y_trues, n_bins):
    # Compute the calibration curve
    fraction_of_positives, mean_predicted_value = calibration_curve(y_trues, y_probs, n_bins=n_bins, strategy='uniform')
    
    # Define bin edges
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    
    # Assign each probability prediction to a bin
    bin_indices = np.digitize(y_probs, bins=bin_edges, right=True) - 1
    # Correct any indices that are out of bounds
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)
    
    # Total number of samples
    n_samples = len(y_trues)
    
    # Count the number of samples per bin
    bin_counts = np.bincount(bin_indices, minlength=n_bins)
    
    # Calculate the weight of each bin (proportion of total samples)
    bin_weights = bin_counts / n_samples
    
    # Compute the absolute difference between accuracy and confidence for each bin
    bin_errors = np.abs(fraction_of_positives - mean_predicted_value)

    # Calculate the Expected Calibration Error
    ece = np.sum(bin_weights * bin_errors)
    
    return ece

# y_true = np.asarray([0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1])
# y_prob = np.asarray([0.1, 0.4, 0.35, 0.8, 0.1, 0.4, 0.25, 0.5, 0.1, 0.4, 0.35, 0.9])
# ece = ece(y_prob, y_true, n_bins=5)
# print(f"Expected Calibration Error: {ece:.4f}")

In [6]:
ensemble_members = list(range(1, len(fair_probits[0]) + 1))

fair_m_eces, fair_eces,  = list(), list()

for m in range(len(method_seeds)):

    fair_eces.append([ece(p[:, 1], y_fair_t, n_bins=10) for p in fair_probits[m]])

    fair_fm_eces_ = list()


    probs = torch.mean(fair_probits[m], dim=0)[:, 1]
    fair_fm_eces_.append(ece(probs, y_fair_t, n_bins=10))

    fair_m_eces.append(fair_fm_eces_)

fair_m_eces = np.asarray(fair_m_eces).reshape(-1, )
fair_eces = np.asarray(fair_eces).reshape(-1, )

print(f"{fair_eces.mean(axis=0):.3f} $\pm$ {fair_eces.std(axis=0):.3f}")
print(f"{fair_m_eces.mean(axis=0):.3f} $\pm$ {fair_m_eces.std(axis=0):.3f}")

0.012 $\pm$ 0.003
0.011 $\pm$ 0.001
