# HW06: Python 
(due November 3rd)

# Heterogeneous Treatment Effects with Keras MLP

In [1]:
import pandas as pd
import numpy as np
from tensorflow import keras
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping

In [2]:
# Dataset on RCT study of case management on mental health outcomes

# show variable labels
pd.read_stata('http://www.homepages.ucl.ac.uk/~rmjwiww/stata/missing/uk500.dta', iterator=True).variable_labels()

{'trialid': 'Trial ID',
 'centreid': 'Trial centre',
 'status': 'Patient status at baseline',
 'age': 'Age in years at baseline',
 'sex': 'Sex',
 'afcarib': 'Ethnic group',
 'ocfabth': "Father's social class at birth",
 'chron1l': 'Months since onset of psychosis, logged',
 'hos94': 'Days in hospital for psychiatric reasons: 2 years before baseline',
 'cprs94': 'Psychopathology at baseline (CPRS)',
 'das94': 'Disability at baseline (DAS)',
 'sat94': '(Dis)satisfaction with services at baseline',
 'rand': 'Randomised group',
 'hos96': 'Days in hospital for psychiatric reasons: 2 years after baseline',
 'cprs96': 'Psychopathology at 2 years (CPRS)',
 'sat96': '(Dis)satisfaction with services at 2 years'}

In [3]:
# Load data 
df = pd.read_stata('http://www.homepages.ucl.ac.uk/~rmjwiww/stata/missing/uk500.dta')
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 500 entries, 0 to 499
Data columns (total 16 columns):
 #   Column    Non-Null Count  Dtype   
---  ------    --------------  -----   
 0   trialid   500 non-null    float64 
 1   centreid  500 non-null    category
 2   status    500 non-null    category
 3   age       500 non-null    float64 
 4   sex       500 non-null    category
 5   afcarib   500 non-null    category
 6   ocfabth   413 non-null    category
 7   chron1l   500 non-null    float32 
 8   hos94     500 non-null    float64 
 9   cprs94    500 non-null    float64 
 10  das94     489 non-null    float64 
 11  sat94     399 non-null    float64 
 12  rand      500 non-null    category
 13  hos96     477 non-null    float64 
 14  cprs96    421 non-null    float64 
 15  sat96     349 non-null    float64 
dtypes: category(6), float32(1), float64(9)
memory usage: 44.7 KB


In [4]:
# the post-treatment outcomes to analyze
outcomes = ['sat96', 'hos96', 'cprs96']
df[outcomes].describe()

Unnamed: 0,sat96,hos96,cprs96
count,349.0,477.0,421.0
mean,17.102794,72.463312,18.624924
std,4.730995,113.007423,13.583216
min,9.0,0.0,0.0
25%,13.5,0.0,8.0
50%,17.0,22.0,16.0
75%,20.0,98.0,27.0
max,32.0,725.0,71.0


In [5]:
# variable describing treatment status
treatvar = 'rand'
df[treatvar].value_counts()

Intensive case management    251
Standard case management     249
Name: rand, dtype: int64

In [6]:
df.head()

Unnamed: 0,trialid,centreid,status,age,sex,afcarib,ocfabth,chron1l,hos94,cprs94,das94,sat94,rand,hos96,cprs96,sat96
0,222064.0,St Mary's,In hospital,59.0,male,Other,B,6.011267,13.0,12.0,1.285714,20.0,Standard case management,32.0,47.0,
1,107.0,St George's,Out-patient,27.0,male,Other,A,3.178054,80.0,4.0,0.285714,18.0,Intensive case management,27.0,3.0,22.0
2,222005.0,St Mary's,In hospital,41.0,male,Other,D,4.521789,240.0,6.0,0.75,15.0,Intensive case management,15.0,13.0,9.0
3,222018.0,St Mary's,In hospital,25.0,male,Other,C2,4.094345,48.0,12.0,0.125,18.0,Intensive case management,263.0,6.0,21.375
4,222049.0,St Mary's,In hospital,50.0,female,Other,C2,5.817111,63.0,25.0,0.5,20.0,Standard case management,5.0,8.0,


In [7]:
df = df.dropna()

float64var = ['trialid', 'age',  'hos94', 'cprs94', 'das94', 'sat94', 'hos96', 'cprs96', 'sat96']
catvar = ['centreid', 'status', 'sex', 'afcarib', 'ocfabth']

df['centreid'] = df['centreid'].astype('category').cat.codes
df['status'] = df['status'].astype('category').cat.codes
df['sex'] = df['sex'].astype('category').cat.codes
df['afcarib'] = df['afcarib'].astype('category').cat.codes
df['ocfabth'] = df['ocfabth'].astype('category').cat.codes

