Skip to content

harsh-parikh/MALTS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Introduction

PyMALTS is a learning-to-matching interpretable causal inference method. PyMALTS implements MALTS algorithm proposed by Harsh Parikh, Cynthia Rudin and Alexander Volfovsky in their 2019 paper titled "MALTS: Matching After Learning to Stretch"

Getting Started

PyMALTS is a Python3 library and it requires numpy, pandas, scikit-learn, matplotlib and seaborn. Let's first import the necessary libraries.

import pymalts
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(0)

Reading the data

We use pandas to read the data from CSV file into a dataframe but you can use your favorite file-reader.

df_train = pd.read_csv('example/example_training.csv',index_col=0)
df_est = pd.read_csv('example/example_estimation.csv',index_col=0)

Looking at first few rows of the training data and estimation data.

df_train.head()
X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 outcome treated
452 1.819048 1.647819 1.308953 1.734019 0.855118 1.977119 1.350102 1.680756 1.461630 0.182236 1.814832 2.414663 1.349697 1.689332 1.908062 0.213754 1.995093 0.229470 65.357913 1
470 0.391609 1.680689 0.800403 0.966430 0.520767 0.509781 1.089567 1.135084 1.573440 2.104325 2.164498 0.770808 2.872847 1.335996 1.922559 1.194346 1.144125 0.786834 -8.773791 0
311 2.002644 1.936058 1.684250 1.817321 2.370025 2.194189 1.506828 2.023798 -1.818878 -0.173793 -1.103308 -1.182929 -0.728503 1.144823 1.438420 2.330082 1.169831 0.955778 -15.044913 0
243 0.944754 1.165708 0.932396 1.807640 0.532053 0.527254 0.537563 0.692490 -1.447944 -1.974136 2.079438 1.496016 0.418019 3.421552 -1.644336 0.551145 0.577592 0.263000 -5.346794 0
155 2.390528 1.509044 1.675889 1.589362 1.636131 1.678246 1.755178 1.312119 -2.048745 0.335748 2.007166 2.626542 1.414703 0.826678 2.482560 1.616941 0.059490 0.780916 -17.352554 0
df_est.head()
X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 outcome treated
1355 1.881335 1.684164 0.532332 2.002254 1.435032 1.450196 1.974763 1.321659 0.709443 -1.141244 0.883130 0.956721 2.498229 2.251677 0.375271 -0.545129 3.334220 0.081259 -15.679894 0
1320 0.666476 1.263065 0.657558 0.498780 1.096135 1.002569 0.881916 0.740392 2.780857 -0.765889 1.230980 -1.214324 -0.040029 1.554477 4.235513 3.596213 0.959022 0.513409 -7.068587 0
1233 -0.193200 0.961823 1.652723 1.117316 0.590318 0.566765 0.775715 0.938379 -2.055124 1.942873 -0.606074 3.329552 -1.822938 3.240945 2.106121 0.857190 0.577264 -2.370578 -5.133200 0
706 1.378660 1.794625 0.701158 1.815518 1.129920 1.188477 0.845063 1.217270 5.847379 0.566517 -0.045607 0.736230 0.941677 0.835420 -0.560388 0.427255 2.239003 -0.632832 39.684984 1
438 0.434297 0.296656 0.545785 0.110366 0.151758 -0.257326 0.601965 0.499884 -0.973684 -0.552586 -0.778477 0.936956 0.831105 2.060040 3.153799 0.027665 0.376857 -1.221457 -2.954324 0

Using MALTS

Distance metric learning

Setting up the model for learning the distance metric.

  1. Variable name for the outcome variable: 'outcome'.
  2. Variable name for the treatment variable: 'treated'
m = pymalts.malts( outcome='outcome', treatment='treated', data=df_train, discrete=[], k=10 )

Fitting the model

res = m.fit()

Getting Matched Groups

mg = m.get_matched_groups(df_estimation = df_est, k=20 )

Estimating Conditional Average Treatment Effect (CATE)

Fit linear regression models for control and treated units inside the matched group and use the difference of the estimated treatment and control potential outcomes for the unit of interest as it's corresponding CATE estimate.

cate = m.CATE( mg, model='linear' )

Looking at the estimated CATEs for first few units

cate.head()
CATE outcome treatment
0 72.115699 -15.679894 0
1 23.664641 -7.068587 0
2 29.161183 -5.133200 0
3 52.595566 39.684984 1
4 4.643755 -2.954324 0

PyMALTS also allows estimating CATEs by vanilla average between treated and control units inside a matched group

cate_mean = m.CATE( mg, model='mean' )
cate_mean.head()
CATE outcome treatment
0 64.714527 -15.679894 0
1 28.313943 -7.068587 0
2 26.927607 -5.133200 0
3 52.828022 39.684984 1
4 5.914307 -2.954324 0

Plotting the probability distribution of CATE.

sns.distplot(cate['CATE'])
plt.axvline(np.mean(cate['CATE']),c='red')
plt.axvline(np.median(cate['CATE']),c='blue')
<matplotlib.lines.Line2D at 0x1a176aed10>

png

Estimating the Average Treatment Effect (ATE)

We use law of iterated expectation to estimate ATE. Thus, we use

ate = np.mean(cate['CATE'])
ate
41.709461029592596

Visualization

PyMALTS also provides tools to visualize inside each matched group. The tools plots a scatter plot of outcome versus each covariate, and the color shows if the point is treated or control.

df_mg1 = m.visualizeMG( MG=mg, a=10 )
<Figure size 432x288 with 0 Axes>

png

png

About

Matching After Learning to Stretch Code

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages