# Double Machine Learning with CausalELM
In this notebook we will see how we can easily estimate average treatment effects with CausalELM using double machine larning. Double machine learning is a method for estimating treatment effects when the data is high dimensional or the functional form of the model is unknown. The idea is to estimate treatment and outcome models (that are ensembles of ELMs), use out of fold predictions to get residuals, and regress the outcome residuals on the treatment residuals in a final model. 

For more details see Chernozhukov, Victor, Denis Chetverikov, Mert Demirer, Esther Duflo, Christian Hansen, Whitney Newey, and James Robins. "Double/debiased machine learning for treatment and structural parameters." (2018): C1-C68.

## Setup
We will load CausalELM along with the necessary packages to read and manipulate the data. We will also set the seed here.

In [1]:
using CSV
using DataFrames
using CausalELM
using Random

In [2]:
Random.seed!(2024)

TaskLocalRNG()

# Data
This data comes from a 1994 study by Poterba et. al that examined the effect of 401(k) (pension plan in the US) eligibility on net financial assets. The dependent variable is net total financial assets, which is net_tfa in the dataframe. The treatment variable is an indicator denoting whether an individual was eligible to enroll in the 401(d) plan, which corresponds to the e401 column. The study contolled for age, income, family size, years of education, marital status, a two-earner status indicator, a defined benefit pension income indicator, an IRA paticipation indicator, and a home ownership indicator, which correspond to the age, inc, fsize, educ, marr, twoearn, db, pira, and hown columns. There are also one hot encoded versions of some of these variables, which we will not use.

In [3]:
pension_df = CSV.read("data/pension.csv", DataFrame)
pension_df = pension_df[:, [10, 22, 13, 14, 15, 18, 20, 17, 24, 33]]

Row,net_tfa,e401,age,inc,fsize,marr,twoearn,db,pira,hown
Unnamed: 0_level_1,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64
1,-3300,0,31,28146,5,1,0,0,0,1
2,61010,0,52,32634,5,0,0,0,0,1
3,8849,0,50,52206,3,1,1,0,1,1
4,-6013,0,28,45252,4,1,1,0,0,0
5,-2375,0,42,33126,3,0,0,1,0,1
6,-11000,0,49,76860,6,1,1,1,0,1
7,-16901,0,40,57477,4,1,1,1,0,1
8,1000,0,58,14637,1,0,0,0,0,0
9,0,0,29,6573,4,0,0,0,0,0
10,6400,0,50,5766,1,0,0,0,1,0


It's always a good idea to look at your data before you fit any models, and we can see the dataset below.

In [4]:
describe(pension_df)

Row,variable,mean,min,median,max,nmissing,eltype
Unnamed: 0_level_1,Symbol,Float64,Int64,Float64,Int64,Int64,DataType
1,net_tfa,18051.5,-502302,1499.0,1536798,0,Int64
2,e401,0.371357,0,0.0,1,0,Int64
3,age,41.0602,25,40.0,64,0,Int64
4,inc,37200.6,-2652,31476.0,242124,0,Int64
5,fsize,2.86586,1,3.0,13,0,Int64
6,marr,0.604841,0,1.0,1,0,Int64
7,twoearn,0.380837,0,0.0,1,0,Int64
8,db,0.271004,0,0.0,1,0,Int64
9,pira,0.242158,0,0.0,1,0,Int64
10,hown,0.635199,0,1.0,1,0,Int64


## Preprocessing
We need to separately pass in covariates, the treatment variable, and the outcome variable, so we will split them up here and normalize some of the covariates

In [5]:
covariates, treatment, outcome = pension_df[:, 3:end], pension_df[:, 2], pension_df[:, 1]

