In [12]:
import os
import numpy as np
from error_parity import RelaxedThresholdOptimizer
from error_parity.classifiers import RandomizedClassifier

import torch
import torch.nn as nn

from source.constants import RESULTS_PATH, PLOTS_PATH
from source.data.medical_imaging import get_chexpert
from source.utils.metrics import balanced_accuracy, aod, eod, spd

os.makedirs(PLOTS_PATH, exist_ok=True)

In [13]:
import warnings
# handle upstream FutureWarnings regarding solver in cvxpy used by error_parity. The default solver will be changed from ECOS to Clarabel in cvxpy 1.16.0
warnings.filterwarnings("ignore", category=FutureWarning)

In [14]:
method_seeds = [42, 142, 242, 342, 442]
dseed = 42

model = ["resnet18", "resnet34", "resnet50"][2]

verbose = False

pas = ["old", "woman", "white"]
pa = 0 # 0, 1, 2

In [15]:
# parameters
c = 0
constraint = ["demographic_parity", "true_positive_rate_parity", "average_odds"][c]

In [16]:
full_ds, _, _ = get_chexpert(load_to_ram=False)

run_path = os.path.join(RESULTS_PATH, f"chexpert_{model}_mseed{method_seeds[0]}_dseed{dseed}")
fair_inds = torch.load(os.path.join(run_path, "fair_inds.pt"))
val_inds = torch.load(os.path.join(run_path, "val_inds.pt"))

print(len(fair_inds), len(val_inds))

y_fair_t = full_ds.targets[fair_inds]
a_fair_t = full_ds.protected_attributes[pa, fair_inds]
y_val_t = full_ds.targets[val_inds]
a_val_t = full_ds.protected_attributes[pa, val_inds]

# switch 0 / 1 in protected attribute
pas = ["young", "man", "non-white"]
a_fair_t = 1 - a_fair_t
a_val_t = 1 - a_val_t

# switch label
# y_fair_t = 1 - y_fair_t
# y_val_t = 1 - y_val_t

p_a_fair = a_fair_t.float().mean().item() 

# patients general 65401
# patients with race 58010
24638 24638


In [17]:
# load probits
fair_probits, val_probits = list(), list()
for mseed in method_seeds:
    path = os.path.join(RESULTS_PATH, f"chexpert_{model}_mseed{mseed}_dseed{dseed}")

    # don't do fairness ensemble on medical imaging - use this split as test dataset
    fp = torch.load(os.path.join(path, f"fair_probits.pt"))
    #fp = 1 - fp
    fair_probits.append(fp)
    vp = torch.load(os.path.join(path, f"val_probits.pt"))
    #vp = 1 - vp
    val_probits.append(vp)

In [18]:
# calculate accuracies and fairness measures
fair_balanced_accuracys, val_balanced_accuracys = list(), list()
fair_spds, val_spds = list(), list()
fair_eods, val_eods = list(), list()
fair_aods, val_aods = list(), list()

for m in range(len(method_seeds)):
    fair_balanced_accuracys.append([balanced_accuracy(p.argmax(dim=1), y_fair_t) for p in fair_probits[m]])
    val_balanced_accuracys.append([balanced_accuracy(p.argmax(dim=1), y_val_t) for p in val_probits[m]])

    fair_spds.append([spd(p.argmax(dim=1), a_fair_t) for p in fair_probits[m]])
    val_spds.append([spd(p.argmax(dim=1), a_val_t) for p in val_probits[m]])

    fair_eods.append([eod(p.argmax(dim=1), y_fair_t, a_fair_t) for p in fair_probits[m]])
    val_eods.append([eod(p.argmax(dim=1), y_val_t, a_val_t) for p in val_probits[m]])

    fair_aods.append([aod(p.argmax(dim=1), y_fair_t, a_fair_t) for p in fair_probits[m]])
    val_aods.append([aod(p.argmax(dim=1), y_val_t, a_val_t) for p in val_probits[m]])
    

In [19]:
# method to do the fake predictions
class DummyPredictor(nn.Module):
    def __init__(self, probits):
        super(DummyPredictor, self).__init__()
        self.probits = probits

    def forward(self, indices:torch.Tensor):
        return self.probits[indices].numpy()

In [20]:
def get_thresholds(fair_clf, verbose=True):
    thresholds = list()
    for i in range(2):
        if verbose: print(f"Class {i}")
        if isinstance(fair_clf._realized_classifier.group_to_clf[i], RandomizedClassifier):
            thrs = list()
            for clf in fair_clf._realized_classifier.group_to_clf[i].classifiers:
                if verbose: print(clf.threshold)
                thrs.append(clf.threshold)
            thresholds.append(thrs)   
        else:
            thrs = fair_clf._realized_classifier.group_to_clf[i].threshold
            if verbose: print(thrs)
            thresholds.append([thrs, thrs])
    return thresholds

