In [1]:
# Packages
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import mean_squared_error, r2_score

import PyGRF

### Read data and split into training and test sets

In [2]:
data_obesity = pd.read_csv("../Data/Obesity.csv")
y = data_obesity[["Census tract code", "obesity rate"]]
X_train, X_test, y_train, y_test = train_test_split(data_obesity, y, test_size=0.3, random_state=42)

### Search the optimal bandwidth and local model weight using incremental spatial autocorrelation

In [3]:
bandwidth, local_weight, p_value = PyGRF.search_bw_lw_ISA(X_train["obesity rate"], X_train[['Lon', 'Lat']])

 There are 410 disconnected components.
 There are 16 disconnected components.
 There are 4 disconnected components.
 There are 2 disconnected components.
  self.seI_norm = self.VI_norm ** (1 / 2.0)
  self.seI_rand = VIR ** (1 / 2.0)


bandwidth: 152, moran's I: 0.44881501000042584, p-value: 0.0


### Evaluate performance of PyGRF using 10-fold cross validation

In [4]:
# function for standarizing independent variables
def standarize_data(data, stats):
    return (data - stats['mean']) / stats['std']

In [5]:
# get columns for only dependent variables
columns_to_exclude = ['Census tract code', 'Lon', 'Lat', 'obesity rate']
X_columns = [column for column in data_obesity.columns if column not in columns_to_exclude]

y_predict = []
y_true = []
df_local_fi = pd.DataFrame()
df_global_fi = pd.DataFrame()

K_fold = KFold(n_splits=10, shuffle=True, random_state=42)

i = 0
for train_index, test_index in K_fold.split(data_obesity):
    print("fold:", i)
    
    # get the training and test data in each fold
    X_train_all, X_test_all = data_obesity.iloc[train_index], data_obesity.iloc[test_index]
    y_train, y_test = X_train_all['obesity rate'], X_test_all['obesity rate']
    X_train = X_train_all[X_columns]
    X_test = X_test_all[X_columns]
    xy_coord = X_train_all[['Lon', 'Lat']]
    coords_test = X_test_all[['Lon', 'Lat']]

    # standarize independent variables
    training_stat = X_train.describe().transpose()
    X_scaled_train = standarize_data(X_train, training_stat)
    X_scaled_test = standarize_data(X_test, training_stat)

    # create a PyGRF model
    pygrf = PyGRF.PyGRFBuilder(n_estimators=400, max_features=1/3, band_width=152, train_weighted=True, predict_weighted=True, bootstrap=False,
                          resampled=True, random_state=42)

    # fit the model and use it to make predictions
    pygrf.fit(X_scaled_train, y_train, xy_coord)
    predict_combined, predict_global, predict_local = pygrf.predict(X_scaled_test, coords_test, local_weight=0.4488)

    # get the feature importance output by local models
    local_fi = pygrf.get_local_feature_importance()
    df_local_fi = pd.concat([df_local_fi, local_fi])

    # get the feature importance output by the global random forest model
    global_fi = pygrf.global_model.feature_importances_
    df_global_fi = pd.concat([df_global_fi, pd.DataFrame(data=global_fi.reshape(1, -1), columns=X_columns)])

    y_predict = y_predict + predict_combined
    y_true = y_true + y_test.tolist()

    i = i + 1

fold: 0
fold: 1
fold: 2
fold: 3
fold: 4
fold: 5
fold: 6
fold: 7
fold: 8
fold: 9


In [6]:
# compute the RMSE and r-square
rmse = mean_squared_error(y_true, y_predict, squared=False)
r2 = r2_score(y_true, y_predict)
print("rmse: " + str(round(rmse, 4)), "r2: " + str(round(r2, 4)))

rmse: 1.6341 r2: 0.9229


### Examine the obtained feature importance

In [7]:
# show the local feature importance
print(df_local_fi.shape)
df_local_fi.head()

(17955, 22)


Unnamed: 0,model_index,% Black,% Ame Indi and AK Native,% Asian,% Nati Hawa and Paci Island,% Hispanic or Latino,% male,% married,% age 18-29,% age 30-39,...,% age >=60,% <highschool,median income,% unemployment,% below poverty line,% food stamp/SNAP,median value units built,median year units built,% renter-occupied housing units,population density
0,0,0.311021,0.002785,0.210766,0.000297,0.024867,0.004729,0.088088,0.004334,0.010742,...,0.018941,0.050142,0.014501,0.008568,0.011491,0.055468,0.138184,0.006007,0.010586,0.015362
1,1,0.326607,0.004213,0.097843,0.000439,0.034703,0.005782,0.03037,0.006006,0.008659,...,0.028549,0.15319,0.014234,0.016159,0.02193,0.119729,0.070204,0.00614,0.030484,0.0143
2,2,0.315451,0.003027,0.188863,0.00038,0.022814,0.004493,0.046391,0.005409,0.007128,...,0.012401,0.054442,0.013475,0.023643,0.012675,0.094314,0.133486,0.005514,0.028606,0.018284
3,3,0.318866,0.003513,0.147762,0.001417,0.03036,0.0037,0.087574,0.005744,0.013925,...,0.017146,0.054844,0.013198,0.012265,0.013483,0.072873,0.145519,0.005182,0.016203,0.022448
4,4,0.311025,0.003404,0.179981,0.000582,0.029167,0.003886,0.057562,0.009107,0.007423,...,0.011285,0.069788,0.011834,0.021366,0.025068,0.095424,0.113794,0.006453,0.01781,0.009728


In [8]:
# show the global feature importance
print(df_global_fi.shape)
df_global_fi.head()

(10, 21)


Unnamed: 0,% Black,% Ame Indi and AK Native,% Asian,% Nati Hawa and Paci Island,% Hispanic or Latino,% male,% married,% age 18-29,% age 30-39,% age 40-49,...,% age >=60,% <highschool,median income,% unemployment,% below poverty line,% food stamp/SNAP,median value units built,median year units built,% renter-occupied housing units,population density
0,0.310024,0.002279,0.125044,0.000605,0.052558,0.007214,0.031269,0.011756,0.012159,0.005952,...,0.015568,0.069677,0.069877,0.009793,0.032149,0.072003,0.109989,0.015399,0.014452,0.023637
0,0.307121,0.002107,0.124237,0.000632,0.053164,0.006846,0.032326,0.011463,0.012076,0.00552,...,0.014876,0.071732,0.073431,0.010324,0.031172,0.076012,0.104762,0.016543,0.013913,0.023622
0,0.311305,0.002254,0.109196,0.000481,0.053045,0.007096,0.028317,0.011317,0.011336,0.005812,...,0.01658,0.066776,0.078061,0.010493,0.034705,0.071332,0.119501,0.016218,0.014194,0.02417
0,0.311867,0.002148,0.11664,0.000543,0.054791,0.006927,0.030061,0.011235,0.012311,0.006027,...,0.015424,0.072949,0.076302,0.010176,0.032683,0.06958,0.107259,0.016566,0.015308,0.023724
0,0.31298,0.002335,0.121388,0.000601,0.056305,0.006994,0.032938,0.010862,0.011751,0.005596,...,0.015384,0.065075,0.07899,0.010789,0.032612,0.068442,0.107614,0.01473,0.013997,0.022693
