In [42]:
from scipy.io import loadmat
import numpy as np

In [43]:
def calc_fairness_metric(constraint, confu_mat, num_groups=2, num_classes=2):
    if constraint == 'eopp':
        '''
        Compute EO disparity
        '''
        group0_tn, group0_fp, group0_fn, group0_tp = confu_mat['0'].ravel()
        group1_tn, group1_fp, group1_fn, group1_tp = confu_mat['1'].ravel()

        pivot = (group0_tp + group1_tp) / (group0_fn + group0_tp + group1_fn + group1_tp)
        group0_tpr = group0_tp / (group0_fn + group0_tp)
        group1_tpr = group1_tp / (group1_fn + group1_tp)

        return max(abs(group0_tpr - pivot), abs(group1_tpr - pivot)) # from fairbatch paper
        #return abs(group0_tp / (group0_fn + group0_tp) - group1_tp / (group1_fn + group1_tp))

    elif constraint == 'eo':
        '''
        Compute ED disparity 
        '''

        group0_tn, group0_fp, group0_fn, group0_tp = confu_mat['0'].ravel()
        group1_tn, group1_fp, group1_fn, group1_tp = confu_mat['1'].ravel()
        
        pivot_1 = (group0_tp + group1_tp) / (group0_fn + group0_tp + group1_fn + group1_tp)
        group0_tpr = group0_tp / (group0_fn + group0_tp)
        group1_tpr = group1_tp / (group1_fn + group1_tp)

        EO_Y_1 = max(abs(group0_tpr - pivot_1), abs(group1_tpr - pivot_1))

        pivot_0 = (group0_fp + group1_fp) / (group0_tn + group0_fp + group1_tn + group1_fp)
        group0_fpr = (group0_fp) / (group0_tn + group0_fp)
        group1_fpr = (group1_fp) / (group1_tn + group1_fp)

        EO_Y_0 = max(abs(group0_fpr - pivot_0), abs(group1_fpr - pivot_0))

        return max(EO_Y_0, EO_Y_1)

    elif constraint == 'dp':
        pass

def calc_acc(constraint, confu_mat, num_classes=2, num_groups=2):
    group0_tn, group0_fp, group0_fn, group0_tp = confu_mat['0'].ravel()
    group1_tn, group1_fp, group1_fn, group1_tp = confu_mat['1'].ravel()
    
    return (group0_tn + group0_tp + group1_tn + group1_tp) / (confu_mat['0'].ravel().sum() + confu_mat['1'].ravel().sum())


In [37]:

method = "reweighting"
dataset = "adult"
sen_attr_dict = {
    "adult": "sex",
    "compas": "sex",
    "retiring_adult": "race",
    "retiring_adult_coverage": "race"
}

sen_attr = sen_attr_dict[dataset]
constraint = "eo"


In [44]:

filename_epi = "mlp_{}_seed{}_epochs5_bs128_lr{}_decay{:.4f}"
seed_arr = [0]
date = "20220214"

if method == "reweighting":
    filename = filename_epi + "_constraint{}_eta{}_iter{}_test_confu"
#     lr_arr = [0.0003, 0.0005]
    lr_arr = [0.0001, 0.0003, 0.0005]
    eta_arr = [0.1, 0.2, 0.3, 0.5, 1.0, 2.0, 3.0, 5.0, 10.0, 20.0, 30.0]
    decay = 0.0005
    iteration = 10
    
elif method == "adv_debiasing":
    filename = filename_epi + "_adv_lamb{}_eta{}_constraint{}_test_confu"
    lr = 0.0005
    decay = 0.0005
#     lamb_arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    lamb_arr = [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0, 20.0, 30.0, 50.0, 100.0]
#     eta_arr = [0.001, 0.003, 0.005, 0.007, 0.01]
    eta_arr = [0.001, 0.003, 0.005, 0.01]
    


In [45]:

results_fair = np.zeros((len(seed_arr), len(lr_arr), len(eta_arr)))
results_acc = np.zeros((len(seed_arr), len(lr_arr), len(eta_arr)))

for seed in seed_arr:
    for i, lr in enumerate(lr_arr):
        for j, eta in enumerate(eta_arr):
            confu_mat = loadmat("./results/{}/{}/{}/{}".format(date, dataset, method, filename.format(sen_attr, seed, lr, decay, constraint, eta, iteration)), appendmat=False)
#             print(calc_fairness_metric(constraint, confu_mat))
            results_fair[seed, i, j] += calc_fairness_metric(constraint, confu_mat)
            results_acc[seed, i, j] += calc_acc(constraint, confu_mat)

results_fair = np.mean(results_fair, axis=0)
results_acc = np.mean(results_acc, axis=0)


print(results_fair)
print(results_acc)


[[0.01784794 0.01784794 0.01784794 0.01784794 0.01784794 0.01784794
  0.01784794 0.01784794 0.01784794 0.01784794 0.01784794]
 [0.05500794 0.05500794 0.05500794 0.05500794 0.05500794 0.05500794
  0.05500794 0.05500794 0.05500794 0.05500794 0.05500794]
 [0.11845271 0.11845271 0.11845271 0.11845271 0.11845271 0.11845271
  0.11845271 0.11845271 0.11845271 0.11845271 0.11845271]]
[[0.55105348 0.55105348 0.55105348 0.55105348 0.55105348 0.55105348
  0.55105348 0.55105348 0.55105348 0.55105348 0.55105348]
 [0.57617504 0.57617504 0.57617504 0.57617504 0.57617504 0.57617504
  0.57617504 0.57617504 0.57617504 0.57617504 0.57617504]
 [0.61264182 0.61264182 0.61264182 0.61264182 0.61264182 0.61264182
  0.61264182 0.61264182 0.61264182 0.61264182 0.61264182]]