Optimize Ensemble for average member constraint

In [21]:
balanced_accuracys_bma, fairs_bma = list(), list()
balanced_accuracys_bma_pp, fairs_bma_pp = list(), list()
balanced_accuracys_avg, fairs_avg = list(), list()
thresholds_fairs_bma_pp = list()

for m in range(len(method_seeds)):

    if verbose: print("-"*20 + f"  seed {m}  " + "-"*20)

    val_m_probits = torch.mean(val_probits[m], dim=0)

    val_fairness = [val_spds[m], val_eods[m], val_aods[m]][c]
    test_fairness = [fair_spds[m], fair_eods[m], fair_aods[m]][c]

    model = DummyPredictor(val_m_probits)

    # Given any trained model that outputs real-valued scores
    if constraint == "average_odds":
        fair_clf = RelaxedThresholdOptimizer(
            predictor=lambda X: model(X)[:, -1],   # for sklearn API
            constraint="equalized_odds",
            l_p_norm=1,
            tolerance=2 * max(np.mean(val_fairness), 0), # fairness constraint tolerance, use twice because of norm
        )
    else:
        fair_clf = RelaxedThresholdOptimizer(
            predictor=lambda X: model(X)[:, -1],   # for sklearn API
            constraint=constraint,
            tolerance=max(np.mean(val_fairness), 0), # fairness constraint tolerance
        )

    # Fit the fairness adjustment on some data
    # This will find the optimal _fair classifier_
    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())

    # Get the thresholds for the optimal classifier
    thresholds_fairs_bma_pp.append(get_thresholds(fair_clf, verbose=verbose))

    # overwrite model for predictor
    ff_test_m_probits = torch.mean(fair_probits[m], dim=0)
    model.probits = ff_test_m_probits

    # Now you can use `fair_clf` as any other classifier
    # You have to provide group information to compute fair predictions
    y_pred_test = fair_clf(X=torch.tensor(range(len(y_fair_t))), group=a_fair_t.numpy())
    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)

    if verbose: print("Avg Member")
    balanced_accuracys_avg.extend(fair_balanced_accuracys[m])
    if verbose: print(f"  {(fair_balanced_accuracys[0][m]):.3f} ")
    fairs_avg.extend(test_fairness)
    if verbose: print(f"  {test_fairness[0]:.3f} (val: {val_fairness[0]:.3f})")
    if verbose: print("BMA")
    balanced_accuracys_bma.append(balanced_accuracy(ff_test_m_probits.argmax(dim=1), y_fair_t).item())
    if verbose: print(f"  {(balanced_accuracy(ff_test_m_probits.argmax(dim=1), y_fair_t).item()):.3f}")
    if c == 0:
        fairs_bma.append(spd(ff_test_m_probits.argmax(dim=1), a_fair_t).item())
        if verbose: print(f"  {spd(ff_test_m_probits.argmax(dim=1), a_fair_t).item():.3f}")
    elif c == 1:
        fairs_bma.append(eod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item())
        if verbose: print(f"  {eod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item():.3f}")
    elif c == 2:
        fairs_bma.append(aod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item())
        if verbose: print(f"  {aod(ff_test_m_probits.argmax(dim=1), y_fair_t, a_fair_t).item():.3f}")
    if verbose: print("BMA-PP")
    balanced_accuracys_bma_pp.append(balanced_accuracy(y_pred_test, y_fair_t).item())
    if verbose: print(f"  {(balanced_accuracy(y_pred_test, y_fair_t).item()):.3f} ")
    if c == 0:
        fairs_bma_pp.append(spd(y_pred_test, a_fair_t).item())
        if verbose: print(f"  {spd(y_pred_test, a_fair_t).item():.3f}")
    elif c == 1:
        fairs_bma_pp.append(eod(y_pred_test, y_fair_t, a_fair_t).item())
        if verbose: print(f"  {eod(y_pred_test, y_fair_t, a_fair_t).item():.3f}")
    elif c == 2:
        fairs_bma_pp.append(aod(y_pred_test, y_fair_t, a_fair_t).item())
        if verbose: print(f"  {aod(y_pred_test, y_fair_t, a_fair_t).item():.3f}")

thresholds_fairs_bma_pp = np.asarray(thresholds_fairs_bma_pp)

