In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.impute import KNNImputer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split, KFold
from xgboost import XGBRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import root_mean_squared_error

#We may want these at some point for transforming our output:
#from scipy.special import logit, expit

pd.set_option('display.max_columns', None) #forces all columns to be displayed
pd.set_option('display.max_rows', None) #forces all rows to be displayed

In [2]:
filepath = '../data/data_reduced_train.csv'
data = pd.read_csv(filepath)

In [3]:
data.head()

Unnamed: 0,FIPS,State,County,% Adults with Diabetes,% Adults Reporting Currently Smoking,% Adults with Obesity,Food Environment Index,% Physically Inactive,% With Access to Exercise Opportunities,% Excessive Drinking,...,% 65 and Over,% Black,% American Indian or Alaska Native,% Asian,% Native Hawaiian or Other Pacific Islander,% Hispanic,% Non-Hispanic White,% Not Proficient in English,% Female,% Rural
0,17027,Illinois,Clinton,8.7,16.7,34.8,9.0,25.8,63.809317,18.658612,...,18.662115,3.243111,0.403696,0.655667,0.08399,3.541142,91.007613,0.535157,47.787802,80.216266
1,42071,Pennsylvania,Lancaster,8.2,16.7,35.2,8.8,23.4,80.948635,17.168046,...,19.735587,3.722587,0.474823,2.778152,0.114798,11.589407,80.19363,2.150077,50.747266,27.87549
2,46003,South Dakota,Aurora,8.2,17.0,38.8,7.8,23.1,3.349108,21.089477,...,21.125227,0.907441,2.903811,0.943739,0.0,8.45735,85.880218,1.615576,48.566243,100.0
3,46027,South Dakota,Clay,9.2,16.3,35.6,7.6,22.2,85.56825,22.062377,...,12.624346,1.760471,4.01178,2.729058,0.065445,3.331152,86.302356,0.406533,50.425393,22.101958
4,13205,Georgia,Mitchell,15.9,22.8,42.2,6.8,34.5,59.91726,13.354151,...,18.407842,45.87043,0.672476,0.975564,0.0663,5.479257,46.12616,0.725704,49.720591,75.775684


In [None]:
#we may not even need to drop these anymore
#data_drop_null_state_features = data.drop(columns=['% Voter Turnout', 'School Funding Adequacy']) 

In [4]:
len(data.columns)

52

In [5]:
features = data.columns[4:]
target = '% Adults with Diabetes'

In [6]:
data_train, data_holdout = train_test_split(data,
                                        shuffle=True,
                                        random_state=42,
                                        test_size=0.2)

In [7]:
X_train = data_train[features]
y_train = data_train[target]
X_holdout = data_holdout[features]
y_holdout = data_holdout[target]

In [None]:
xgb_pipe = Pipeline([('impute', KNNImputer()),
                     ('xgb', XGBRegressor())])

xgb_pipe.fit(X_train, y_train)
xgb_train_preds = xgb_pipe.predict(X_train)

In [19]:
xgb_model = XGBRegressor()
xgb_model.fit(X_train, y_train)
xgb_model_train_preds = xgb_model.predict(X_train)

In [None]:
print('XGB Training RMSE with kNNImputer:', root_mean_squared_error(y_train, xgb_train_preds))
print('XGB Training RMSE without imputing:', root_mean_squared_error(y_train, xgb_model_train_preds))

Train RMSE with kNNImputer: 0.015976972268687
Train RMSE without imputing: 0.016927218715798713


## Below, we do a quick sanity check to ensure that the imputer is not giving us wild values

In [None]:
imputer = KNNImputer()

array = imputer.fit_transform(X_train)

In [16]:
np.isnan(array).sum().sum()

np.int64(0)

In [22]:
X_train_imputed = pd.DataFrame(data=array, columns=X_train.columns)

In [26]:
X_train_imputed.describe()

