# Section 2: Running Causal Forests on GSS data

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from copy import deepcopy

from sklearn.preprocessing import LabelEncoder, normalize
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from econml.grf import CausalForest
from sklearn.linear_model import LassoCV

import warnings
warnings.filterwarnings("ignore")

## Loading and Cleaning data

In [2]:
# w = 0 means the question had "assistance", w = 1 means the question had "welfare"
# y = 0 means the responder said no, y = 1 means yes
welfare_raw = pd.read_csv("welfarelabel.csv", low_memory=False)
labels = welfare_raw['y'].values
treatments = welfare_raw['w']
treatments = treatments.replace({0:1, 1:0}) # we want 1 to be assistance, and 0 to be welfare, so if the TE is positive then it means people responded favorably to assistance
welfare_raw.drop(columns=["_merge", 'y', 'id', 'w'], inplace=True)
welfare_raw

Unnamed: 0,year,wrkstat,hrs1,hrs2,evwork,occ,prestige,wrkslf,wrkgovt,commute,...,preteen_miss,teens_miss,adults_miss,unrelat_miss,earnrs_miss,income_miss,rincome_miss,income86_miss,partyid_miss,polviews_miss
0,1986,working fulltime,40.000000,38.613701,1.1395408,270.00000,44.000000,someone else,private,60,...,0,0,0,0,0,0,0,0,0,0
1,1986,keeping house,41.733318,38.613701,1,195.00000,51.000000,someone else,private,10,...,0,0,0,1,0,0,1,0,0,0
2,1986,working fulltime,40.000000,38.613701,1.1395408,184.00000,51.000000,someone else,private,35,...,0,0,0,1,0,0,0,0,0,0
3,1986,retired,41.733318,38.613701,1,311.00000,36.000000,someone else,1,25,...,0,0,0,0,0,0,1,0,0,0
4,1986,working parttime,41.733318,38.613701,1.1395408,449.41599,40.335918,someone else,1.8203658,25,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36496,2010,retired,41.733318,38.613701,1,449.41599,40.335918,someone else,private,20.245865,...,0,0,0,0,0,0,1,1,0,0
36497,2010,retired,41.733318,38.613701,1,449.41599,40.335918,someone else,private,20.245865,...,0,0,0,0,0,0,1,1,0,0
36498,2010,working fulltime,40.000000,38.613701,1.1395408,449.41599,40.335918,someone else,private,20.245865,...,0,0,0,1,0,0,0,1,0,0
36499,2010,working fulltime,49.000000,38.613701,1.1395408,449.41599,40.335918,someone else,private,20.245865,...,0,0,0,1,0,0,1,1,0,0


In [3]:
def cleanWelfare(welfare_raw):
    welfare = welfare_raw.copy()

    toClean = set(['commute', 'childs', 'age', 'preteen', 'adults', 'unrelat', 'earnrs'])
    toEncode = set(['year', 'occ'])

    encoders = {} # want a dictionary to keep track of columns and their encoded values so can decode when done

    for column in welfare:
        if '_' not in column:
            if column in toEncode:
                le = LabelEncoder()
                le.fit(welfare[column])
                welfare[column] = le.transform(welfare[column])
                encoders[column] = le 
            elif column in toClean:
                if column == 'commute':
                    welfare[column] = pd.to_numeric(welfare[column].apply(lambda x: 97 if x == '97+ minutes' else x), errors='coerce')
                elif column == 'childs' or column == 'earnrs':
                    welfare[column] = pd.to_numeric(welfare[column].apply(lambda x: 8 if x == 'eight or more' else x), errors='coerce')
                elif column == 'age':
                    welfare[column] = pd.to_numeric(welfare[column].apply(lambda x: 89 if x == '89 or older' else x), errors='coerce')
                elif column == 'preteen' or column == 'adults' or column =='unrelat':
                    welfare[column] = pd.to_numeric(welfare[column].apply(lambda x: 8 if x == '8 or more' else x), errors='coerce')
                else:
                    continue # should never reach here
                welfare[column] = normalize(welfare[column].values.reshape(1, -1))[0] # once column converted to float, normalize
            else:
                if welfare[column].dtype == welfare['teens'].dtype: # float64 column
                    welfare[column] = normalize(welfare[column].values.reshape(1, -1))[0]
                elif welfare[column].dtype == welfare['polviews'].dtype: # object column
                    le = LabelEncoder()
                    le.fit(welfare[column])
                    welfare[column] = le.transform(welfare[column])
                    encoders[column] = le 
                else:
                    continue # should never reach here
    return welfare, encoders