df.head()

Unnamed: 0,trialid,centreid,status,age,sex,afcarib,ocfabth,chron1l,hos94,cprs94,das94,sat94,rand,hos96,cprs96,sat96
1,107.0,0,0,27.0,0,0,0,3.178054,80.0,4.0,0.285714,18.0,Intensive case management,27.0,3.0,22.0
2,222005.0,2,1,41.0,0,0,4,4.521789,240.0,6.0,0.75,15.0,Intensive case management,15.0,13.0,9.0
3,222018.0,2,1,25.0,0,0,3,4.094345,48.0,12.0,0.125,18.0,Intensive case management,263.0,6.0,21.375
5,312015.0,3,0,31.0,1,0,0,4.787492,60.0,28.0,2.375,20.0,Intensive case management,45.0,19.0,17.0
6,221023.0,2,1,35.0,0,1,3,4.430817,60.0,25.0,1.571428,24.0,Intensive case management,58.0,27.0,19.125


In [8]:
# covariates for predicting the outcome conditional on treatment
covariates = ['status', 'sex', 'sat94', 'ocfabth', 'hos94', 'das94', 'cprs94', 'age', 'afcarib']
df[covariates].describe()

Unnamed: 0,status,sex,sat94,ocfabth,hos94,das94,cprs94,age,afcarib
count,246.0,246.0,246.0,246.0,246.0,246.0,246.0,246.0,246.0
mean,0.394309,0.455285,18.837907,2.715447,94.776423,1.072794,19.362691,38.593496,0.276423
std,0.489698,0.499012,4.907599,1.188519,94.375038,0.820939,13.35019,11.050044,0.44814
min,0.0,0.0,9.0,0.0,1.0,0.0,0.0,20.0,0.0
25%,0.0,0.0,15.75,2.0,33.25,0.428571,9.0,30.0,0.0
50%,0.0,0.0,19.0,3.0,63.0,1.0,17.0,36.0,0.0
75%,1.0,1.0,22.0,4.0,126.0,1.5,27.0,47.0,1.0
max,1.0,1.0,36.0,5.0,730.0,4.714283,67.0,65.0,1.0


In [9]:
# Subset the dataset by treatment (intensive) and control (standard)
df_treat = df[df[treatvar] == 'Intensive case management']
df_control = df[df[treatvar] == 'Standard case management']


In [10]:
# build an MLP model with at least 2 hidden layers, ReLU activation, batch normalization, dropout
model_tr = keras.models.Sequential(name = "Treatment")

model_tr.add(keras.layers.BatchNormalization())
model_tr.add(keras.layers.Dense(input_dim=len(covariates),units=256, activation="relu"))
model_tr.add(keras.layers.Dropout(0.2))
model_tr.add(keras.layers.Dense(128, activation="relu"))
model_tr.add(keras.layers.Dense(1))

In [11]:
model_cn = keras.models.Sequential(name = "Control")

#model_cn.add(keras.layers.BatchNormalization())
model_cn.add(keras.layers.Dense(input_dim=len(covariates),units=128, activation="relu"))
#model_cn.add(keras.layers.Dropout(0.2))
model_cn.add(keras.layers.Dense(64, activation="relu"))
model_cn.add(keras.layers.Dense(1))

In [12]:
# compile the model
model_tr.compile(loss="mean_squared_error",
              optimizer=keras.optimizers.Adam(),
              metrics="mean_squared_error")
model_cn.compile(loss="mean_squared_error",
              optimizer=keras.optimizers.Adam(),
              metrics="mean_squared_error")

In [None]:
# choose one of the three outcomes to analyze.
# fit separate models on the treatment dataset and control dataset
# use early stopping

X_treat = df_treat[covariates]
X_control = df_control[covariates]

y_treat = df_treat['hos96']
y_control = df_control['hos96']

X_tr_val_train, X_tr_test, y_tr_val_train, y_tr_test = train_test_split(X_treat, y_treat)
X_cn_val_train, X_cn_test, y_cn_val_train, y_cn_test = train_test_split(X_control, y_control)

X_tr_train, X_tr_val, y_tr_train, y_tr_val = train_test_split(X_tr_val_train, y_tr_val_train)
X_cn_train, X_cn_val, y_cn_train, y_cn_val = train_test_split(X_cn_val_train, y_cn_val_train)

earlystop = EarlyStopping(monitor='loss', patience=30)