Unnamed: 0,% Adults Reporting Currently Smoking,% Adults with Obesity,Food Environment Index,% Physically Inactive,% With Access to Exercise Opportunities,% Excessive Drinking,% Driving Deaths with Alcohol Involvement,% Uninsured,Dentist Rate,% with Annual Mammogram,% Vaccinated,% Completed High School,% Unemployed,% Children in Poverty,Income Ratio,% Children in Single-Parent Households,Social Association Rate,Average Daily PM2.5,% Severe Housing Problems,% Drive Alone to Work,% Long Commute - Drives Alone,% Food Insecure,% Limited Access to Healthy Foods,% Insufficient Sleep,% Uninsured Children,Other Primary Care Provider Rate,School Funding Adequacy,Gender Pay Gap,Median Household Income,% Household Income Required for Child Care Expenses,% Voter Turnout,% Census Participation,Traffic Volume,% Homeowners,% Households with Severe Cost Burden,% Households with Broadband Access,Population,% Less than 18 Years of Age,% 65 and Over,% Black,% American Indian or Alaska Native,% Asian,% Native Hawaiian or Other Pacific Islander,% Hispanic,% Non-Hispanic White,% Not Proficient in English,% Female,% Rural
count,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0
mean,19.080846,37.366318,7.570677,26.64204,61.576795,16.858568,27.375944,11.440834,47.831436,41.730348,40.107662,88.303045,3.582974,19.349751,4.555425,23.916172,11.297751,7.509075,12.898025,77.685944,33.003234,11.412139,8.385946,34.479831,6.480667,109.106245,-1884.920319,0.782768,63283.775124,27.740609,65.614183,58.232488,46.708529,72.73145,10.975752,82.345228,106708.2,21.683466,20.450801,8.847672,2.601479,1.695258,0.148835,10.194657,75.060171,1.568745,49.488312,64.27671
std,4.174719,4.555487,1.177204,5.192249,22.919207,2.611362,15.154552,5.174442,32.008737,8.139939,10.406799,5.622573,1.212205,8.26686,0.823503,10.406909,5.887015,1.707951,4.548764,7.820264,12.619063,3.552768,7.353854,3.641291,3.575343,72.868849,7648.579193,0.101792,16349.712888,6.207529,10.028628,12.142842,85.092645,8.448679,3.542817,7.230942,351785.8,3.581355,4.822241,13.945457,8.322901,3.082494,0.396069,13.461868,20.109038,2.558139,2.297528,33.679398
min,7.0,17.4,0.0,12.0,0.0,9.038096,0.0,2.389743,0.0,6.0,2.0,54.906334,0.859107,2.4,2.410432,0.0,0.0,2.2,2.526316,5.66839,0.0,2.2,0.0,23.8,0.684932,0.0,-63405.88,0.315462,28972.0,7.064892,32.939271,1.2,0.000332,19.96376,0.595238,47.336501,233.0,3.546634,6.299112,0.0,0.0,0.0,0.0,0.732961,2.693887,0.0,27.039544,0.0
25%,16.4,35.1,7.0,23.1,46.742619,15.050318,18.518519,7.465734,27.322985,37.0,33.0,85.227729,2.724356,13.4,4.015577,16.961015,7.87341,6.5,10.219265,74.906346,23.6,8.9,3.703208,31.9,4.072028,66.29132,-4425.575152,0.727727,52522.25,23.684258,58.922135,50.7,8.726006,68.372645,8.552134,78.526993,10997.75,19.602721,17.313146,0.843511,0.444085,0.537928,0.040193,2.81879,63.039319,0.244788,48.963481,35.612502
50%,18.7,37.7,7.7,26.3,64.068279,16.854673,26.315789,10.307064,42.568895,42.0,41.0,89.436084,3.410447,18.1,4.434182,22.156167,10.720041,7.7,12.352907,78.999341,32.4,11.1,6.536077,34.5,5.494948,93.76549,-845.38235,0.781011,60833.5,27.326022,65.660292,59.8,24.715576,73.93378,10.494738,83.424161,25996.5,21.678597,20.104667,2.32508,0.747352,0.818378,0.076756,5.100209,82.132857,0.68282,49.870434,67.47811
75%,21.6,40.2,8.4,30.0,78.846767,18.465536,33.333333,14.539297,64.014583,47.0,47.9,92.477047,4.203793,23.875,4.953463,28.720613,14.13601,8.8,14.7857,82.253066,41.6,13.8,10.702867,36.8,7.882996,135.788755,2013.453687,0.833934,70497.0,31.303697,72.581944,67.4,51.692614,78.424408,12.865166,87.230051,68013.75,23.638154,23.220098,9.722577,1.522712,1.520247,0.14191,11.06431,91.420332,1.741557,50.613384,100.0
max,43.0,51.0,10.0,47.0,100.0,26.134563,100.0,32.747535,349.65035,64.0,68.0,99.690402,12.910382,55.0,10.532338,78.519291,59.435364,15.6,61.142857,94.245723,73.3,26.3,62.599918,46.9,31.179735,905.64147,32379.685,1.728074,167605.0,80.474882,100.0,83.4,1753.934509,96.985816,31.169239,100.0,9721138.0,42.099541,42.459275,82.122697,92.111621,46.380224,10.414296,95.291657,97.539986,32.084894,57.76242,100.0


