Use predicted prob and labels to calculate 
1) utility
   1) AUC
   2) ACC
2) fairness
   1) eq odd
   2) multiaccuracy
   3) multicalibration

In [1]:
import pickle
import torch 
import numpy as np 
import pandas as pd
import sklearn.metrics as sklm
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score


In [2]:
def load_model(datapath,model_name):
    model_path = f"{datapath}/predictions_{model_name}.pkl"
    with open(model_path, 'rb') as file:
        predictions = pickle.load(file)
    probs = predictions['probs']
    labels = predictions['label']
    patientid = predictions['patientid']
    # check if tensor
    if type(patientid[0]) == torch.Tensor:
        patientid = [pid.item() for pid in patientid]
    return probs, labels, patientid

In [3]:
# multiaccuracy 
def compute_multiaccuracy(pred_probs, labels):
    """Compute multiaccuracy (absolute residual mean)"""
    residual = pred_probs - labels
    return np.abs(np.mean(residual))

def compute_group_multiaccuracy(pred_probs, labels, group_indices):
    """Compute multiaccuracy for specific subgroup (by indexing)"""
    if len(group_indices) == 0:
        return 0.0
    return compute_multiaccuracy(pred_probs[group_indices], labels[group_indices])

# multicalibration 
def expected_calibration_error(pred_probs, labels, num_bins=10, metric_variant="abs", quantile_bins=False):
    """
    Computes the calibration error with a binning estimator over equal sized bins
    See http://arxiv.org/abs/1706.04599 and https://arxiv.org/abs/1904.01685.
    Does not currently support sample weights
    https://github.com/MLforHealth/CXR_Fairness/blob/c2a0e884171d6418e28d59dca1ccfb80a3f125fe/cxr_fairness/metrics.py#L1557
    """
    if metric_variant == "abs":
        transform_func = np.abs
    elif (metric_variant == "squared") or (metric_variant == "rmse"):
        transform_func = np.square
    else:
        raise ValueError("provided metric_variant not supported")

    if quantile_bins:
        cut_fn = pd.qcut
    else:
        cut_fn = pd.cut

    bin_ids = cut_fn(pred_probs, num_bins, labels=False, retbins=False)
    df = pd.DataFrame(
        {"pred_probs": pred_probs, "labels": labels, "bin_id": bin_ids})
    ece_df = (
        df.groupby("bin_id")
        .agg(
            pred_probs_mean=("pred_probs", "mean"),
            labels_mean=("labels", "mean"),
            bin_size=("pred_probs", "size"),
        )
        .assign(
            bin_weight=lambda x: x.bin_size / df.shape[0],
            err=lambda x: transform_func(x.pred_probs_mean - x.labels_mean),
        )
    )
    result = np.average(ece_df.err.values, weights=ece_df.bin_weight)
    if metric_variant == "rmse":
        result = np.sqrt(result)
    return result

def binary_classification_report(pred, y):
    auc = roc_auc_score(y, pred)
    ece = expected_calibration_error(pred, y)
    multiacc = compute_multiaccuracy(pred, y)

    tn, fp, fn, tp = confusion_matrix(y, (pred > 0.5).astype(int)).ravel()
    report = {
        f"auc": auc,
        f"acc": (tp + tn) / (tn + fp + fn + tp),
        f"ae": multiacc,
        f"ece": ece,
        f"tpr": tp / (tp + fn),
        f"tnr": tn / (tn + fp),
        f"tn": tn,
        f"fp": fp,
        f"fn": fn,
        f"tp": tp,
    }
    return report

def evaluate_binary(pred, Y, A):
    
    overall_metrics = {}
    subgroup_metrics = {}

    # overall
    overall_metrics.update(binary_classification_report(pred, Y,))

    # subgroup
    for group in np.unique(A):
        group_indices = np.where(A == group)[0]
        sub_report = binary_classification_report(pred[group_indices], Y[group_indices])

        # compute multiaccuracy using index-based filtering (no mask)
        multiacc = compute_group_multiaccuracy(pred, Y, group_indices)
        sub_report[f"ae"] = multiacc

        for k, v in sub_report.items():
            subgroup_metrics.setdefault(k, []).append(v)

    return overall_metrics, subgroup_metrics

