### This script exemplifies how to use DADT decision trees

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# global imports
import sys
import time
import pickle
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# local imports
import utils
from utils import cm_metrics
from experiments import run_test

# general settings  
plt.style.use('seaborn-whitegrid')
plt.rc('font', size=14)
plt.rc('legend', fontsize=14)
plt.rc('lines', linewidth=2)
plt.rc('axes', linewidth=2)
plt.rc('axes', edgecolor='k')
plt.rc('xtick.major', width=2)
plt.rc('xtick.major', size=6)
plt.rc('ytick.major', width=2)
plt.rc('ytick.major', size=6)
plt.rc('pdf', fonttype=42)
plt.rc('ps', fonttype=42)

In [2]:
%%time
# subset of attributes
attributes = utils.get_attributes('subset1')
print('Attributes', attributes)
states = utils.states
nstates = len(states)
print('No States', nstates)

# load distances (see 01_CalculateDistances.ipynb)
dists = pickle.load( open("results/distances.pkl", "rb" ) )
# restrict to subset and states
dists = { k:{att:d for att, d in v.items() if att in attributes} for k, v in dists.items() if k[0] in states and k[1] in states and k[0] != k[1]}

# load data
X_train_s, X_test_s, y_train_s, y_test_s, X_train_t, X_test_t, y_train_t, y_test_t = utils.load_ACSPublicCoverage(attributes, states)

Attributes ['SCHL', 'MAR', 'AGEP', 'SEX', 'CIT', 'RAC1P']
No States 50
AK AL AR AZ CA CO CT DE FL GA HI IA ID IL IN KS KY LA MA MD ME MI MN MO MS MT NC ND NE NH NJ NM NV NY OH OK OR PA RI SC SD TN TX UT VA VT WA WI WV WY Wall time: 1min 49s


In [3]:
# parameters
source = 'AL'
target = 'OR'
# decision tree parameter
max_depth = 8
min_pct = 0.05 # percentage of cases, not absolute number
_dict = dists[(source, target)]
att_xw = min(_dict, key=lambda k: _dict[k]['w_y_cond']) # usage as in the paper Eq. 13
#att_xw = 'MAR' # alternatively, direct setting
print('att_xw', att_xw)
maxdepth_td = max_depth
fairness_metric = None # or 'true_positive_rate_parity' #'demographic_parity'
# derived parameters
size = len(X_train_s[source])
min_cases = int(size*min_pct)
# output distances (for debug purposes)
[(c, dists[(source, target)][c]['w_y_cond'], dists[(source, target)][c]['w_y']) for c in attributes]

att_xw MAR


[('SCHL', 0.0945665713757643, 0.09430977740840138),
 ('MAR', 0.0603242820532249, 0.08149840055558766),
 ('AGEP', 0.06311084562224045, 0.07759983660222014),
 ('SEX', 0.07598000753010789, 0.07598000753010795),
 ('CIT', 0.09012550000573281, 0.0880482163035855),
 ('RAC1P', 0.11982475732494548, 0.11274269312258639)]

In [4]:
%%time
# train on target, test on target
clf_t, cm_t, cm_unprotected_t, cm_protected_t = run_test(X_train_s[target], y_train_s[target], X_test_t[target], 
    y_test_t[target], X_td=None, max_depth=max_depth, min_cases=int(len(X_train_s[target])*min_pct), fairness_metric=fairness_metric)

Wall time: 1.27 s


In [5]:
print('cm', cm_t)
print('cm_unprotected', cm_unprotected_t)
print('cm_protected', cm_protected_t)
print('acc', cm_metrics(cm_t)[0])
print('dp', cm_metrics(cm_unprotected_t)[4] - cm_metrics(cm_protected_t)[4])
print('eop', cm_metrics(cm_unprotected_t)[1] - cm_metrics(cm_protected_t)[1])

cm [[2057  351]
 [ 760  501]]
cm_unprotected [[820 169]
 [336 250]]
cm_protected [[1237  182]
 [ 424  251]]
acc 0.6971926955573726
dp 0.05925046618456359
eop 0.054769308557704455


In [6]:
%%time
# train on source, test on target without domain adaptation
clf_s, cm_s, cm_unprotected_s, cm_protected_s = run_test(X_train_s[source], y_train_s[source], X_test_t[target], 
    y_test_t[target], max_depth=max_depth, min_cases=min_cases, fairness_metric=fairness_metric,
    X_td=X_train_t[target], y_td=y_train_t[target]) # X_td and y_td only to compute w_dist() 

Wall time: 1.51 s


In [7]:
print('cm', cm_s)
print('cm_unprotected', cm_unprotected_s)
print('cm_protected', cm_protected_s)
print('acc', cm_metrics(cm_s)[0])
print('dp', cm_metrics(cm_unprotected_s)[4] - cm_metrics(cm_protected_s)[4])
print('eop', cm_metrics(cm_unprotected_s)[1] - cm_metrics(cm_protected_s)[1])

cm [[2276  132]
 [1056  205]]
cm_unprotected [[934  55]
 [481 105]]
cm_protected [[1342   77]
 [ 575  100]]
acc 0.6762060506950123
dp 0.017060080956929097
eop 0.031032739223865513