In [27]:
X_train.describe()

Unnamed: 0,% Adults Reporting Currently Smoking,% Adults with Obesity,Food Environment Index,% Physically Inactive,% With Access to Exercise Opportunities,% Excessive Drinking,% Driving Deaths with Alcohol Involvement,% Uninsured,Dentist Rate,% with Annual Mammogram,% Vaccinated,% Completed High School,% Unemployed,% Children in Poverty,Income Ratio,% Children in Single-Parent Households,Social Association Rate,Average Daily PM2.5,% Severe Housing Problems,% Drive Alone to Work,% Long Commute - Drives Alone,% Food Insecure,% Limited Access to Healthy Foods,% Insufficient Sleep,% Uninsured Children,Other Primary Care Provider Rate,School Funding Adequacy,Gender Pay Gap,Median Household Income,% Household Income Required for Child Care Expenses,% Voter Turnout,% Census Participation,Traffic Volume,% Homeowners,% Households with Severe Cost Burden,% Households with Broadband Access,Population,% Less than 18 Years of Age,% 65 and Over,% Black,% American Indian or Alaska Native,% Asian,% Native Hawaiian or Other Pacific Islander,% Hispanic,% Non-Hispanic White,% Not Proficient in English,% Female,% Rural
count,2010.0,2010.0,1983.0,2010.0,1981.0,2010.0,1991.0,2010.0,1958.0,1998.0,1999.0,2010.0,2010.0,2010.0,1999.0,2010.0,2010.0,1991.0,2010.0,2010.0,2010.0,2010.0,1983.0,2008.0,2010.0,2002.0,1958.0,2007.0,2010.0,2010.0,1989.0,2007.0,1992.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0,2010.0
mean,19.080846,37.366318,7.567625,26.64204,61.803914,16.858568,27.373329,11.440834,48.055012,41.733233,40.123562,88.303045,3.582974,19.349751,4.556363,23.916172,11.297751,7.525565,12.898025,77.685944,33.003234,11.412139,8.41131,34.483416,6.480667,109.18953,-1880.567395,0.782694,63283.775124,27.740609,65.518577,58.247583,47.07042,72.73145,10.975752,82.345228,106708.2,21.683466,20.450801,8.847672,2.601479,1.695258,0.148835,10.194657,75.060171,1.568745,49.488312,64.27671
std,4.174719,4.555487,1.180937,5.192249,22.96397,2.611362,15.194709,5.174442,32.301401,8.156482,10.421125,5.622573,1.212205,8.26686,0.825153,10.406909,5.887015,1.704615,4.548764,7.820264,12.619063,3.552768,7.390631,3.64133,3.575343,72.979658,7649.918424,0.101809,16349.712888,6.207529,10.001493,12.144958,85.389829,8.448679,3.542817,7.230942,351785.8,3.581355,4.822241,13.945457,8.322901,3.082494,0.396069,13.461868,20.109038,2.558139,2.297528,33.679398
min,7.0,17.4,0.0,12.0,0.0,9.038096,0.0,2.389743,0.0,6.0,2.0,54.906334,0.859107,2.4,2.410432,0.0,0.0,2.2,2.526316,5.66839,0.0,2.2,0.0,23.8,0.684932,0.0,-63405.88,0.315462,28972.0,7.064892,32.939271,1.2,0.000332,19.96376,0.595238,47.336501,233.0,3.546634,6.299112,0.0,0.0,0.0,0.0,0.732961,2.693887,0.0,27.039544,0.0
25%,16.4,35.1,6.95,23.1,47.185161,15.050318,18.518519,7.465734,27.337245,37.0,34.0,85.227729,2.724356,13.4,4.014312,16.961015,7.87341,6.5,10.219265,74.906346,23.6,8.9,3.67322,31.9,4.072028,66.29132,-4425.575152,0.727733,52522.25,23.684258,58.895887,50.7,8.889837,68.372645,8.552134,78.526993,10997.75,19.602721,17.313146,0.843511,0.444085,0.537928,0.040193,2.81879,63.039319,0.244788,48.963481,35.612502
50%,18.7,37.7,7.7,26.3,64.536631,16.854673,26.315789,10.307064,42.71728,42.0,41.0,89.436084,3.410447,18.1,4.433995,22.156167,10.720041,7.7,12.352907,78.999341,32.4,11.1,6.620099,34.5,5.494948,93.76549,-868.9473,0.781007,60833.5,27.326022,65.536954,59.8,25.105147,73.93378,10.494738,83.424161,25996.5,21.678597,20.104667,2.32508,0.747352,0.818378,0.076756,5.100209,82.132857,0.68282,49.870434,67.47811
75%,21.6,40.2,8.4,30.0,79.116829,18.465536,33.333333,14.539297,64.454988,47.0,48.0,92.477047,4.203793,23.875,4.956377,28.720613,14.13601,8.8,14.7857,82.253066,41.6,13.8,10.767675,36.825,7.882996,135.842655,2004.103336,0.833872,70497.0,31.303697,72.447183,67.4,52.178669,78.424408,12.865166,87.230051,68013.75,23.638154,23.220098,9.722577,1.522712,1.520247,0.14191,11.06431,91.420332,1.741557,50.613384,100.0
max,43.0,51.0,10.0,47.0,100.0,26.134563,100.0,32.747535,349.65035,64.0,68.0,99.690402,12.910382,55.0,10.532338,78.519291,59.435364,15.6,61.142857,94.245723,73.3,26.3,62.599918,46.9,31.179735,905.64147,32379.685,1.728074,167605.0,80.474882,100.0,83.4,1753.934509,96.985816,31.169239,100.0,9721138.0,42.099541,42.459275,82.122697,92.111621,46.380224,10.414296,95.291657,97.539986,32.084894,57.76242,100.0