def organize_results(overall_metrics, subgroup_metrics):
    subgroup_auc = subgroup_metrics["auc"]
    subgroup_acc = subgroup_metrics["acc"]
    subgroup_ece = subgroup_metrics["ece"]
    subgroup_tpr = subgroup_metrics["tpr"]
    subgroup_tnr = subgroup_metrics["tnr"]
    subgroup_multiacc = subgroup_metrics["ae"]

    result = {
        "overall-auc": overall_metrics["auc"],
        "overall-acc": overall_metrics["acc"],
        "overall-ae": overall_metrics["ae"],
        "overall-ece": overall_metrics["ece"],
        "worst-auc": min(subgroup_auc),
        "auc-gap": max(subgroup_auc) - min(subgroup_auc),
        "acc-gap": max(subgroup_acc) - min(subgroup_acc),
        "ae-gap": max(subgroup_multiacc) - min(subgroup_multiacc),
        "ece-gap": max(subgroup_ece) - min(subgroup_ece),
        "eod": 1 - ((max(subgroup_tpr) - min(subgroup_tpr)) + (max(subgroup_tnr) - min(subgroup_tnr))) / 2,
        "eo": max(subgroup_tpr) - min(subgroup_tpr),
    }
    return result


## CXP

In [None]:
datapath = "./no_finding/CXP"

metadata_path = f"{datapath}/test_with_metadata.csv"
metadata = pd.read_csv(metadata_path)
metapatientid = metadata['Patient'].values
age = metadata['Age'].values
age_threshold = 65
age_group_mask = np.array(age) >= age_threshold
age_binary = age_group_mask.astype(int)
sex = metadata['Sex'].values

sensitive_groups = {'sex':sex,
                    'age':age_binary}


In [5]:
cxp_result_df = pd.DataFrame()
for model_name in ["CLIP", 'MedCLIP',"BiomedCLIP","PubMedCLIP"]: 
    # load model predictions
    probs, labels, patientid = load_model(datapath, model_name)
    # reorganize sensitive features based on patientid sequence
    sen = sensitive_groups['sex']
    sen_dict = dict(zip(metapatientid, sen))
    sen = np.array([sen_dict[pid] for pid in patientid])

    overall_metrics, subgroup_metrics  = evaluate_binary(pred=probs[:,1],Y=labels,A=sen)
    result = organize_results(overall_metrics, subgroup_metrics)
    # store in the result_df
    result_df = pd.DataFrame(result, index=[model_name])
    cxp_result_df = pd.concat([cxp_result_df, result_df])

In [6]:
cxp_result_df.round(4)

Unnamed: 0,overall-auc,overall-acc,overall-ae,overall-ece,worst-auc,auc-gap,acc-gap,ae-gap,ece-gap,eod,eo
CLIP,0.5328,0.1177,0.6123,0.6123,0.5086,0.0465,0.0038,0.005,0.005,0.9987,0.0023
MedCLIP,0.8173,0.8291,0.3969,0.397,0.807,0.0202,0.0064,0.0038,0.0014,0.981,0.0341
BiomedCLIP,0.6918,0.4571,0.4912,0.4916,0.674,0.0359,0.0423,0.0511,0.0478,0.9127,0.1133
PubMedCLIP,0.6587,0.7733,0.3414,0.3414,0.6518,0.0134,0.0234,0.0042,0.0042,0.9724,0.0278


# MIMIC

In [None]:
datapath = "./no_finding/MIMIC"

metadata_path = f"{datapath}/test.csv"
metadata = pd.read_csv(metadata_path)
metapatientid = metadata['subject_id'].values
age = metadata['anchor_age'].values
age_threshold = 65
age_group_mask = np.array(age) >= age_threshold
age_binary = age_group_mask.astype(int)
sex = metadata['gender'].values


sensitive_groups = {'sex':sex,
                    'age':age_binary}


