In [1]:
%load_ext autoreload
%autoreload 2
import os
import pickle as pkl
from typing import Dict, Any
from collections import Counter

import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500) 
pd.set_option('display.max_columns', 50) 
import sklearn as sk
from tqdm import tqdm
from sklearn.metrics import (
    accuracy_score, roc_auc_score, roc_curve, precision_score, recall_score, 
    average_precision_score, precision_recall_curve, plot_roc_curve, plot_precision_recall_curve, make_scorer
)
from sklearn.model_selection import KFold, cross_validate, train_test_split
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 250
import dvu  # for visualization

import imodels

# change working directory to project root
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('../..')

from experiments.models.stablelinear import StableLinearClassifier
from experiments.models.stableskope import StableSkopeClassifier
from experiments.notebooks import viz
from experiments.util import (
    get_comparison_result, get_openml_dataset, get_clean_dataset, 
    get_best_model_under_complexity, MODEL_COMPARISON_PATH, get_best_accuracy
)

np.random.seed(0)

/accounts/projects/binyu/keyan3/imodels/experiments/notebooks


# readmission

In [27]:
readmission_rules = get_comparison_result(MODEL_COMPARISON_PATH, 'stbl_l1_mm0', 'readmission', 'test')['rule_df']['readmission']
readmission_rules_sorted = sorted(readmission_rules[5], key=lambda x: x.args[0], reverse=True)
for r in readmission_rules_sorted:
    print(r, r.args)

number_emergency > 0.5 [0.3918810077421773]
number_inpatient > 0.0 [0.35492935302451323]
number_inpatient > 1.5 [0.3227765083757946]
number_inpatient > 3.5 [0.13667475416635858]
number_diagnoses > 5.5 [0.12486111996498804]
number_inpatient > 0.5 [0.01831205379127073]
num_medications <= 11.0 [-0.02091595518155013]
number_inpatient <= 1.5 [-0.062179431262614424]
num_medications <= 11.0 and num_medications > 1.0 and repaglinide:No > 0.5 [-0.09282230349542264]
glipizide:No > 0.5 and insulin:No > 0.5 [-0.09873605540875244]
number_diagnoses <= 5.5 [-0.10136093981002635]
number_diagnoses <= 5.5 and number_inpatient <= 0.5 [-0.16316008983251556]
number_inpatient <= 0.5 [-0.1731373089320735]
number_inpatient <= 2.625 [-0.1989254450474448]


# credit

In [31]:
credit_rules = get_comparison_result(MODEL_COMPARISON_PATH, 'stbl_l1_mm1', 'credit', 'test')['rule_df']['credit']
credit_rules_sorted = sorted(credit_rules[17], key=lambda x: x.args[0], reverse=True)
for r in credit_rules_sorted:
    print(r, r.args)

PAY_2 > 1.5 [0.6418371916262372]
PAY_0 > 1.5 [0.498799119552249]
PAY_0 <= 0.5 and PAY_0 > -0.8 [-0.05001925400087011]
SEX > 1.875 [-0.05932988875046715]
PAY_AMT6 > 0.0 [-0.1060379239229436]
PAY_AMT2 > 0.0 [-0.12247692949145599]
PAY_AMT5 > 0.0 [-0.14722586470622806]
PAY_0 <= 0.5 [-0.7017497392372162]
PAY_0 <= 1.5 [-0.7176520313207045]


In [33]:
pd.read_csv('experiments/data/credit_card/credit_card_clean.csv')

Unnamed: 0,limit_bal,age,pay_0,pay_2,pay_3,pay_4,pay_5,pay_6,bill_amt1,bill_amt2,bill_amt3,bill_amt4,bill_amt5,bill_amt6,pay_amt1,pay_amt2,pay_amt3,pay_amt4,pay_amt5,pay_amt6,sex:1,sex:2,education:0,education:1,education:2,education:3,education:4,education:5,education:6,marriage:0,marriage:1,marriage:2,marriage:3,default.payment.next.month
0,20000.0,24,2,2,-1,-1,-2,-2,3913.0,3102.0,689.0,0.0,0.0,0.0,0.0,689.0,0.0,0.0,0.0,0.0,0,1,0,0,1,0,0,0,0,0,1,0,0,1
1,120000.0,26,-1,2,0,0,0,2,2682.0,1725.0,2682.0,3272.0,3455.0,3261.0,0.0,1000.0,1000.0,1000.0,0.0,2000.0,0,1,0,0,1,0,0,0,0,0,0,1,0,1
2,90000.0,34,0,0,0,0,0,0,29239.0,14027.0,13559.0,14331.0,14948.0,15549.0,1518.0,1500.0,1000.0,1000.0,1000.0,5000.0,0,1,0,0,1,0,0,0,0,0,0,1,0,0
3,50000.0,37,0,0,0,0,0,0,46990.0,48233.0,49291.0,28314.0,28959.0,29547.0,2000.0,2019.0,1200.0,1100.0,1069.0,1000.0,0,1,0,0,1,0,0,0,0,0,1,0,0,0
4,50000.0,57,-1,0,-1,0,0,0,8617.0,5670.0,35835.0,20940.0,19146.0,19131.0,2000.0,36681.0,10000.0,9000.0,689.0,679.0,1,0,0,0,1,0,0,0,0,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29995,220000.0,39,0,0,0,0,0,0,188948.0,192815.0,208365.0,88004.0,31237.0,15980.0,8500.0,20000.0,5003.0,3047.0,5000.0,1000.0,1,0,0,0,0,1,0,0,0,0,1,0,0,0
29996,150000.0,43,-1,-1,-1,-1,0,0,1683.0,1828.0,3502.0,8979.0,5190.0,0.0,1837.0,3526.0,8998.0,129.0,0.0,0.0,1,0,0,0,0,1,0,0,0,0,0,1,0,0
29997,30000.0,37,4,3,2,-1,0,0,3565.0,3356.0,2758.0,20878.0,20582.0,19357.0,0.0,0.0,22000.0,4200.0,2000.0,3100.0,1,0,0,0,1,0,0,0,0,0,0,1,0,1
29998,80000.0,41,1,-1,0,0,0,-1,-1645.0,78379.0,76304.0,52774.0,11855.0,48944.0,85900.0,3409.0,1178.0,1926.0,52964.0,1804.0,1,0,0,0,0,1,0,0,0,0,1,0,0,1


