PyMALTS is a learning-to-matching interpretable causal inference method. PyMALTS implements MALTS algorithm proposed by Harsh Parikh, Cynthia Rudin and Alexander Volfovsky in their 2019 paper titled "MALTS: Matching After Learning to Stretch"
PyMALTS is a Python3 library and it requires numpy, pandas, scikit-learn, matplotlib and seaborn. Let's first import the necessary libraries.
import pymalts
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(0)
We use pandas to read the data from CSV file into a dataframe but you can use your favorite file-reader.
df_train = pd.read_csv('example/example_training.csv',index_col=0)
df_est = pd.read_csv('example/example_estimation.csv',index_col=0)
Looking at first few rows of the training data and estimation data.
df_train.head()
X1 | X2 | X3 | X4 | X5 | X6 | X7 | X8 | X9 | X10 | X11 | X12 | X13 | X14 | X15 | X16 | X17 | X18 | outcome | treated | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
452 | 1.819048 | 1.647819 | 1.308953 | 1.734019 | 0.855118 | 1.977119 | 1.350102 | 1.680756 | 1.461630 | 0.182236 | 1.814832 | 2.414663 | 1.349697 | 1.689332 | 1.908062 | 0.213754 | 1.995093 | 0.229470 | 65.357913 | 1 |
470 | 0.391609 | 1.680689 | 0.800403 | 0.966430 | 0.520767 | 0.509781 | 1.089567 | 1.135084 | 1.573440 | 2.104325 | 2.164498 | 0.770808 | 2.872847 | 1.335996 | 1.922559 | 1.194346 | 1.144125 | 0.786834 | -8.773791 | 0 |
311 | 2.002644 | 1.936058 | 1.684250 | 1.817321 | 2.370025 | 2.194189 | 1.506828 | 2.023798 | -1.818878 | -0.173793 | -1.103308 | -1.182929 | -0.728503 | 1.144823 | 1.438420 | 2.330082 | 1.169831 | 0.955778 | -15.044913 | 0 |
243 | 0.944754 | 1.165708 | 0.932396 | 1.807640 | 0.532053 | 0.527254 | 0.537563 | 0.692490 | -1.447944 | -1.974136 | 2.079438 | 1.496016 | 0.418019 | 3.421552 | -1.644336 | 0.551145 | 0.577592 | 0.263000 | -5.346794 | 0 |
155 | 2.390528 | 1.509044 | 1.675889 | 1.589362 | 1.636131 | 1.678246 | 1.755178 | 1.312119 | -2.048745 | 0.335748 | 2.007166 | 2.626542 | 1.414703 | 0.826678 | 2.482560 | 1.616941 | 0.059490 | 0.780916 | -17.352554 | 0 |
df_est.head()
X1 | X2 | X3 | X4 | X5 | X6 | X7 | X8 | X9 | X10 | X11 | X12 | X13 | X14 | X15 | X16 | X17 | X18 | outcome | treated | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1355 | 1.881335 | 1.684164 | 0.532332 | 2.002254 | 1.435032 | 1.450196 | 1.974763 | 1.321659 | 0.709443 | -1.141244 | 0.883130 | 0.956721 | 2.498229 | 2.251677 | 0.375271 | -0.545129 | 3.334220 | 0.081259 | -15.679894 | 0 |
1320 | 0.666476 | 1.263065 | 0.657558 | 0.498780 | 1.096135 | 1.002569 | 0.881916 | 0.740392 | 2.780857 | -0.765889 | 1.230980 | -1.214324 | -0.040029 | 1.554477 | 4.235513 | 3.596213 | 0.959022 | 0.513409 | -7.068587 | 0 |
1233 | -0.193200 | 0.961823 | 1.652723 | 1.117316 | 0.590318 | 0.566765 | 0.775715 | 0.938379 | -2.055124 | 1.942873 | -0.606074 | 3.329552 | -1.822938 | 3.240945 | 2.106121 | 0.857190 | 0.577264 | -2.370578 | -5.133200 | 0 |
706 | 1.378660 | 1.794625 | 0.701158 | 1.815518 | 1.129920 | 1.188477 | 0.845063 | 1.217270 | 5.847379 | 0.566517 | -0.045607 | 0.736230 | 0.941677 | 0.835420 | -0.560388 | 0.427255 | 2.239003 | -0.632832 | 39.684984 | 1 |
438 | 0.434297 | 0.296656 | 0.545785 | 0.110366 | 0.151758 | -0.257326 | 0.601965 | 0.499884 | -0.973684 | -0.552586 | -0.778477 | 0.936956 | 0.831105 | 2.060040 | 3.153799 | 0.027665 | 0.376857 | -1.221457 | -2.954324 | 0 |
Setting up the model for learning the distance metric.
- Variable name for the outcome variable: 'outcome'.
- Variable name for the treatment variable: 'treated'
m = pymalts.malts( outcome='outcome', treatment='treated', data=df_train, discrete=[], k=10 )
Fitting the model
res = m.fit()
mg = m.get_matched_groups(df_estimation = df_est, k=20 )
Fit linear regression models for control and treated units inside the matched group and use the difference of the estimated treatment and control potential outcomes for the unit of interest as it's corresponding CATE estimate.
cate = m.CATE( mg, model='linear' )
Looking at the estimated CATEs for first few units
cate.head()
CATE | outcome | treatment | |
---|---|---|---|
0 | 72.115699 | -15.679894 | 0 |
1 | 23.664641 | -7.068587 | 0 |
2 | 29.161183 | -5.133200 | 0 |
3 | 52.595566 | 39.684984 | 1 |
4 | 4.643755 | -2.954324 | 0 |
PyMALTS also allows estimating CATEs by vanilla average between treated and control units inside a matched group
cate_mean = m.CATE( mg, model='mean' )
cate_mean.head()
CATE | outcome | treatment | |
---|---|---|---|
0 | 64.714527 | -15.679894 | 0 |
1 | 28.313943 | -7.068587 | 0 |
2 | 26.927607 | -5.133200 | 0 |
3 | 52.828022 | 39.684984 | 1 |
4 | 5.914307 | -2.954324 | 0 |
Plotting the probability distribution of CATE.
sns.distplot(cate['CATE'])
plt.axvline(np.mean(cate['CATE']),c='red')
plt.axvline(np.median(cate['CATE']),c='blue')
<matplotlib.lines.Line2D at 0x1a176aed10>
We use law of iterated expectation to estimate ATE. Thus, we use
ate = np.mean(cate['CATE'])
ate
41.709461029592596
PyMALTS also provides tools to visualize inside each matched group. The tools plots a scatter plot of outcome versus each covariate, and the color shows if the point is treated or control.
df_mg1 = m.visualizeMG( MG=mg, a=10 )
<Figure size 432x288 with 0 Axes>