welfare, encoders = cleanWelfare(welfare_raw)
welfare

Unnamed: 0,year,wrkstat,hrs1,hrs2,evwork,occ,prestige,wrkslf,wrkgovt,commute,...,preteen_miss,teens_miss,adults_miss,unrelat_miss,earnrs_miss,income_miss,rincome_miss,income86_miss,partyid_miss,polviews_miss
0,0,7,0.004845,0.005228,1,135,0.005641,2,2,0.015315,...,0,0,0,0,0,0,0,0,0,0
1,0,1,0.005055,0.005228,0,106,0.006538,2,2,0.002552,...,0,0,0,1,0,0,1,0,0,0
2,0,7,0.004845,0.005228,1,99,0.006538,2,2,0.008934,...,0,0,0,1,0,0,0,0,0,0
3,0,3,0.005055,0.005228,0,142,0.004615,2,0,0.006381,...,0,0,0,0,0,0,1,0,0,0
4,0,8,0.005055,0.005228,1,211,0.005171,2,1,0.006381,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36496,15,3,0.005055,0.005228,0,211,0.005171,2,2,0.005168,...,0,0,0,0,0,0,1,1,0,0
36497,15,3,0.005055,0.005228,0,211,0.005171,2,2,0.005168,...,0,0,0,0,0,0,1,1,0,0
36498,15,7,0.004845,0.005228,1,211,0.005171,2,2,0.005168,...,0,0,0,1,0,0,0,1,0,0
36499,15,7,0.005935,0.005228,1,211,0.005171,2,2,0.005168,...,0,0,0,1,0,0,1,1,0,0


## Estimating ATE and its CI

In [22]:
from econml.dml import CausalForestDML

# set parameters for causal forest 
causal_forest = CausalForestDML(criterion='het', 
                                n_estimators=1000,       
                                min_samples_leaf=10, 
                                max_depth=None, 
                                max_samples=0.5,
                                discrete_treatment=True,
                                honest=True,
                                inference=True,
                                cv=5,
                                )

X_train, X_test, Y_train, Y_test, T_train, T_test = train_test_split(welfare, labels, treatments, test_size=0.2)

causal_forest.fit(Y_train, T_train, X=X_train, W=None)

<econml.dml.causal_forest.CausalForestDML at 0x7fe35692d150>

In [25]:
causal_forest.ate_inference(X_test)

mean_point,stderr_mean,zstat,pvalue,ci_mean_lower,ci_mean_upper
0.328,0.037,8.786,0.0,0.267,0.389

std_point,pct_point_lower,pct_point_upper
0.053,0.233,0.412

stderr_point,ci_point_lower,ci_point_upper
0.065,0.218,0.433


In [29]:
causal_forest.cate_treatment_names()

['w_1']

In [None]:
X_train, X_test, Y_train, Y_test, T_train, T_test = train_test_split(X, y, treatments, test_size=test_size)
        
# specify hyperparams of model
est = CausalForest(criterion='mse', n_estimators=1000,       
                        min_samples_leaf=1, 
                        max_depth=100, max_samples=0.5,
                        honest=True, inference=True)

# fit model
est.fit(X_train, T_train, Y_train)