tr_model = model_tr.fit(X_tr_train, y_tr_train, epochs=512,
                    validation_data=(X_tr_val, y_tr_val), callbacks=[earlystop], verbose=1)
cn_model = model_cn.fit(X_cn_train, y_cn_train, epochs=512,
                    validation_data=(X_cn_val, y_cn_val), callbacks=[earlystop], verbose=1)

print(len(tr_model.history['val_loss']))
print(tr_model.history['val_loss'])
print(len(cn_model.history['val_loss']))
print(cn_model.history['val_loss'])

model_tr.summary()
model_cn.summary()

Epoch 1/512


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Epoch 2/512
Epoch 3/512
Epoch 4/512
Epoch 5/512
Epoch 6/512
Epoch 7/512
Epoch 8/512
Epoch 9/512
Epoch 10/512
Epoch 11/512
Epoch 12/512
Epoch 13/512
Epoch 14/512
Epoch 15/512
Epoch 16/512
Epoch 17/512
Epoch 18/512
Epoch 19/512
Epoch 20/512
Epoch 21/512
Epoch 22/512
Epoch 23/512
Epoch 24/512
Epoch 25/512
Epoch 26/512
Epoch 27/512
Epoch 28/512
Epoch 29/512
Epoch 30/512
Epoch 31/512
Epoch 32/512
Epoch 33/512
Epoch 34/512
Epoch 35/512
Epoch 36/512
Epoch 37/512
Epoch 38/512
Epoch 39/512
Epoch 40/512
Epoch 41/512
Epoch 42/512
Epoch 43/512
Epoch 44/512
Epoch 45/512
Epoch 46/512
Epoch 47/512
Epoch 48/512
Epoch 49/512
Epoch 50/512
Epoch 51/512
Epoch 52/512
Epoch 53/512
Epoch 54/512
Epoch 

In [None]:
# form predicted outcomes for each individual, for both control and treatment
y_tr_pred = model_tr.predict(X_tr_test)
y_tr_pred_cn = model_tr.predict(X_cn_test)
y_cn_pred = model_cn.predict(X_cn_test)
y_cn_pred_tr = model_cn.predict(X_tr_test)

import seaborn as sns
import matplotlib.pyplot as plt
sns.set()

plt.figure(figsize=(8,8))
sns.scatterplot(x=y_tr_test, y=y_tr_pred.reshape(1,-1)[0])
sns.scatterplot(x=y_cn_test, y=y_tr_pred_cn.reshape(1,-1)[0])
plt.figure(figsize=(8,8))
sns.scatterplot(x=y_cn_test, y=y_cn_pred.reshape(1,-1)[0])
sns.scatterplot(x=y_tr_test, y=y_cn_pred_tr.reshape(1,-1)[0])

from sklearn.metrics import mean_squared_error, r2_score

mse_y_tr_tr = mean_squared_error(y_tr_test, y_tr_pred.reshape(1,-1)[0])
r2_y_tr_tr = r2_score(y_tr_test, y_tr_pred.reshape(1,-1)[0])

mse_y_tr_cn = mean_squared_error(y_cn_test, y_tr_pred_cn.reshape(1,-1)[0])
r2_y_tr_cn = r2_score(y_cn_test, y_tr_pred_cn.reshape(1,-1)[0])

mse_y_cn_cn = mean_squared_error(y_cn_test, y_cn_pred.reshape(1,-1)[0])
r2_y_cn_cn = r2_score(y_cn_test, y_cn_pred.reshape(1,-1)[0])

mse_y_cn_tr = mean_squared_error(y_tr_test, y_cn_pred_tr.reshape(1,-1)[0])
r2_y_cn_tr = r2_score(y_tr_test, y_cn_pred_tr.reshape(1,-1)[0])

print("MSE y tr -> tr:           {:>10.4}".format(mse_y_tr_tr))
print("R2  y tr -> tr:           {:>10.4}".format(r2_y_tr_tr))

print("MSE y tr -> cn:           {:>10.4}".format(mse_y_tr_cn))
print("R2  y tr -> cn:           {:>10.4}".format(r2_y_tr_cn))

print("MSE y cn -> cn:           {:>10.4}".format(mse_y_cn_cn))
print("R2  y cn -> cn:           {:>10.4}".format(r2_y_cn_cn))

print("MSE y cn -> tr:           {:>10.4}".format(mse_y_cn_tr))
print("R2  y cn -> tr:           {:>10.4}".format(r2_y_cn_tr))

In [None]:
# explore what features matter for the predicted difference between control and treatment
#TODO