#### This notebook demonstrates the use of the Reject Option Classification (ROC) post-processing algorithm for bias mitigation.
- The debiasing function used is implemented in the `RejectOptionClassification` class.
- Divide the dataset into training, validation, and testing partitions.
- Train classifier on original training data.
- Estimate the optimal classification threshold, that maximizes balanced accuracy without fairness constraints.
- Estimate the optimal classification threshold, and the critical region boundary (ROC margin) using a validation set for the desired constraint on fairness. The best parameters are those that maximize the classification threshold while satisfying the fairness constraints.
- The constraints can be used on the following fairness measures:
    * Statistical parity difference on the predictions of the classifier
    * Average odds difference for the classifier
    * Equal opportunity difference for the classifier
- Determine the prediction scores for testing data. Using the estimated optimal classification threshold, compute accuracy and fairness metrics.
- Using the determined optimal classification threshold and the ROC margin, adjust the predictions. Report accuracy and fairness metric on the new predictions.

In [1]:
%matplotlib inline
# Load all necessary packages
import sys
sys.path.append("../")
import numpy as np
from tqdm import tqdm
from warnings import warn

from aif360.datasets import BinaryLabelDataset
from aif360.datasets import AdultDataset, GermanDataset, CompasDataset
from aif360.metrics import ClassificationMetric, BinaryLabelDatasetMetric
from aif360.metrics.utils import compute_boolean_conditioning_vector
from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions\
        import load_preproc_data_adult, load_preproc_data_german, load_preproc_data_compas
from aif360.algorithms.postprocessing.reject_option_classification\
        import RejectOptionClassification
from common_utils import compute_metrics

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

from IPython.display import Markdown, display
import matplotlib.pyplot as plt
from ipywidgets import interactive, FloatSlider

pip install 'aif360[LawSchoolGPA]'
pip install 'aif360[Reductions]'
pip install 'aif360[Reductions]'
pip install 'aif360[Reductions]'


In [2]:
import gc

#### Load dataset and specify options

In [3]:
## import dataset
dataset_used = "adult" # "adult", "german", "compas"
protected_attribute_used = 1 # 1, 2

if dataset_used == "adult":
#     dataset_orig = AdultDataset()
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        dataset_orig = load_preproc_data_adult(['sex'])
    else:
        privileged_groups = [{'race': 1}]
        unprivileged_groups = [{'race': 0}]
        dataset_orig = load_preproc_data_adult(['race'])
    
elif dataset_used == "german":
#     dataset_orig = GermanDataset()
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        dataset_orig = load_preproc_data_german(['sex'])
    else:
        privileged_groups = [{'age': 1}]
        unprivileged_groups = [{'age': 0}]
        dataset_orig = load_preproc_data_german(['age'])
    
elif dataset_used == "compas":
#     dataset_orig = CompasDataset()
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        dataset_orig = load_preproc_data_compas(['sex'])
    else:
        privileged_groups = [{'race': 1}]
        unprivileged_groups = [{'race': 0}]  
        dataset_orig = load_preproc_data_compas(['race'])

        
# Metric used (should be one of allowed_metrics)
metric_name = "Statistical parity difference"

# Upper and lower bound on the fairness metric used
metric_ub = 0.11
metric_lb = -0.11
        
#random seed for calibrated equal odds prediction
np.random.seed(1)

# Verify metric name
allowed_metrics = ["Statistical parity difference",
                   "Average odds difference",
                   "Equal opportunity difference"]
if metric_name not in allowed_metrics:
    raise ValueError("Metric name should be one of allowed metrics")

#### Split into train, test and validation