print("-"*30)
print(f"${np.mean(balanced_accuracys_avg):.3f}_{'{'}\pm {np.std(balanced_accuracys_avg):.3f}{'}'}$", end=" & ")
print(f"${np.mean(fairs_avg):.3f}_{'{'}\pm {np.std(fairs_avg):.3f}{'}'}$")
print(f"${np.mean(balanced_accuracys_bma):.3f}_{'{'}\pm {np.std(balanced_accuracys_bma):.3f}{'}'}$", end=" & ")
print(f"${np.mean(fairs_bma):.3f}_{'{'}\pm {np.std(fairs_bma):.3f}{'}'}$")
print(f"${np.mean(balanced_accuracys_bma_pp):.3f}_{'{'}\pm {np.std(balanced_accuracys_bma_pp):.3f}{'}'}$", end=" & ")
print(f"${np.mean(fairs_bma_pp):.3f}_{'{'}\pm {np.std(fairs_bma_pp):.3f}{'}'}$")
print("-"*30)
for i in range(2):
    print(f"Group {i}")
    print(f"${np.mean(thresholds_fairs_bma_pp[:, i, 0]):.3f}_{'{'}\pm {np.std(thresholds_fairs_bma_pp[:, i, 0]):.3f}{'}'}$")
    print(f"${np.mean(thresholds_fairs_bma_pp[:, i, 1]):.3f}_{'{'}\pm {np.std(thresholds_fairs_bma_pp[:, i, 1]):.3f}{'}'}$")

------------------------------
$0.783_{\pm 0.008}$ & $0.138_{\pm 0.004}$
$0.786_{\pm 0.004}$ & $0.139_{\pm 0.001}$
$0.801_{\pm 0.004}$ & $0.122_{\pm 0.019}$
------------------------------
Group 0
$0.457_{\pm 0.021}$
$0.440_{\pm 0.010}$
Group 1
$0.509_{\pm 0.049}$
$0.496_{\pm 0.058}$


In [22]:
balanced_accuracys_bma_pp, fairs_bma_pp = list(), list()
balanced_accuracys_member_pp, fairs_member_pp = list(), list()
thresholds_fairs_bma_pp = list()
thresholds_fairs_member_pp = list()

for m in range(len(method_seeds)):

    if verbose: print("-"*20 + f"  seed {m}  " + "-"*20)

    val_m_probits = torch.mean(val_probits[m], dim=0)

    model = DummyPredictor(val_m_probits)

    # Given any trained model that outputs real-valued scores
    if constraint == "average_odds":
        fair_clf = RelaxedThresholdOptimizer(
            predictor=lambda X: model(X)[:, -1],   # for sklearn API
            constraint="equalized_odds",
            l_p_norm=1,
            tolerance=2 * 0.05, # fairness constraint tolerance, use twice because of norm
        )
    else:
        fair_clf = RelaxedThresholdOptimizer(
            predictor=lambda X: model(X)[:, -1],   # for sklearn API
            constraint=constraint,
            tolerance=0.05, # fairness constraint tolerance
        )

    # Fit the fairness adjustment on some data
    # This will find the optimal _fair classifier_
    fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())

    thresholds_fairs_bma_pp.append(get_thresholds(fair_clf, verbose=verbose))

    # overwrite model for predictor
    ff_test_m_probits = torch.mean(fair_probits[m], dim=0)
    model.probits = ff_test_m_probits

    # Now you can use `fair_clf` as any other classifier
    # You have to provide group information to compute fair predictions
    y_pred_test = fair_clf(X=torch.tensor(range(len(y_fair_t))), group=a_fair_t.numpy())
    y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)

    if verbose: print("BMA-PP")
    balanced_accuracys_bma_pp.append(balanced_accuracy(y_pred_test, y_fair_t).item())
    if verbose: print(f"  {(balanced_accuracy(y_pred_test, y_fair_t).item()):.3f}")
    if c == 0:
        fairs_bma_pp.append(spd(y_pred_test, a_fair_t).item())
        if verbose: print(f"  {spd(y_pred_test, a_fair_t).item():.3f}")
    elif c == 1:
        fairs_bma_pp.append(eod(y_pred_test, y_fair_t, a_fair_t).item())
        if verbose: print(f"  {eod(y_pred_test, y_fair_t, a_fair_t).item():.3f}")
    elif c == 2:
        fairs_bma_pp.append(aod(y_pred_test, y_fair_t, a_fair_t).item())
        if verbose: print(f"  {aod(y_pred_test, y_fair_t, a_fair_t).item():.3f}")

    for mem in range(len(val_probits[m])):
        val_m_probits = val_probits[m][mem]

        model = DummyPredictor(val_m_probits)

        # Given any trained model that outputs real-valued scores
        if constraint == "average_odds":
            fair_clf = RelaxedThresholdOptimizer(
                predictor=lambda X: model(X)[:, -1],   # for sklearn API
                constraint="equalized_odds",
                l_p_norm=1,
                tolerance=2 * 0.05, # fairness constraint tolerance, use twice because of norm
            )
        else:
            fair_clf = RelaxedThresholdOptimizer(
                predictor=lambda X: model(X)[:, -1],   # for sklearn API
                constraint=constraint,
                tolerance=0.05, # fairness constraint tolerance
            )

        # Fit the fairness adjustment on some data
        # This will find the optimal _fair classifier_
        fair_clf.fit(X=torch.tensor(range(len(y_val_t))), y=y_val_t.numpy(), group=a_val_t.numpy())

        thresholds_fairs_member_pp.append(get_thresholds(fair_clf, verbose=verbose))

        # overwrite model for predictor
        ff_test_m_probits = fair_probits[m][0]
        model.probits = ff_test_m_probits

        # Now you can use `fair_clf` as any other classifier
        # You have to provide group information to compute fair predictions
        y_pred_test = fair_clf(X=torch.tensor(range(len(y_fair_t))), group=a_fair_t.numpy())
        y_pred_test = torch.tensor(y_pred_test, dtype=torch.long)

        if mem == 0 and verbose : print("Member-PP")
        balanced_accuracys_member_pp.append(balanced_accuracy(y_pred_test, y_fair_t).item())
        if mem == 0 and verbose : print(f"  {(balanced_accuracy(y_pred_test, y_fair_t).item()):.3f} ")
        if c == 0:
            fairs_member_pp.append(spd(y_pred_test, a_fair_t).item())
            if mem == 0 and verbose : print(f"  {spd(y_pred_test, a_fair_t).item():.3f}")
        elif c == 1:
            fairs_member_pp.append(eod(y_pred_test, y_fair_t, a_fair_t).item())
            if mem == 0 and verbose : print(f"  {eod(y_pred_test, y_fair_t, a_fair_t).item():.3f}")
        elif c == 2:
            fairs_member_pp.append(aod(y_pred_test, y_fair_t, a_fair_t).item())
            if mem == 0 and verbose : print(f"  {aod(y_pred_test, y_fair_t, a_fair_t).item():.3f}")