In [6]:
mimic_result_df = pd.DataFrame()
for model_name in ["CLIP", 'MedCLIP',"BiomedCLIP","PubMedCLIP"]: 
    # load model predictions
    probs, labels, patientid = load_model(datapath, model_name)
    
    # reorganize sensitive features based on patientid sequence
    sen = sensitive_groups['sex']
    sen_dict = dict(zip(metapatientid, sen))
    sen = np.array([sen_dict[pid] for pid in patientid])

    overall_metrics, subgroup_metrics  = evaluate_binary(pred=probs[:,1],Y=labels,A=sen)
    result = organize_results(overall_metrics, subgroup_metrics)
    # store in the result_df
    result_df = pd.DataFrame(result, index=[model_name])
    mimic_result_df = pd.concat([mimic_result_df, result_df])

In [7]:
mimic_result_df

Unnamed: 0,overall-auc,overall-acc,overall-ae,overall-ece,worst-auc,auc-gap,acc-gap,ae-gap,ece-gap,eod,eo
CLIP,0.483954,0.461457,0.28638,0.286898,0.460741,0.034455,0.054813,0.033922,0.033454,0.996349,0.004475
MedCLIP,0.78119,0.724546,0.040573,0.220499,0.768251,0.022865,0.005229,0.055581,0.009934,0.91969,0.097023
BiomedCLIP,0.618294,0.5395,0.241211,0.34558,0.613878,0.005332,0.02625,0.006302,0.027661,0.9282,0.081018
PubMedCLIP,0.612408,0.582261,0.002351,0.02396,0.608444,0.005529,0.01568,0.003415,0.00805,0.93999,0.072587


# NIH

In [None]:
datapath = "./no_finding/NIH"

metadata_path = f"{datapath}/test_meta_FM.csv"
metadata = pd.read_csv(metadata_path)
metapatientid = metadata['patientid'].values
age = metadata['Patient Age'].values
age_threshold = 65
age_group_mask = np.array(age) >= age_threshold
age_binary = age_group_mask.astype(int)
sex = metadata['gender'].values


sensitive_groups = {'sex':sex,
                    'age':age_binary}


In [9]:
nih_result_df = pd.DataFrame()
for model_name in ["CLIP", 'MedCLIP',"BiomedCLIP","PubMedCLIP"]: 
    # load model predictions
    probs, labels, patientid = load_model(datapath, model_name)
    
    # reorganize sensitive features based on patientid sequence
    sen = sensitive_groups['sex']
    sen_dict = dict(zip(metapatientid, sen))
    sen = np.array([sen_dict[pid] for pid in patientid])

    overall_metrics, subgroup_metrics  = evaluate_binary(pred=probs[:,1],Y=labels,A=sen)
    result = organize_results(overall_metrics, subgroup_metrics)
    # store in the result_df
    result_df = pd.DataFrame(result, index=[model_name])
    nih_result_df = pd.concat([nih_result_df, result_df])

In [11]:
nih_result_df.round(4)

Unnamed: 0,overall-auc,overall-acc,overall-ae,overall-ece,worst-auc,auc-gap,acc-gap,ae-gap,ece-gap,eod,eo
CLIP,0.4921,0.5382,0.263,0.2636,0.4804,0.0277,0.0062,0.0006,0.0012,0.9999,0.0002
MedCLIP,0.7309,0.6868,0.0393,0.188,0.7231,0.0141,0.0098,0.0061,0.0136,0.9771,0.0123
BiomedCLIP,0.7266,0.6576,0.1946,0.2753,0.7168,0.0173,0.0033,0.0004,0.0021,0.9917,0.0015
PubMedCLIP,0.6357,0.5884,0.0575,0.0695,0.6286,0.0125,0.0001,0.0038,0.0007,0.9245,0.0706


# Find common disease

In [21]:
cxp_disease = ['No Finding',
       'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
       'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
       'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',]
mimic_disease = ['Atelectasis', 'Cardiomegaly', 'Consolidation',
       'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
       'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other',
       'Pneumonia', 'Pneumothorax']
nih_disease = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema',
       'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening',
       'Cardiomegaly', 'Nodule', 'Mass', 'Hernia', 'No Finding']

In [22]:
intersection = set(mimic_disease) & set(cxp_disease) & set(nih_disease)
intersection_list = list(intersection)
print(intersection_list)

['Edema', 'Atelectasis', 'Pneumothorax', 'Consolidation', 'No Finding', 'Cardiomegaly', 'Pneumonia']