In [4]:
mis_mean = 0
deoo_mean = 0
dpe_mean = 0
for step in tqdm(range(50)):

    # Get the dataset and split into train and test
    dataset_orig_train, dataset_orig_vt = dataset_orig.split([0.7], shuffle=True)
    dataset_orig_valid, dataset_orig_test = dataset_orig_vt.split([0.5], shuffle=True)

    # Logistic regression classifier and predictions
    scale_orig = StandardScaler()
    X_train = scale_orig.fit_transform(dataset_orig_train.features)
    y_train = dataset_orig_train.labels.ravel()

    lmod = LogisticRegression()
    lmod.fit(X_train, y_train)
    y_train_pred = lmod.predict(X_train)

    # positive class index
    pos_ind = np.where(lmod.classes_ == dataset_orig_train.favorable_label)[0][0]

    dataset_orig_train_pred = dataset_orig_train.copy(deepcopy=True)
    dataset_orig_train_pred.labels = y_train_pred

    dataset_orig_valid_pred = dataset_orig_valid.copy(deepcopy=True)
    X_valid = scale_orig.transform(dataset_orig_valid_pred.features)
    y_valid = dataset_orig_valid_pred.labels
    dataset_orig_valid_pred.scores = lmod.predict_proba(X_valid)[:,pos_ind].reshape(-1,1)

    dataset_orig_test_pred = dataset_orig_test.copy(deepcopy=True)
    X_test = scale_orig.transform(dataset_orig_test_pred.features)
    y_test = dataset_orig_test_pred.labels.ravel()
    Y_test = y_test.copy()
    dataset_orig_test_pred.scores = lmod.predict_proba(X_test)[:,pos_ind].reshape(-1,1)

    num_thresh = 100
    ba_arr = np.zeros(num_thresh)
    class_thresh_arr = np.linspace(0.01, 0.99, num_thresh)
    for idx, class_thresh in enumerate(class_thresh_arr):
        
        fav_inds = dataset_orig_valid_pred.scores > class_thresh
        dataset_orig_valid_pred.labels[fav_inds] = dataset_orig_valid_pred.favorable_label
        dataset_orig_valid_pred.labels[~fav_inds] = dataset_orig_valid_pred.unfavorable_label
        
        classified_metric_orig_valid = ClassificationMetric(dataset_orig_valid,
                                                dataset_orig_valid_pred, 
                                                unprivileged_groups=unprivileged_groups,
                                                privileged_groups=privileged_groups)
        
        ba_arr[idx] = 0.5*(classified_metric_orig_valid.true_positive_rate()\
                        +classified_metric_orig_valid.true_negative_rate())

    best_ind = np.where(ba_arr == np.max(ba_arr))[0][0]
    best_class_thresh = class_thresh_arr[best_ind]

    ROC = RejectOptionClassification(unprivileged_groups=unprivileged_groups, 
                                    privileged_groups=privileged_groups, 
                                    low_class_thresh=0.01, high_class_thresh=0.99,
                                    num_class_thresh=100, num_ROC_margin=50,
                                    metric_name=metric_name,
                                    metric_ub=metric_ub, metric_lb=metric_lb)
    ROC = ROC.fit(dataset_orig_valid, dataset_orig_valid_pred)

    # Metrics for the test set
    fav_inds = dataset_orig_valid_pred.scores > best_class_thresh
    dataset_orig_valid_pred.labels[fav_inds] = dataset_orig_valid_pred.favorable_label
    dataset_orig_valid_pred.labels[~fav_inds] = dataset_orig_valid_pred.unfavorable_label


    # Transform the validation set
    dataset_transf_valid_pred = ROC.predict(dataset_orig_valid_pred)



    # Metrics for the test set
    fav_inds = dataset_orig_test_pred.scores > best_class_thresh
    dataset_orig_test_pred.labels[fav_inds] = dataset_orig_test_pred.favorable_label
    dataset_orig_test_pred.labels[~fav_inds] = dataset_orig_test_pred.unfavorable_label


    # Metrics for the transformed test set
    dataset_transf_test_pred = ROC.predict(dataset_orig_test_pred)

    eq = 0
    for i in range(len(Y_test)):
        if(dataset_transf_test_pred.labels.ravel()[i] == Y_test[i]):
            eq += 1
    mis_mean += eq / len(Y_test)
    n_10 = 0
    n_11 = 0
    c_10 = 0
    c_11 = 0
    for i in range(len(Y_test)):
        if(Y_test[i] == 1 and dataset_transf_test_pred.protected_attributes[:,0][i] == 0):
            n_10 += 1
            if(dataset_transf_test_pred.labels.ravel()[i] == 1):
                c_10 += 1
        elif(Y_test[i] == 1 and dataset_transf_test_pred.protected_attributes[:,0][i] == 1):
            n_11 += 1
            if(dataset_transf_test_pred.labels.ravel()[i] == 1):
                c_11 += 1
    deoo_mean += abs(c_10 / n_10 - c_11 / n_11)
    n_00 = 0
    n_01 = 0
    c_00 = 0
    c_01 = 0
    for i in range(len(Y_test)):
        if(Y_test[i] == 0 and dataset_transf_test_pred.protected_attributes[:,0][i] == 0):
            n_00 += 1
            if(dataset_transf_test_pred.labels.ravel()[i] == 1):
                c_00 += 1
        elif(Y_test[i] == 0 and dataset_transf_test_pred.protected_attributes[:,0][i] == 1):
            n_01 += 1
            if(dataset_transf_test_pred.labels.ravel()[i] == 1):
                c_01 += 1
    dpe_mean += abs(c_00 / n_00 - c_01 / n_01)
    print(mis_mean)
    print(deoo_mean)
    print(dpe_mean)
    del dataset_orig_train
    del dataset_orig_vt
    del dataset_orig_valid
    del dataset_orig_test
    del X_train
    del y_train
    del lmod
    del y_train_pred
    del pos_ind
    del dataset_orig_train_pred
    del dataset_orig_valid_pred
    del X_valid
    del y_valid
    del dataset_orig_test_pred
    del X_test
    del y_test
    del Y_test
    del ROC
    del dataset_transf_test_pred
    del dataset_transf_valid_pred
    gc.collect()