thresholds_fairs_bma_pp = np.asarray(thresholds_fairs_bma_pp)
thresholds_fairs_member_pp = np.asarray(thresholds_fairs_member_pp).reshape((-1, 2, 2))

print("-"*30)
print(f"${np.mean(balanced_accuracys_bma_pp):.3f}_{'{'}\pm {np.std(balanced_accuracys_bma_pp):.3f}{'}'}$", end=" & ")
print(f"${np.mean(fairs_bma_pp):.3f}_{'{'}\pm {np.std(fairs_bma_pp):.3f}{'}'}$")
print(f"${np.mean(balanced_accuracys_member_pp):.3f}_{'{'}\pm {np.std(balanced_accuracys_member_pp):.3f}{'}'}$", end=" & ")
print(f"${np.mean(fairs_member_pp):.3f}_{'{'}\pm {np.std(fairs_member_pp):.3f}{'}'}$")
print("-"*30)
for i in range(2):
    print(f"Group {i}")
    print(f"${np.mean(thresholds_fairs_bma_pp[:, i, 0]):.3f}_{'{'}\pm {np.std(thresholds_fairs_bma_pp[:, i, 0]):.3f}{'}'}$")
    print(f"${np.mean(thresholds_fairs_bma_pp[:, i, 1]):.3f}_{'{'}\pm {np.std(thresholds_fairs_bma_pp[:, i, 1]):.3f}{'}'}$")
print("-"*30)
for i in range(2):
    print(f"Group {i}")
    print(f"${np.mean(thresholds_fairs_member_pp[:, i, 0]):.3f}_{'{'}\pm {np.std(thresholds_fairs_member_pp[:, i, 0]):.3f}{'}'}$")
    print(f"${np.mean(thresholds_fairs_member_pp[:, i, 1]):.3f}_{'{'}\pm {np.std(thresholds_fairs_member_pp[:, i, 1]):.3f}{'}'}$")

------------------------------
$0.788_{\pm 0.004}$ & $0.057_{\pm 0.002}$
$0.782_{\pm 0.010}$ & $0.060_{\pm 0.005}$
------------------------------
Group 0
$0.439_{\pm 0.012}$
$0.439_{\pm 0.012}$
Group 1
$0.704_{\pm 0.034}$
$0.627_{\pm 0.034}$
------------------------------
Group 0
$0.451_{\pm 0.023}$
$0.450_{\pm 0.023}$
Group 1
$0.726_{\pm 0.037}$
$0.657_{\pm 0.042}$