([1m9915×8 DataFrame[0m
[1m  Row [0m│[1m age   [0m[1m inc   [0m[1m fsize [0m[1m marr  [0m[1m twoearn [0m[1m db    [0m[1m pira  [0m[1m hown  [0m
      │[90m Int64 [0m[90m Int64 [0m[90m Int64 [0m[90m Int64 [0m[90m Int64   [0m[90m Int64 [0m[90m Int64 [0m[90m Int64 [0m
──────┼──────────────────────────────────────────────────────────
    1 │    31  28146      5      1        0      0      0      1
    2 │    52  32634      5      0        0      0      0      1
    3 │    50  52206      3      1        1      0      1      1
    4 │    28  45252      4      1        1      0      0      0
    5 │    42  33126      3      0        0      1      0      1
    6 │    49  76860      6      1        1      1      0      1
    7 │    40  57477      4      1        1      1      0      1
    8 │    58  14637      1      0        0      0      0      0
  ⋮   │   ⋮      ⋮      ⋮      ⋮       ⋮       ⋮      ⋮      ⋮
 9909 │    28  31926      2      1        1      

It is always a good idea in general, but especially with extreme learning machines, to normalize the covariates.

In [6]:
normalize(col) = (col .- minimum(col)) / (maximum(col) - minimum(col))

normalize (generic function with 1 method)

In [7]:
covariates = combine(
    covariates, 
    [:age, :inc, :fsize, :marr] .=> normalize, 
    :twoearn, 
    :db, 
    :pira, 
    :hown, 
    renamecols=false
)

Row,age,inc,fsize,marr,twoearn,db,pira,hown
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Int64,Int64,Int64,Int64
1,0.153846,0.125821,0.333333,1.0,0,0,0,1
2,0.692308,0.144156,0.333333,0.0,0,0,0,1
3,0.641026,0.224115,0.166667,1.0,1,0,1,1
4,0.0769231,0.195705,0.25,1.0,1,0,0,0
5,0.435897,0.146166,0.166667,0.0,0,1,0,1
6,0.615385,0.324836,0.416667,1.0,1,1,0,1
7,0.384615,0.245649,0.25,1.0,1,1,0,1
8,0.846154,0.0706319,0.0,0.0,0,0,0,0
9,0.102564,0.0376875,0.25,0.0,0,0,0,0
10,0.641026,0.0343906,0.0,0.0,0,0,1,0


# Fitting a model
Since this is a (kind of) higher dimensional model that may have nonlinearities, this is a good candidate for double machine learning. Fitting a model is very simple.

In [8]:
dml_model = DoubleMachineLearning(covariates, treatment, outcome)

DoubleMachineLearning([0.8461538461538461 0.09652907147759585 … 0.0 1.0; 0.6923076923076923 0.1634473968036082 … 0.0 1.0; … ; 0.9743589743589743 0.05228453769977449 … 0.0 1.0; 0.1794871794871795 0.11143249338170409 … 1.0 1.0], [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0  …  1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], [41499.0, 14825.0, -6400.0, -900.0, -16122.0, 350.0, 27029.0, 5400.0, 92384.0, -500.0  …  -200.0, 24654.0, 89509.0, 65950.0, 550.0, -1601.0, 55608.0, 3000.0, -1691.0, 12650.0], "ATE", false, "regression", CausalELM.swish, 9915, 50, 6, 24, NaN, 5)

In [9]:
estimate_causal_effect!(dml_model)

8721.139649868901

## Getting a summary
Just as in R, we can get a summary of the model and see basic information about it. Note that the p-value and standard error are null here. By default they are not calculated but you can calulating them by setting the inference keyword argument to true. Howver, since p-values and standard errors are calculated with randomization inference, this will take a long time.

In [10]:
# We can also use the British spelling summarise(dml)
summarize(dml_model)

Dict{Any, Any} with 11 entries:
  "Activation Function"    => swish
  "Quantity of Interest"   => "ATE"
  "Sample Size"            => 9915
  "Number of Machines"     => 50
  "Causal Effect"          => 8721.14
  "Number of Neurons"      => 24
  "Task"                   => "regression"
  "Time Series/Panel Data" => false
  "Standard Error"         => NaN
  "p-value"                => NaN
  "Number of Features"     => 6

## Model Validation
Often the most important part of estimating causal effects is making sure the model meets some assumptions or seeing how sensitive it iw to violations of those assumptions. Using the validate method we can easily see how sensitive our model is to violations of the coutnerfactual consistency assumption, hidden confounders, and verify that there is overlap between the treatment and control groups. The first object in the dictionary shows the average estimated causal effects for some simulated violations of the counterfactual consistency assumption. The second item is the E-value, which tells how much hidden confounding there would need to be to change the results. And finally, the last item is a matrix with observations that have a zero or near zero probability of being assigned to either the treatment or control group.

In [11]:
validate(dml_model)

(Dict("0.1 Standard Deviations from Observed Outcomes" => 7903.440095883949, "0.075 Standard Deviations from Observed Outcomes" => 7564.405940233361, "0.025 Standard Deviations from Observed Outcomes" => 8810.321764706268, "0.05 Standard Deviations from Observed Outcomes" => 7793.5616595259535), 2.505410444617261, Matrix{Float64}(undef, 0, 9))