print(mis_mean / 50)
print(deoo_mean / 50)
print(dpe_mean / 50)

  2%|▏         | 1/50 [00:10<08:20, 10.21s/it]

0.6834993858332196
0.023940942424715228
0.015054750562392472


  4%|▍         | 2/50 [00:20<08:01, 10.03s/it]

1.4154497065647604
0.11510011816741605
0.03682241842422365


  6%|▌         | 3/50 [00:29<07:47,  9.95s/it]

2.1044083526682136
0.13257489175050485
0.057065539515779834


  8%|▊         | 4/50 [00:39<07:35,  9.90s/it]

2.838269414494336
0.1420004260714367
0.068522257888251


 10%|█         | 5/50 [00:49<07:25,  9.89s/it]

3.5244984304626725
0.1516629119410604
0.08791839967543691


 12%|█▏        | 6/50 [00:59<07:16,  9.91s/it]

4.199399481370275
0.1935196059716675
0.09604406174462382


 14%|█▍        | 7/50 [01:09<07:05,  9.89s/it]

4.874573495291388
0.20514196709323795
0.10991110433254486


 16%|█▌        | 8/50 [01:19<06:58,  9.95s/it]

5.552067694827351
0.21430208393923655
0.11575540424184516


 18%|█▊        | 9/50 [01:29<06:49,  9.99s/it]

6.222464856012011
0.2372381855033805
0.1341247880698248


 20%|██        | 10/50 [01:39<06:39,  9.99s/it]

6.925344615804559
0.31150807594466
0.15345097007364328


 22%|██▏       | 11/50 [01:49<06:31, 10.04s/it]

7.65893271461717
0.3536429367123418
0.17022085426562378


 24%|██▍       | 12/50 [02:00<06:25, 10.16s/it]

8.337109321686912
0.41500765402379003
0.1875876145330711


 26%|██▌       | 13/50 [02:10<06:14, 10.11s/it]

9.011055002047224
0.42457154854711654
0.20810271524419843


 28%|██▊       | 14/50 [02:20<06:01, 10.05s/it]

9.687730312542651
0.4601705712709919
0.22270483868024415


 30%|███       | 15/50 [02:30<05:50, 10.01s/it]

10.427460079159275
0.5269005561324793
0.2341524811245444


 32%|███▏      | 16/50 [02:40<05:40, 10.01s/it]

11.155588917701651
0.5514133271317858
0.24446175237402784


 34%|███▍      | 17/50 [02:50<05:34, 10.13s/it]

11.827487375460626
0.5966565013420831
0.26885044285367976


 36%|███▌      | 18/50 [03:00<05:27, 10.23s/it]

12.549338064692234
0.6176133094921391
0.2715968333867886


 38%|███▊      | 19/50 [03:11<05:20, 10.33s/it]

13.227514671761977
0.64564615756978
0.2967711368731532


 40%|████      | 20/50 [03:22<05:13, 10.46s/it]

13.921113689095128
0.6603199418535572
0.29955213955844257


 42%|████▏     | 21/50 [03:34<05:17, 10.94s/it]