# juvenile

In [35]:
juvenile_rules = get_comparison_result(MODEL_COMPARISON_PATH, 'stbl_l1_mm1', 'juvenile', 'test')['rule_df']['juvenile']
juvenile_rules_sorted = sorted(juvenile_rules[17], key=lambda x: x.args[0], reverse=True)
for r in juvenile_rules_sorted:
    print(r, r.args)

fr_suggest_agnts_law:1 > 0.5 [1.0550330803183612]
1_failing_grade:1 > 0.5 [0.7697771637354273]
any_victimization:0 <= 0.5 [0.4535564597910702]
mar_ab_dep:0 > 0.5 [-0.1450779661670562]
ever_attcked_weapon:2 > 0.5 [-0.23627186883960527]
friends_broken_in_steal:2 > 0.5 [-0.3748466431049729]
friends_sold_drugs:2 > 0.5 [-0.39652661821608404]
neill:0 > 0.5 [-0.5408000275003031]
seen_thre_w/gun_knife:2 > 0.5 [-0.6119216522674046]
ever_attcked_weapon:1 <= 0.5 and friends_broken_in_steal:1 <= 0.5 [-0.832971330225257]
mar_ab_dep:0 > 0.5 and physically_ass:0 > 0.5 [-0.9139768518302449]


In [36]:
pd.read_csv("experiments/data/ICPSR_03986/DS0001/data_clean.csv")

Unnamed: 0,age,#_in_household,weighting_95,total_school_cuttings,total_school_sexual_assaults,total_school_muggings,total_school_threats,total_school_beatings,total_acts_at_home,total_acts_somewhere_else,total_acts_neigh,total_acts_at_school,sex:1,sex:2,grade:10,grade:11,grade:12,grade:13,grade:5,grade:6,grade:7,grade:8,grade:9,violence_in_school:1,violence_in_school:2,...,mar_ab_dep:0,mar_ab_dep:1,hard_drug_ab_dep:0,hard_drug_ab_dep:1,alcohol_dependence_or_abuse:0,alcohol_dependence_or_abuse:1,my_drinking_measure:0,my_drinking_measure:1,experimental_marijuana:0,experimental_marijuana:1,nonexpermental_marijuana:0,nonexpermental_marijuana:1,ever_illict_w/o_marijuana:0,ever_illict_w/o_marijuana:1,neill:0,neill:1,my_check_of_wv:0,my_check_of_wv:1,school_shooting:0,school_shooting:1,sch_viol_prob:0,sch_viol_prob:1,commviol_prob:0,commviol_prob:1,any_deviance
0,16.0,6.0,0.8497,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0
1,17.0,8.0,2.4505,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0
2,13.0,5.0,0.8952,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0
3,12.0,3.0,0.5868,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0
4,15.0,3.0,0.4343,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3635,12.0,4.0,0.5535,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0
3636,17.0,5.0,0.4788,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0
3637,12.0,2.0,1.0495,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0
3638,15.0,4.0,0.2316,0.0,0.0,0.0,0.0,0.0,0.0,1.0,2.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,1.0,1.0


# recidivism

In [48]:
recidivism_rules = get_comparison_result(MODEL_COMPARISON_PATH, 'stbl_l1_mm0', 'recidivism', 'test')['rule_df']['recidivism']
recidivism_rules_sorted = sorted(recidivism_rules[10], key=lambda x: x.args[0], reverse=True)
for r in recidivism_rules_sorted:
    print(r, r.args)

priors_count > 1.5 [0.2205056879474306]
decile_score > 2.5 [0.09863528053313589]
decile_score > 5.5 [0.08877310574768842]
decile_score > 5.5 and priors_count > 1.5 [0.03052336655973006]
score_text:Low > 0.5 [-0.12833766205056604]
decile_score <= 5.5 and priors_count <= 2.5 [-0.3992839085331373]