In [15]:
linear_pipe = Pipeline([('impute', KNNImputer()),
                        ('scale', StandardScaler()), #For coefficient size comparison
                        ('linreg', LinearRegression())])

linear_pipe.fit(X_train, y_train)
linear_train_preds = linear_pipe.predict(X_train)

In [18]:
forest_pipe = Pipeline([('impute', KNNImputer()),
                        ('rfr', RandomForestRegressor())])

forest_pipe.fit(X_train, y_train)
forest_train_preds = forest_pipe.predict(X_train)

In [None]:
print('XGB training RMSE: ', root_mean_squared_error(y_train, xgb_train_preds))
print('Linear training RMSE: ', root_mean_squared_error(y_train, linear_train_preds))
print('Random Forest training RMSE: ', root_mean_squared_error(y_train, forest_train_preds))

XGB RMSE:  0.015976972268687
Linear RMSE:  0.4719816306215972
Random Forest RMSE:  0.19995143440196442


In [21]:
xgb_holdout_preds = xgb_pipe.predict(X_holdout)
linear_holdout_preds = linear_pipe.predict(X_holdout)
forest_holdout_preds = forest_pipe.predict(X_holdout)

In [22]:
print('XGB holdout RMSE: ', root_mean_squared_error(y_holdout, xgb_holdout_preds))
print('Linear holdout RMSE: ', root_mean_squared_error(y_holdout, linear_holdout_preds))
print('Random Forest holdout RMSE: ', root_mean_squared_error(y_holdout, forest_holdout_preds))

XGB holdout RMSE:  0.55734045722899
Linear holdout RMSE:  0.49654249928710603
Random Forest holdout RMSE:  0.5700927028658015


In [None]:
#Possibly include cross validation here