In [8]:
%%time
# train on source, test on target, with domain adaptation
clf_da, cm_da, cm_unprotected_da, cm_protected_da = run_test(X_train_s[source], y_train_s[source], X_test_t[target], 
    y_test_t[target], max_depth=max_depth, min_cases=min_cases, fairness_metric=fairness_metric,
    da=True, X_td=X_train_t[target], y_td=y_train_t[target], att_xw =att_xw, maxdepth_td=maxdepth_td)

Wall time: 5.16 s


In [9]:
print('cm', cm_da)
print('cm_unprotected', cm_unprotected_da)
print('cm_protected', cm_protected_da)
print('acc', cm_metrics(cm_da)[0])
print('dp', cm_metrics(cm_unprotected_da)[4] - cm_metrics(cm_protected_da)[4])
print('eop', cm_metrics(cm_unprotected_da)[1] - cm_metrics(cm_protected_da)[1])

cm [[2304  104]
 [1083  178]]
cm_unprotected [[944  45]
 [495  91]]
cm_protected [[1360   59]
 [ 588   87]]
acc 0.6764786045243936
dp 0.016626188202119446
eop 0.02640121350018962


In [10]:
clf_t.tree

{'type': 'split',
 'gain': 0.021032382642826253,
 'split_col': 'MAR',
 'cutoff': 3.0,
 'tot': 11004,
 'dist': array([0.64930934, 0.35069066]),
 'left': {'type': 'split',
  'gain': 0.02162465177084877,
  'split_col': 'AGEP',
  'cutoff': 53.0,
  'tot': 1181,
  'dist': array([0.4081287, 0.5918713]),
  'left': {'type': 'leaf',
   'tot': 610,
   'dist': array([0.49016393, 0.50983607])},
  'right': {'type': 'leaf',
   'tot': 571,
   'dist': array([0.32049037, 0.67950963])}},
 'right': {'type': 'split',
  'gain': 0.012607475267715884,
  'split_col': 'MAR',
  'cutoff': 1.0,
  'tot': 9823,
  'dist': array([0.67830602, 0.32169398]),
  'left': {'type': 'split',
   'gain': 0.016373729205872767,
   'split_col': 'SCHL',
   'cutoff': 21.0,
   'tot': 4167,
   'dist': array([0.74970002, 0.25029998]),
   'left': {'type': 'leaf',
    'tot': 773,
    'dist': array([0.87839586, 0.12160414])},
   'right': {'type': 'split',
    'gain': 0.009122801548116954,
    'split_col': 'SEX',
    'cutoff': 1.0,
    'tot

In [11]:
clf_s.w_dist(), clf_s.tree

(0.11044574847988227,
 {'type': 'split',
  'gain': 0.016543420941708575,
  'split_col': 'RAC1P',
  'cutoff': 2.0,
  'tot': 14284,
  'dist': array([0.72556707, 0.27443293]),
  'left': {'type': 'split',
   'gain': 0.017551274185939547,
   'split_col': 'AGEP',
   'cutoff': 54.0,
   'tot': 4054,
   'dist': array([0.61618155, 0.38381845]),
   'left': {'type': 'split',
    'gain': 0.011780225452817539,
    'split_col': 'AGEP',
    'cutoff': 20.0,
    'tot': 3160,
    'dist': array([0.65696203, 0.34303797]),
    'left': {'type': 'leaf',
     'tot': 737,
     'dist': array([0.54545455, 0.45454545])},
    'right': {'type': 'split',
     'gain': 0.007918206166596553,
     'split_col': 'SEX',
     'cutoff': 1.0,
     'tot': 2423,
     'dist': array([0.69087908, 0.30912092]),
     'left': {'type': 'leaf',
      'tot': 1153,
      'dist': array([0.7415438, 0.2584562])},
     'right': {'type': 'leaf',
      'tot': 1270,
      'dist': array([0.64488189, 0.35511811])}}},
   'right': {'type': 'leaf',
 

In [12]:
clf_da.w_dist(), clf_da.tree

(0.11304092169912443,
 {'type': 'split',
  'gain': 0.0598708535859489,
  'split_col': 'RAC1P',
  'cutoff': 2.0,
  'tot': 14284,
  'dist': array([0.73080774, 0.26919226]),
  'left': {'type': 'split',
   'gain': 0.023039778935278643,
   'split_col': 'AGEP',
   'cutoff': 53.0,
   'tot': 4054,
   'dist': array([0.63036668, 0.36963332]),
   'left': {'type': 'split',
    'gain': 0.013052541809176588,
    'split_col': 'AGEP',
    'cutoff': 20.0,
    'tot': 3093,
    'dist': array([0.67059539, 0.32940461]),
    'left': {'type': 'leaf',
     'tot': 737,
     'dist': array([0.54293629, 0.45706371])},
    'right': {'type': 'split',
     'gain': 0.019527053471736833,
     'split_col': 'SEX',
     'cutoff': 1.0,
     'tot': 2356,
     'dist': array([0.70381489, 0.29618511]),
     'left': {'type': 'leaf',
      'tot': 1129,
      'dist': array([0.74340755, 0.25659245])},
     'right': {'type': 'leaf',
      'tot': 1227,
      'dist': array([0.67293623, 0.32706377])}}},
   'right': {'type': 'leaf',
 