# Sklearn cross-validation for selecting best hyperparams and averaging results across test splits

In [None]:
import os
os.chdir('..')

## Cross validate

In [None]:
from fair_robust_classifiers.metrics.scorers import SubgroupsMetricScorer, DDPMetricScorer, CounterfactualScorer
from sklearn.metrics import balanced_accuracy_score
        
cv_scorers = {
    # utility scorers
    "balancedAccuracy": SubgroupsMetricScorer(balanced_accuracy_score, need_class_predictions=True),
    
    # fairness scorer
    'demographicParity': DDPMetricScorer(),
    'equalOpportunityPos': DDPMetricScorer(evaluation_class=1),
    'equalOpportunityNeg': DDPMetricScorer(evaluation_class=-1),
    "counterfactual": CounterfactualScorer(),
    "counterfactualPos": CounterfactualScorer(evaluation_class=1),
    "counterfactualNeg": CounterfactualScorer(evaluation_class=-1),
}
cv_scorers_red = {k:v for k,v in cv_scorers.items() if 'counterfactual' not in k}

### Train and Store with hyperparameters combinations

#### Gurobi SVC

In [None]:

datasets = [
    ('arrhythmia', 'hasArrhythmia','sex'),
    
    #('adult', 'grossIncomeGEQ50k','race'),
    ('adult', 'grossIncomeGEQ50k','sex'), # <-
    #('adult', 'grossIncomeGEQ50k','nativeCountry'),

    ('credit', 'NoDefaultNextMonth', 'Age'),

    #('drug', 'heroin', 'gender'),
    #('drug', 'heroin', 'ethnicity'),
    ('drug', 'amphetamines', 'gender'), # <-
    #('drug', 'amphetamines', 'ethnicity'),
    
    ('germanSex', 'creditRisk', 'sex'),
    
    #('compas', 'twoYearRecid','sex'),
    ('compas', 'twoYearRecid','race'), # <-
    
    ('taiwan', 'defaultNextMonth', 'sex'),
    ]

bias_mitigation = [None,
                   'linearDP', 'linearEOPpos', 'linearEOPneg',
                   'invertedHingesDP', 'invertedHingesEOPpos',  'invertedHingesEOPneg',
                  ]

kernel = 'gaussian'

In [None]:
from fair_robust_classifiers.cross_validation.methods import gurobi_simple_cv_bias_mitigation

for data_str, label_str, sensitive_str in datasets:
    for bm in bias_mitigation:
        print(f"\nBias mitigation: {bm} for dataset {data_str}")
        gurobi_simple_cv_bias_mitigation(data_str, label_str, sensitive_str,
                                        evaluation_scorers = cv_scorers,
                                        kernel = kernel,
                                        bias_mitigation = bm,
                                        num_test_splits = 2,
                                        num_samples = 1000,
                                        balance_classes = False,
                                        include_sensitive = True,
                                        train_percentage = .7)

### Evaluate utility and fairness and select best hyperparameters

In [None]:

from fair_robust_classifiers.cross_validation.evaluation import evaluate_bias_mitigation
from fair_robust_classifiers.cross_validation.cv_utils import cv_results_name

In [None]:
selection_metrics = [
    "balancedAccuracy",

    "demographicParity",
    "equalOpportunityNeg",
    "equalOpportunityPos",
    "counterfactual",
    "counterfactualPos",
    "counterfactualNeg",

    "90_balancedAccuracy_min_demographicParity",
    "90_balancedAccuracy_min_equalOpportunityNeg",
    "90_balancedAccuracy_min_equalOpportunityPos",
    "90_balancedAccuracy_min_counterfactual",
    "90_balancedAccuracy_min_counterfactualPos",
    "90_balancedAccuracy_min_counterfactualNeg",
]
selection_metrics_red = [sm for sm in selection_metrics if 'counterfactual' not in sm]

#### Gurobi SVC

In [None]:

load_path = os.path.join('results', 'gurobiSVC', f'grid_search_results_simple')

bias_mitigation = ['noMitigation',
                   'linearDP', 'linearEOPpos', 'linearEOPneg',
                   'invertedHingesDP', 'invertedHingesEOPpos', 'invertedHingesEOPneg',
                  ]

kernel = 'gaussian' # None, 'gaussian'

for data in ['arrhythmia', 
            #'adult', 'arrhythmia','bail','compas','credit','drug','germanSex','taiwan'
            ]:
    for incl_sens in [True, False]:
        for bm in bias_mitigation:
            for sm in (selection_metrics if incl_sens else selection_metrics_red):
                #print(f"{data}, sens:{incl_sens}, {bm}, {sm}")
                file_name = cv_results_name(data, include_sens=incl_sens, bias_mitigation=bm,
                                            balance_classes=False, kernel=kernel)

                evaluate_bias_mitigation(result_load_path = load_path,
                                         result_file_name = file_name,
                                         selection_metric = sm,
                                         evaluation_scorers = cv_scorers if incl_sens else cv_scorers_red,
                                         selection_phase = 'validation',
                                         verbose = 0)