In [4]:
def estimate_grf(y, X, treatments, test_size=0.2, criterion='mse'):
    # split data into train and test sets 
    X_train, X_test, Y_train, Y_test, T_train, T_test = train_test_split(X, y, treatments, test_size=test_size)
        
    # specify hyperparams of model
    est = CausalForest(criterion='mse', n_estimators=1000,       
                          min_samples_leaf=1, 
                          max_depth=100, max_samples=0.5,
                          honest=True, inference=True)

    # fit model
    est.fit(X_train, T_train, Y_train)
    
    ites, lbs, ubs = est.predict(X_test, interval=True)
    
    return est, ites, lbs, ubs, X_test

In [5]:
def aggregateITES(ites, lbs, ubs):
    ate = ites.mean()
    lb = lbs.mean()
    ub = ubs.mean()
    return ate, lb, ub

In [6]:
est_ATE, ites_ATE, lbs_ATE, ubs_ATE, X_test_ATE = estimate_grf(labels, welfare, treatments)
ate_ATE, lb_ATE, ub_ATE = aggregateITES(ites_ATE, lbs_ATE, ubs_ATE)
print("The test ATE is equal to:", ate_ATE)
print("With a 95% confidence interval of (" + str(lb_ATE) + ", " + str(ub_ATE) + ")")

The test ATE is equal to: 0.33067534013305305
With a 95% confidence interval of (0.22433458486886984,0.4370160953972363)


In [15]:
est_ATE.predict_moment(X_test_ATE)

AttributeError: 'CausalForest' object has no attribute 'predict_moment'

## CATES for Party Identification

In [7]:
# view all partyid types
welfare_raw['partyid'].value_counts()

not str democrat      7093
not str republican    5969
independent           5859
strong democrat       5666
ind,near dem          4110
strong republican     3880
ind,near rep          3196
other party            512
2.8216343              216
Name: partyid, dtype: int64

In [8]:
# view the corresponding numerical encodings for the cleaned data
welfare['partyid'].value_counts()

4    7093
5    5969
3    5859
7    5666
1    4110
8    3880
2    3196
6     512
0     216
Name: partyid, dtype: int64

In [11]:
partyid_le = encoders['partyid']
partyids = list(welfare_raw['partyid'].value_counts().index)[:-1] # drop last one because it's a nonmeaningful number
notStrongDem 
notStrongRep


['not str democrat',
 'not str republican',
 'independent',
 'strong democrat',
 'ind,near dem',
 'strong republican',
 'ind,near rep',
 'other party']

In [12]:
partyid_le = encoders['partyid']


In [12]:
est.feature_importances()

array([6.77124612e-02, 2.59356593e-02, 1.13682537e-02, 1.25588969e-07,
       8.25922436e-04, 5.16261503e-04, 6.18278268e-04, 5.02850348e-04,
       6.30179504e-04, 4.00157686e-05, 6.24363622e-03, 3.56612378e-03,
       3.50695068e-03, 3.75480525e-04, 1.63536528e-03, 3.17623538e-04,
       2.77305091e-04, 3.90356471e-03, 1.76013660e-03, 0.00000000e+00,
       6.69925600e-05, 1.07542026e-02, 1.76370612e-03, 5.79108434e-03,
       2.47369986e-03, 8.62069507e-04, 8.14868114e-03, 1.53541012e-03,
       5.54400837e-03, 2.47052016e-03, 1.73622849e-03, 1.39597599e-03,
       1.61799844e-03, 1.80106213e-03, 8.78475161e-04, 8.48296906e-03,
       2.26973683e-04, 5.07863760e-01, 1.37394636e-03, 1.25501719e-03,
       8.18551094e-04, 1.73133793e-03, 3.82325892e-04, 1.35844972e-04,
       4.43645648e-04, 9.64355088e-04, 1.26252755e-03, 7.92911847e-04,
       2.71508445e-04, 2.29415599e-04, 3.12208323e-04, 4.70882245e-04,
       2.58490127e-04, 2.53191884e-03, 6.65136911e-03, 1.18878390e-02,
      