In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings

warnings.filterwarnings('ignore')

In [3]:
from tqdm import tqdm
import os
import data_utils
import model_utils
from attack_utils import get_CSMIA_case_by_case_results, CSMIA_attack, LOMIA_attack
from data_utils import oneHotCatVars, filter_random_data_by_conf_score
from experiment_utils import MIAExperiment
from disparity_inference_utils import get_confidence_array, draw_confidence_array_scatter, get_indices_by_group_condition, get_corr_btn_sens_and_out_per_subgroup, get_slopes, get_angular_difference, calculate_stds, get_mutual_info_btn_sens_and_out_per_subgroup
from bcorr_utils import bcorr_sampling, evaluate, MLPClassifierFC
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.neural_network._base import ACTIVATIONS
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics import roc_curve, auc, roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.decomposition import PCA
from sklearn.inspection import permutation_importance
from fairlearn.metrics import equalized_odds_difference, demographic_parity_difference
import matplotlib.pyplot as plt
import seaborn as sns
import tabulate
import pickle
import copy

import matplotlib as mpl

# Load Dataset

In [4]:
experiments = {}

i = -0.4
j = -0.1
experiment = MIAExperiment(sampling_condition_dict = 
    {
            'subgroup_col_name': 'SEX',
            'n': 25000,
            'correlation_by_subgroup_values': [i, j],
    }, shortname = f"Corr_btn_sens_and_output_for_male_({i})_for_female_({j})", random_state = 0
)
experiments[experiment.name] = experiment

experiment_texas = MIAExperiment(sampling_condition_dict =
    {
            'subgroup_col_name': 'SEX_CODE',
            'n': 25000,
            'correlation_by_subgroup_values': [i, j],
    }, shortname = f"Corr_btn_sens_and_output_for_male_({i})_for_female_({j})", random_state = 0, name = "Texas100", sensitive_column = 'ETHNICITY'
)
experiments[experiment_texas.name] = experiment_texas

i = 0
experiment_multi_valued = MIAExperiment(sampling_condition_dict = 
        {
                'subgroup_col_name': 'ST',
                'n': 1000,
        }, random_state = i,
        shortname = f"Corr_btn_sens_and_output_for_ST_ranging_from_0_to_-0.5_random_state_{i}"
    )
experiments[f"{experiment_multi_valued.name}_multi_valued"] = experiment_multi_valued

subgroup_vals = [1, 2, 3, 4, 6, 20, 50, 51, 62, 63]
experiment_multi_valued_texas = MIAExperiment(sampling_condition_dict = 
    {
            'subgroup_col_name': 'PAT_STATUS',
            'subgroup_values': subgroup_vals,
            'n': 5000
    }, shortname = f"Corr_btn_sens_and_output_for_PAT_STATUS_ranging_from_0_to_-0.5", name='Texas100', sensitive_column='SEX_CODE'
)
experiments[f"{experiment_multi_valued_texas.name}_multi_valued"] = experiment_multi_valued_texas

In [5]:
for experiment_key in experiments:
    experiment = experiments[experiment_key]
    experiment.subgroup_col_name = experiment.sampling_condition_dict['subgroup_col_name']
    experiment.subgroup_vals = [col.split('_')[-1] for col in experiment.X_train.columns if col.startswith(experiment.subgroup_col_name)]
    print(f"\nDataset: {experiment.name}, Subgroup: {experiment.subgroup_col_name}")
    correlations_dict = {val: round(get_corr_btn_sens_and_out_per_subgroup(experiment, experiment.X_train, experiment.y_tr, {experiment.subgroup_col_name: val}), 2) for val in experiment.subgroup_vals}
    print(f"Correlations: {correlations_dict}")


Dataset: Census19, Subgroup: SEX
Correlations: {'0': -0.4, '1': -0.1}

Dataset: Texas100, Subgroup: SEX_CODE
Correlations: {'0': -0.4, '1': -0.1}

Dataset: Census19, Subgroup: ST
Correlations: {'0': 0.0, '1': -0.01, '2': -0.02, '3': -0.03, '4': -0.04, '5': -0.05, '6': -0.06, '7': -0.07, '8': -0.08, '9': -0.09, '10': -0.1, '11': -0.11, '12': -0.12, '13': -0.13, '14': -0.14, '15': -0.15, '16': -0.16, '17': -0.17, '18': -0.18, '19': -0.18, '20': -0.2, '21': -0.21, '22': -0.21, '23': -0.23, '24': -0.23, '25': -0.25, '26': -0.25, '27': -0.27, '28': -0.27, '29': -0.29, '30': -0.29, '31': -0.3, '32': -0.31, '33': -0.32, '34': -0.33, '35': -0.34, '36': -0.35, '37': -0.36, '38': -0.37, '39': -0.38, '40': -0.39, '41': -0.38, '42': -0.41, '43': -0.42, '44': -0.43, '45': -0.44, '46': -0.45, '47': -0.46, '48': -0.47, '49': -0.48, '50': -0.49}

Dataset: Texas100, Subgroup: PAT_STATUS
Correlations: {'1': -0.0, '2': 0.05, '3': 0.1, '4': 0.15, '6': 0.2, '20': 0.25, '50': 0.3, '51': 0.35, '62': 0.4, '6