In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings

warnings.filterwarnings('ignore')

In [24]:
from tqdm import tqdm
import os
import data_utils
import model_utils
from attack_utils import get_CSMIA_case_by_case_results, CSMIA_attack, LOMIA_attack
from data_utils import oneHotCatVars, filter_random_data_by_conf_score
from vulnerability_score_utils import get_vulnerability_score, draw_hist_plot
from experiment_utils import MIAExperiment
from disparity_inference_utils import get_confidence_array, draw_confidence_array_scatter, get_indices_by_group_condition, get_corr_btn_sens_and_out_per_subgroup, get_slopes, get_angular_difference, calculate_stds, get_mutual_info_btn_sens_and_out_per_subgroup
from bcorr_utils import bcorr_sampling, evaluate
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.neural_network._base import ACTIVATIONS
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import roc_curve, auc, roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.decomposition import PCA
from sklearn.inspection import permutation_importance
from fairlearn.metrics import equalized_odds_difference, demographic_parity_difference
import matplotlib.pyplot as plt
import seaborn as sns
import tabulate
import pickle
# import utils
import copy

import matplotlib as mpl

# Setting the font family, size, and weight globally
mpl.rcParams['font.family'] = 'DejaVu Sans'
mpl.rcParams['font.size'] = 8
mpl.rcParams['font.weight'] = 'light'

In [5]:
i = -0.4
j = -0.1
experiment = MIAExperiment(sampling_condition_dict = 
    {
            'subgroup_col_name': 'SEX',
            'n': 25000,
            'correlation_by_subgroup_values': [i, j],
            # 'fixed_corr_in_test_data': True
    }, shortname = f"Corr_btn_sens_and_output_for_male_({i})_for_female_({j})", random_state = 0
)

  0%|          | 0/2 [00:00<?, ?it/s]

{0: {(0, 1): 8750, (0, 0): 3750, (1, 1): 3750, (1, 0): 8750}, 1: {}}


 50%|█████     | 1/2 [00:01<00:01,  1.14s/it]

{0: {(0, 1): 8750, (0, 0): 3750, (1, 1): 3750, (1, 0): 8750}, 1: {(0, 1): 6875, (0, 0): 5625, (1, 1): 5625, (1, 0): 6875}}


100%|██████████| 2/2 [00:02<00:00,  1.14s/it]


[12500, 12500, 12500, 12500]


In [10]:
correlation = round(get_corr_btn_sens_and_out_per_subgroup(experiment, experiment.X_train, experiment.y_tr, {'SEX': 0}), 2)
print(f"Correlation between sensitive attribute and output for subgroup 0: {correlation}")
correlation = round(get_corr_btn_sens_and_out_per_subgroup(experiment, experiment.X_train, experiment.y_tr, {'SEX': 1}), 2)
print(f"Correlation between sensitive attribute and output for subgroup 1: {correlation}")

Correlation between sensitive attribute and output for subgroup 0: -0.4
Correlation between sensitive attribute and output for subgroup 1: -0.1


In [11]:
save_model=True
print(f"Training classifier for experiment: {experiment}")
try:
    experiment.clf = model_utils.load_model(f'<PATH_TO_MODEL>/{experiment.ds.ds.filenameroot}_target_model.pkl')
    print(f"Loaded classifier for experiment from file: {experiment}")
except:
    base_model = model_utils.get_model(max_iter=500)
    experiment.clf = copy.deepcopy(base_model)
    experiment.clf.fit(experiment.X_train, experiment.y_tr_onehot)

    if save_model:
        model_utils.save_model(experiment.clf, f'<PATH_TO_MODEL>/{experiment.ds.ds.filenameroot}_target_model.pkl')

Training classifier for experiment: Census19_subgroup_col_name_SEX_n_25000_correlation_by_subgroup_values_[-0.4, -0.1]_rs0


In [30]:
experiment.X_train_balanced_corr, experiment.y_tr_balanced_corr, experiment.y_tr_onehot_balanced_corr = bcorr_sampling(experiment, experiment.X_train, experiment.y_tr, experiment.y_tr_onehot, subgroup_col_name='SEX')

{0: {(0, 1): 4125, (0, 0): 3375, (1, 1): 3375, (1, 0): 4125}, 1: {(0, 1): 6187, (0, 0): 5062, (1, 1): 5063, (1, 0): 6188}}


100%|██████████| 2/2 [00:00<00:00, 21.74it/s]


In [31]:
correlation = round(get_corr_btn_sens_and_out_per_subgroup(experiment, experiment.X_train_balanced_corr, experiment.y_tr_balanced_corr, {'SEX': 0}), 2)
print(f"Correlation between sensitive attribute and output for subgroup 0: {correlation}")
correlation = round(get_corr_btn_sens_and_out_per_subgroup(experiment, experiment.X_train_balanced_corr, experiment.y_tr_balanced_corr, {'SEX': 1}), 2)
print(f"Correlation between sensitive attribute and output for subgroup 1: {correlation}")

Correlation between sensitive attribute and output for subgroup 0: -0.1
Correlation between sensitive attribute and output for subgroup 1: -0.1


In [18]:
save_model=True
print(f"Training classifier for experiment: {experiment}")
try:
    experiment.clf_balanced_corr = model_utils.load_model(f'<PATH_TO_MODEL>/{experiment.ds.ds.filenameroot}_target_model_bcorr.pkl')
    print(f"Loaded classifier for experiment from file: {experiment}")
except:
    base_model = model_utils.get_model(max_iter=500)
    experiment.clf_balanced_corr = copy.deepcopy(base_model)
    experiment.clf_balanced_corr.fit(experiment.X_train_balanced_corr, experiment.y_tr_balanced_corr)

    if save_model:
        model_utils.save_model(experiment.clf, f'<PATH_TO_MODEL>/{experiment.ds.ds.filenameroot}_target_model_bcorr.pkl')

Training classifier for experiment: Census19_subgroup_col_name_SEX_n_25000_correlation_by_subgroup_values_[-0.4, -0.1]_rs0


In [28]:
res_dict = {
    'w/o BCorr': evaluate(experiment, experiment.clf, experiment.X_train, experiment.y_tr, experiment.X_test, experiment.y_te, subgroup_col_name='SEX'),
    'w Bcorr': evaluate(experiment, experiment.clf_balanced_corr, experiment.X_train_balanced_corr, experiment.y_tr_balanced_corr, experiment.X_test, experiment.y_te, subgroup_col_name='SEX')
}
res_dict_df = pd.DataFrame.from_dict(res_dict, orient='index')

In [29]:
res_dict_df

Unnamed: 0,ASRD_CSMIA,ASRD_LOMIA,EOD,DPD,MA
w/o BCorr,11.8,14.65,0.0726,0.1284,73.904
w Bcorr,1.75,2.19,0.0415,0.0887,72.442