## Plots

### Methods scatterplots

In [None]:
import os

pl_d = {
    'file_path': os.path.join("results", 'gurobiSVC', f"grid_search_results_simple"),
    'datasets':{
        'adult': 'Pos',
        'arrhythmia':'Neg',
        'compas':'Pos',
        'credit':'Neg',
        'drug':'Neg',
        'germanSex':'Neg',
        'taiwan':'Pos',
    },
    'phase': 'test',
    'settings': [
        [{'percent':90, 'util_fn':'balancedAccuracy', 'fair_fn':'demographicParity', 'mitig':'DP', 'sensitive':True},
         {'percent':90, 'util_fn':'balancedAccuracy', 'fair_fn':'counterfactual', 'mitig':'DP', 'sensitive':True},
         {'percent':90, 'util_fn':'balancedAccuracy', 'fair_fn':'equalOpportunity', 'mitig':'EOp', 'sensitive':True}],
    ],
}

In [None]:
from fair_robust_classifiers.utilities.plot_utils import normalized_accuracy_fairness_plot_cum

normalized_accuracy_fairness_plot_cum(pl_d)

### Tables - best hyper-pameters selection 

In [None]:
from fair_robust_classifiers.utilities.plot_utils import make_fairness_results_table

hyp_sel_metrics = [
    "90_balancedAccuracy_min_demographicParity",
    "90_balancedAccuracy_min_equalOpportunityPos",
    "90_balancedAccuracy_min_equalOpportunityNeg",
    "90_balancedAccuracy_min_counterfactual",
    "90_balancedAccuracy_min_counterfactualPos",
    "90_balancedAccuracy_min_counterfactualNeg",
    ]
hyp_sel_metrics_red = [sm for sm in hyp_sel_metrics if 'counterfactual' not in sm]

eval_metrics = ["balancedAccuracy",
                "demographicParity", "equalOpportunityPos", "equalOpportunityNeg",
                "counterfactual", 'counterfactualPos', 'counterfactualNeg',
                ]
eval_metrics_red = [em for em in eval_metrics if 'counterfactual' not in em]

#### Gurobi SVC

In [None]:
mitigation_methods = ['noMitigation',
                      "linearDP", "invertedHingesDP",
                      "linearEOPpos", "invertedHingesEOPpos",
                      "linearEOPneg", "invertedHingesEOPneg",
                      ]

f_path = os.path.join('results','gurobiSVC',f'grid_search_results_simple')

##### Adult - 'Neg'

In [None]:
# ----- Adult - no kernel - include sensitive
make_fairness_results_table('adult',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- Adult - gauss kernel - include sensitive
make_fairness_results_table('adult',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)

##### Arrhythmia - 'Pos'

In [None]:
# ----- Arrhythmia - no kernel - include sensitive
make_fairness_results_table('arrhythmia',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- Arrhythmia - gauss kernel - include sensitive
make_fairness_results_table('arrhythmia',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)

##### Bail - 'Neg'

In [None]:
# ----- Bail - no kernel - include sensitive
make_fairness_results_table('bail',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- Bail - gauss kernel - include sensitive
make_fairness_results_table('bail',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)

##### Compas - 'Pos'

In [None]:
# ----- Compas - no kernel - include sensitive
make_fairness_results_table('compas',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- Compas - gauss kernel - include sensitive
make_fairness_results_table('compas',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)

##### Credit - 'Neg'

In [None]:
# ----- Credit  - no kernel - include sensitive
make_fairness_results_table('credit',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- Credit - gauss kernel - include sensitive
make_fairness_results_table('credit',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)

##### Drug - 'Neg'

In [None]:
# ----- Drug - no kernel - include sensitive
make_fairness_results_table('drug',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- Drug - gauss kernel - include sensitive
make_fairness_results_table('drug',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)

##### German - 'Neg'

In [None]:
# ----- German - no kernel - include sensitive
make_fairness_results_table('germanSex',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- German - gauss kernel - include sensitive
make_fairness_results_table('germanSex',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)

##### Taiwan - 'Pos'

In [None]:
# ----- Taiwan Credit - no kernel - include sensitive
make_fairness_results_table('taiwan',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = None,
                            phase = 'test',
                            folder_path = f_path)

In [None]:
# ----- Taiwan Credit - gauss kernel - include sensitive
make_fairness_results_table('taiwan',
                            selection_metrics = hyp_sel_metrics,
                            mitigation_methods = mitigation_methods,
                            evaluation_metrics = eval_metrics,
                            include_sensitive = True,
                            balance_classes = False,
                            kernel = 'gaussian',
                            phase = 'test',
                            folder_path = f_path)