14.654155861880716
0.7132769926696216
0.3122285723132164


 44%|████▍     | 22/50 [03:52<06:11, 13.26s/it]

15.320185614849189
0.7891752790156692
0.3453258075040101


 46%|████▌     | 23/50 [04:08<06:19, 14.04s/it]

16.018834447932306
0.8458647821261503
0.35724382800914667


 48%|████▊     | 24/50 [04:19<05:36, 12.93s/it]

16.770847550156954
0.8978978505337575
0.3994553146386307


 50%|█████     | 25/50 [04:29<05:04, 12.17s/it]

17.442200081888906
0.9256609146553361
0.4406701508551352


 52%|█████▏    | 26/50 [04:39<04:39, 11.63s/it]

18.1239252081343
0.9334318758005303
0.459225959417824


 54%|█████▍    | 27/50 [04:50<04:18, 11.25s/it]

18.805240889859427
0.9683114634660418
0.47061140282793157


 56%|█████▌    | 28/50 [05:00<04:03, 11.05s/it]

19.536918247577457
1.0139500746417376
0.48694427039507954


 58%|█████▊    | 29/50 [05:11<03:49, 10.95s/it]

20.267913197761708
1.0461055699278023
0.510831104656192


 60%|██████    | 30/50 [05:22<03:39, 10.98s/it]

20.950457213047635
1.0464471259050319
0.5206032996742886


 62%|██████▏   | 31/50 [05:33<03:29, 11.01s/it]

21.62822437559711
1.0469114961562185
0.5331505502284857


 64%|██████▍   | 32/50 [05:44<03:16, 10.93s/it]

22.363996178517816
1.092366041610764
0.5354538142219767


 66%|██████▌   | 33/50 [05:55<03:06, 10.95s/it]

23.044765934215917
1.151956288900989
0.5556246649013019


 68%|██████▊   | 34/50 [08:21<13:44, 51.53s/it]

23.77930940357582
1.1704428247474987
0.5607271336882763


 70%|███████   | 35/50 [08:33<09:53, 39.56s/it]

24.450116009280748
1.1886046679565279
0.5788288236649664


 72%|███████▏  | 36/50 [08:43<07:12, 30.88s/it]

25.135799099222062
1.2351642247432313
0.5825789357454341


 74%|███████▍  | 37/50 [08:54<05:24, 24.93s/it]

25.813975706291803
1.2534980036403907
0.6035087145877529


 76%|███████▌  | 38/50 [09:10<04:25, 22.09s/it]

26.555479732496252
1.3127531890583668
0.6214516164025222


 78%|███████▊  | 39/50 [09:31<04:00, 21.90s/it]

27.235567080660577
1.3201729415810826
0.6348914607815932


 80%|████████  | 40/50 [09:48<03:23, 20.37s/it]

27.917838132932992
1.3344994384633604
0.6506820366750994


 82%|████████▏ | 41/50 [10:07<02:58, 19.86s/it]

28.589190664664944
1.401696827310234
0.6724617457001837


 84%|████████▍ | 42/50 [10:25<02:34, 19.25s/it]

29.256994677221243
1.4227109624354506
0.6906108135506703
29.938992766480148
1.468473888034442
0.6957332546460286


 88%|████████▊ | 44/50 [10:57<01:44, 17.37s/it]

30.618943633137714
1.5337270109798724
0.7245915502007344


 90%|█████████ | 45/50 [11:08<01:18, 15.62s/it]

31.34188617442337
1.565348395060273
0.7351626064040572


 92%|█████████▏| 46/50 [11:24<01:02, 15.57s/it]

32.0092807424594
1.6361025622126215
0.7747022353914155


 94%|█████████▍| 47/50 [11:45<00:52, 17.38s/it]

32.68854920158319
1.6715816991315489
0.7779850884736368


 96%|█████████▌| 48/50 [12:07<00:37, 18.65s/it]

33.36495154906511
1.6988701038386467
0.7907756324766575


 98%|█████████▊| 49/50 [12:29<00:19, 19.81s/it]

34.11136890951277
1.7442411243504714
0.7983415248837186


100%|██████████| 50/50 [12:51<00:00, 15.42s/it]

34.77617032892044
1.764951120060467
0.8386069047808814
0.6955234065784088
0.03529902240120934
0.016772138095617627



