# Section 4


This file runs the experiments and returns the top performing models as a dictionary of hyper-parameters. It also creates all figures in the paper (except for those specific to the causal PDP and CD diagrams).

In [None]:
# required modules (skip if already installed)
!pip install lightgbm
!pip install xgboost
!pip install optuna
!pip install scikit_posthocs
!pip install autorank
!pip install pytorch-tabnet

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# imports
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import sys
import pickle
import warnings
warnings.filterwarnings("ignore")
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold

# local imports
from _utils import *

# settings
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_info_columns', 500)
pd.set_option('display.max_info_rows', 500)
np.set_printoptions(threshold=sys.maxsize)

## Data preliminaries

In [8]:
df = pd.read_csv(os.path.join('..', 'data', 'lar_fl_adult.csv'), sep='|', )
df.head(5)

Unnamed: 0,AGEP,COW,SCHL,MAR,OCCP,POBP,RELP,WKHP,SEX,RAC1P,Y
0,30.0,1.0,17.0,5.0,4030.0,315.0,0.0,30.0,2.0,8.0,False
1,78.0,1.0,21.0,1.0,2750.0,25.0,0.0,1.0,1.0,1.0,False
2,68.0,1.0,19.0,1.0,4900.0,44.0,1.0,12.0,2.0,1.0,False
3,20.0,1.0,19.0,5.0,4110.0,12.0,0.0,30.0,2.0,1.0,False
4,57.0,1.0,18.0,3.0,3930.0,12.0,0.0,42.0,2.0,2.0,True


In [11]:
"""
AGEP (Age); [cont.] 
COW (Class of worker); [cat.]
SCHL (Educational attainment); [cat.: ordinal] 
MAR (Marital status); [cat.]
OCCP (Occupation); [cat.] 
POBP (Place of birth); [cat.]
RELP (Relationship); [cat.]
WKHP (Usual hours worked per week past 12 months); [cont.]
SEX (Sex); [cat.]
RAC1P (Recoded detailed race code) [cat.]
"""

df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 97238 entries, 0 to 97237
Data columns (total 11 columns):
 #   Column  Dtype  
---  ------  -----  
 0   AGEP    float64
 1   COW     float64
 2   SCHL    float64
 3   MAR     float64
 4   OCCP    float64
 5   POBP    float64
 6   RELP    float64
 7   WKHP    float64
 8   SEX     float64
 9   RAC1P   float64
 10  Y       bool   
dtypes: bool(1), float64(10)
memory usage: 7.5 MB


In [12]:
# TODO: not all of these are floats!


In [5]:
df.describe()

Unnamed: 0,AGEP,COW,SCHL,MAR,OCCP,POBP,RELP,WKHP,SEX,RAC1P
count,97238.0,97238.0,97238.0,97238.0,97238.0,97238.0,97238.0,97238.0,97238.0,97238.0
mean,44.787172,2.166077,18.586787,2.511981,4066.467184,90.934378,2.083856,38.289876,1.486476,1.629404
std,15.150233,1.976364,3.296277,1.755314,2466.007185,124.480708,4.128277,12.584155,0.49982,1.755916
min,17.0,1.0,1.0,1.0,10.0,1.0,0.0,1.0,1.0,1.0
25%,32.0,1.0,16.0,1.0,2300.0,12.0,0.0,35.0,1.0,1.0
50%,46.0,1.0,19.0,1.0,4220.0,34.0,0.0,40.0,1.0,1.0
75%,57.0,3.0,21.0,5.0,5400.0,72.0,1.0,40.0,2.0,1.0
max,95.0,8.0,24.0,5.0,9830.0,554.0,17.0,99.0,2.0,9.0


## Run experiments