In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import os
import warnings
warnings.filterwarnings('ignore')
os.environ["PYTHONWARNINGS"] = "ignore"

In [3]:
from virny.datasets import ACSEmploymentDataset, ACSPublicCoverageDataset
from virny.utils.protected_groups_partitioning import create_test_protected_groups

In [4]:
def get_proportions(protected_groups, X_data):
    for col_name in protected_groups.keys():
        proportion = protected_groups[col_name].shape[0] / X_data.shape[0]
        print(f'{col_name}: {round(proportion, 3)}')


def get_base_rate(protected_groups, y_data):
    for col_name in protected_groups.keys():
        filtered_df = y_data.iloc[protected_groups[col_name].index].copy(deep=True)
        base_rate = filtered_df[filtered_df == 1].shape[0] / filtered_df.shape[0]
        print(f'{col_name}: {round(base_rate, 3)}')

    base_rate = y_data[y_data == 1].shape[0] / y_data.shape[0]
    print(f'overall: {round(base_rate, 3)}')

In [5]:
sensitive_attributes_dct = {'SEX': '2', 'RAC1P': ['2', '3', '4', '5', '6', '7', '8', '9'], 'SEX & RAC1P': None}

In [6]:
data_loader = ACSPublicCoverageDataset(state=['NY'], year=2018, with_nulls=False,
                                       subsample_size=50_000, subsample_seed=42)
data_loader.full_df.head()

Unnamed: 0,SCHL,MAR,SEX,DIS,ESP,CIT,MIG,MIL,ANC,NATIVITY,DEAR,DEYE,DREM,ESR,ST,FER,RAC1P,AGEP,PINCP,PUBCOV
0,19,3,2,2,0,1,1,4,3,1,2,2,2,1,36,0,1,52,22000.0,1
1,16,3,2,2,0,5,1,4,1,2,2,2,2,1,36,2,1,39,25000.0,1
2,16,5,1,2,0,1,1,4,1,1,2,2,2,6,36,0,1,62,10000.0,0
3,21,5,1,1,0,1,1,4,4,1,2,1,2,6,36,0,2,49,23800.0,1
4,16,1,2,2,0,4,1,4,1,2,2,2,2,1,36,2,6,44,7000.0,1


In [7]:
data_loader.full_df.shape

(50000, 20)

In [8]:
protected_groups = create_test_protected_groups(data_loader.X_data, data_loader.X_data, sensitive_attributes_dct)

In [9]:
for col_name in protected_groups.keys():
    print(f'{col_name}: {protected_groups[col_name].shape[0]}')

SEX_priv: 21785
SEX_dis: 28215
RAC1P_priv: 31233
RAC1P_dis: 18767
SEX&RAC1P_priv: 39710
SEX&RAC1P_dis: 10290


In [10]:
protected_groups.keys()

dict_keys(['SEX_priv', 'SEX_dis', 'RAC1P_priv', 'RAC1P_dis', 'SEX&RAC1P_priv', 'SEX&RAC1P_dis'])

In [11]:
get_proportions(protected_groups, data_loader.X_data)

SEX_priv: 0.436
SEX_dis: 0.564
RAC1P_priv: 0.625
RAC1P_dis: 0.375
SEX&RAC1P_priv: 0.794
SEX&RAC1P_dis: 0.206


In [12]:
get_base_rate(protected_groups, data_loader.y_data)

SEX_priv: 0.414
SEX_dis: 0.388
RAC1P_priv: 0.35
RAC1P_dis: 0.482
SEX&RAC1P_priv: 0.376
SEX&RAC1P_dis: 0.49
overall: 0.399
