# Doubly Robust Estimation with CausalELM
In this notebook we will see how we can easily estimate conditional average treatment effect with CausalELM using doubly robust estimation. Doubly robust estimation is a method for estimating treatment effects when the data is high dimensional or the functional form of the model is unknown. Like other metalearners it models the outcome and propensity scores and then combines them in a final model. However, unlike other estimators, it predicts outcomes for the treatment and control groups separately. The main benefit is that if either the functional form of the outcomes or the propensity score model is correct then it will generate accurate predictions of the CATE. 

For more details see Kennedy, Edward H. "Towards optimal doubly robust estimation of heterogeneous causal effects." Electronic Journal of Statistics 17, no. 2 (2023): 3008-3049.

## Setup
We will load the necessary packages and set the seed below.

In [1]:
using CSV
using DataFrames
using CausalELM
using Random

In [2]:
Random.seed!(2024)

TaskLocalRNG()

## Data
The data we are going to use is from a 2019 study by Jason Lyall on how payments by the US Agancy for International Development to civilian victims of the war in Afghanistan effected the number of Taliban attacks. Each observation represents a village, the treatment represents whetehr a case of civilian victimization was found eligible for compensation in that village, and the outcome we will look at is the number of Taliban attacks in a village 30 days after being given or denied compensation.  There are several covariates for geographical, social, ethnic, and economic charateristics, making this a good candidate for doubly robust estimation.

In [3]:
civcas_df = CSV.read("data/CIVCAS.tab", DataFrame)

Row,acapproject,eligible,date,day,month,year,pplid,pplname,latitude,longitude,distid,distname,provid,provname,altitude,population,logpop,logelev,langcode,language,pashto,neighbors,cdcfamilies,cdcprojects,cdcpercapita,killed,wounded,combined,killeddum,woundeddum,property,incidenttype,airstrike,cfmilop,crossfire,milop,sb,ied,accident,eof,indirect,cfindirect,isafinit,isafresp,talibresp,pre7a,preied7a,post7a,postied7a,preisaf7a,diff7a,diffied7a,pre7b,preied7b,post7b,postied7b,diff7b,diffied7b,pre7c,preied7c,post7c,postied7c,diff7c,diffied7c,pre30a,preied30a,post30a,postied30a,preisaf30a,diff30a,diffied30a,pre30b,preied30b,post30b,postied30b,diff30b,diffied30b,pre30c,preied30c,post30c,postied30c,diff30c,diffied30c,pre90a,preied90a,post90a,postied90a,preisaf90a,diff90a,diffied90a,pre90b,preied90b,post90b,postied90b,diff90b,diffied90b,pre90c,preied90c,post90c,postied90c,⋯
Unnamed: 0_level_1,Int64,Int64,String15,Int64,Int64,Int64,Int64,String,Float64,Float64,Int64,String31,Int64,String15,Int64,Int64,Float64,Float64,Int64,String15?,Int64,Int64,Int64,Int64,Float64,Int64,Int64,Int64,Int64,Int64,Int64,String31,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64?,Int64?,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,⋯
1,1,1,6/5/12,5,6,2012,22515,NOW BOR,35.491,63.359,1906,Bala Murghab,19,Badghis,661,583,6.36819,6.49375,2,Pashto,1,8,0,0,0.0,8,11,19,1,1,0,Airstrike,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,2,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,5,0,2,0,1,-3,0,0,0,0,0,0,0,0,0,0,0,⋯
2,2,1,1/13/13,13,1,2013,2441,AZIZAN,35.0526,63.2788,1907,Muqur,19,Badghis,957,338,5.82305,6.8638,2,Pashto,1,14,0,0,0.0,1,0,1,1,0,0,CF Military Operation,0,1,0,0,0,0,0,0,0,0,1,1,0,1,0,0,0,0,-1,0,1,0,0,0,-1,0,0,0,0,0,0,0,1,0,0,0,0,-1,0,1,0,0,0,-1,0,0,0,0,0,0,0,1,0,3,1,0,2,1,1,0,0,0,-1,0,0,0,0,0,⋯
3,3,1,1/13/13,13,1,2013,26757,RASHID KHAN,35.5026,63.3748,1906,Bala Murghab,19,Badghis,522,234,5.45532,6.25767,2,Pashto,1,8,0,0,0.0,0,8,8,0,1,1,IED,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6,1,0,6,1,2,1,5,0,3,-1,0,0,0,0,⋯
4,4,1,5/9/13,9,5,2013,22155,NOOR KHAIL ABDUL WAHAB,35.0397,63.2892,1907,Muqur,19,Badghis,1005,599,6.39526,6.91274,2,Pashto,1,13,150,1,300.0,0,1,1,0,1,0,CF Military Operation,0,1,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,0,-1,0,0,0,0,0,0,0,3,1,9,1,0,6,0,1,0,0,0,-1,0,0,0,0,0,⋯
5,5,1,6/10/13,10,6,2013,37982,ABLA-I-MIRANZA'I,35.0477,63.296,1907,Muqur,19,Badghis,929,0,7.08339,6.83411,1,Dari,0,0,0,0,0.0,1,1,2,1,1,1,IED,0,0,0,0,0,1,0,0,0,0,0,0,1,2,0,1,0,0,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,0,11,2,0,9,2,1,0,0,0,-1,0,0,0,0,0,⋯
6,6,1,6/10/13,10,6,2013,15503,KHAL ZARDAK,35.0502,63.3538,1907,Muqur,19,Badghis,1031,1181,7.07412,6.93828,2,Pashto,1,3,0,0,0.0,1,0,1,1,0,1,IED,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,-1,-1,0,0,0,0,0,0,0,0,0,0,⋯
7,7,1,8/21/13,21,8,2013,35395,QALA-E-NAW,34.9874,63.1275,1901,Qala-I- Naw,19,Badghis,962,9000,9.10498,6.86901,1,Dari,0,14,0,0,0.0,0,1,1,0,1,0,CF Military Operation,0,1,0,0,0,0,0,0,0,0,1,1,0,1,1,0,0,0,-1,-1,0,0,0,0,0,0,0,0,0,0,0,0,2,1,0,0,0,-2,-1,0,0,0,0,0,0,0,0,0,0,0,0,6,2,2,0,0,-4,-2,1,0,1,0,0,0,0,0,1,0,⋯
8,8,1,7/21/13,21,7,2013,34275,YOMAL,36.8087,71.0694,1125,Warduj,11,Badakhshan,2171,667,6.50279,7.68294,1,Dari,0,6,0,0,0.0,2,0,2,1,0,1,Airstrike,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,4,0,0,0,0,-4,0,0,0,0,0,0,0,0,0,0,0,⋯
9,9,1,10/6/12,6,10,2012,35092,BAGHLAN-E-JADID,36.1221,68.6861,1311,Baghlani Jadid,13,Baghlan,657,56200,10.9367,6.48768,2,Pashto,1,14,0,0,0.0,1,0,1,1,0,0,Accident,0,0,0,0,0,0,1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,-1,0,0,0,0,0,0,0,0,0,0,0,0,0,4,1,2,1,0,-2,0,0,0,0,0,0,0,0,0,0,0,⋯
10,10,1,11/26/12,26,11,2012,9518,GARDAB MUGHUL,36.3294,68.841,1311,Baghlani Jadid,13,Baghlan,569,411,6.01859,6.34388,2,Pashto,1,19,0,0,0.0,0,4,4,0,1,0,Accident,0,0,0,0,0,0,1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,-1,-1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,-1,0,0,0,0,0,0,0,0,0,0,⋯


In [4]:
covariates = select(
    civcas_df, 
    [
        :acapproject, 
        :distid, 
        :provid,
        :latitude, 
        :longitude,  
        :altitude, 
        :population, 
        :langcode, 
        :neighbors, 
        :cdcfamilies, 
        :killed, 
        :wounded, 
        :property, 
        :airstrike, 
        :cfmilop, 
        :crossfire, 
        :milop, 
        :sb, 
        :ied, 
        :accident, 
        :eof, 
        :indirect, 
        :cfindirect, 
        :isafinit,  
        :kabul, 
        :pakistan, 
        :dcdistkm, 
        :basedistkm, 
        :fightingseason, 
        :beneficiaries, 
        :north, 
        :south, 
        :southwest, 
        :west, 
        :kunar, 
        :helmand, 
        :khost, 
        :logar, 
        :east, 
        :logroadkm, 
        :timesaid
    ]
)

Row,acapproject,distid,provid,latitude,longitude,altitude,population,langcode,neighbors,cdcfamilies,killed,wounded,property,airstrike,cfmilop,crossfire,milop,sb,ied,accident,eof,indirect,cfindirect,isafinit,kabul,pakistan,dcdistkm,basedistkm,fightingseason,beneficiaries,north,south,southwest,west,kunar,helmand,khost,logar,east,logroadkm,timesaid
Unnamed: 0_level_1,Int64,Int64,Int64,Float64,Float64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Float64,Int64,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Int64
1,1,1906,19,35.491,63.359,661,583,2,8,0,8,11,0,1,0,0,0,0,0,0,0,0,0,1,0,0,13.01,12396,1,41,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,3.9925,1
2,2,1907,19,35.0526,63.2788,957,338,2,14,0,1,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,2.281,3811,0,7,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,-4.60517,1
3,3,1906,19,35.5026,63.3748,522,234,2,8,0,0,8,1,0,0,0,0,0,1,0,0,0,0,0,0,0,12.144,11740,0,61,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,3.9925,1
4,4,1907,19,35.0397,63.2892,1005,599,2,13,150,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0.27,5824,1,12,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,-4.60517,1
5,5,1907,19,35.0477,63.296,929,0,1,0,0,1,1,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1.22,5194,1,6,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,-4.60517,1
6,6,1907,19,35.0502,63.3538,1031,1181,2,3,0,1,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1.443,4517,1,12,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,-4.60517,1
7,7,1901,19,34.9874,63.1275,962,9000,1,14,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,2.697,3027,1,8,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,4.1281,1
8,8,1125,11,36.8087,71.0694,2171,667,1,6,0,2,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,14.167,15136,1,1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-4.60517,1
9,9,1311,13,36.1221,68.6861,657,56200,2,14,0,1,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,2.204,10214,0,20,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.30244,1
10,10,1311,13,36.3294,68.841,569,411,2,19,0,0,4,0,0,0,0,0,0,0,1,0,0,0,1,0,0,24.247,23415,0,25,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.30244,1


In [5]:
treatment, outcome = select(civcas_df, :eligible), select(civcas_df, :diff30a)

([1m1061×1 DataFrame[0m
[1m  Row [0m│[1m eligible [0m
      │[90m Int64    [0m
──────┼──────────
    1 │        1
    2 │        1
    3 │        1
    4 │        1
    5 │        1
    6 │        1
    7 │        1
    8 │        1
  ⋮   │    ⋮
 1055 │        0
 1056 │        0
 1057 │        0
 1058 │        0
 1059 │        0
 1060 │        0
 1061 │        0
[36m1046 rows omitted[0m, [1m1061×1 DataFrame[0m
[1m  Row [0m│[1m diff30a [0m
      │[90m Int64   [0m
──────┼─────────
    1 │       1
    2 │      -1
    3 │       0
    4 │       1
    5 │       0
    6 │       0
    7 │      -2
    8 │      -1
  ⋮   │    ⋮
 1055 │      -4
 1056 │       1
 1057 │      -1
 1058 │      -1
 1059 │      -2
 1060 │       2
 1061 │      -3
[36m1046 rows omitted[0m)

It's always a good idea to look at the summary statistics of your data before doing any modeling, and we can see some descriptive statistics below.

In [6]:
describe(covariates)

Row,variable,mean,min,median,max,nmissing,eltype
Unnamed: 0_level_1,Symbol,Float64,Real,Float64,Real,Int64,DataType
1,acapproject,531.0,1,531.0,1061,0,Int64
2,distid,1476.31,101,1014.0,3307,0,Int64
3,provid,14.7003,1,10.0,33,0,Int64
4,latitude,33.9036,30.5517,33.999,37.1696,0,Float64
5,longitude,68.2909,61.2076,69.0311,71.537,0,Float64
6,altitude,1381.67,307,1266.0,2965,0,Int64
7,population,46066.3,0,1148.0,2536300,0,Int64
8,langcode,2.00848,1,2.0,9,0,Int64
9,neighbors,23.8511,0,22.0,91,0,Int64
10,cdcfamilies,82.6277,0,0.0,988,0,Int64


In [7]:
describe(outcome)

Row,variable,mean,min,median,max,nmissing,eltype
Unnamed: 0_level_1,Symbol,Float64,Int64,Float64,Int64,Int64,DataType
1,diff30a,0.189444,-20,0.0,76,0,Int64


## Preprocessing
Before we can estimate the CATE we need to do a little bit or preprocessing. First, we will recode the district and province IDs.

In [8]:
covariates.distid = covariates.distid .- minimum(covariates.distid)
covariates.provid = covariates.provid .- minimum(covariates.provid)

1061-element Vector{Int64}:
 18
 18
 18
 18
 18
 18
 18
 10
 12
 12
  ⋮
 24
 24
 24
 24
 24
 24
 24
 24
 24

We also want our variables to be normalized but the variable for the number of raods in a village is logged, so we will exponentiate it to be able to normalize it later.

In [9]:
covariates.logroadkm = exp.(covariates.logroadkm)

1061-element Vector{Float64}:
 54.19000598376728
  0.009999999859880915
 54.19000598376728
  0.009999999859880915
  0.009999999859880915
  0.009999999859880915
 62.060021161207885
  0.009999999859880915
 73.87998852414373
 73.87998852414373
  ⋮
  0.009999999859880915
  0.009999999859880915
 60.27000195907922
 81.48000885210132
  0.009999999859880915
  0.009999999859880915
 60.27000195907922
 60.27000195907922
  0.009999999859880915

Now we can normalize our data.

In [10]:
normalize(col) = (col .- minimum(col)) / (maximum(col) - minimum(col))

normalize (generic function with 1 method)

In [11]:
vars_to_normalize = [
    :acapproject,
    :altitude, 
    :population, 
    :neighbors, 
    :cdcfamilies, 
    :killed, 
    :wounded, 
    :dcdistkm, 
    :basedistkm, 
    :beneficiaries, 
    :logroadkm
]

covariates = combine(
    covariates, 
    vars_to_normalize .=> normalize, 
    [sym for sym ∈ propertynames(covariates) if sym ∉ vars_to_normalize]...,
    renamecols=false
)

Row,acapproject,altitude,population,neighbors,cdcfamilies,killed,wounded,dcdistkm,basedistkm,beneficiaries,logroadkm,distid,provid,latitude,longitude,langcode,property,airstrike,cfmilop,crossfire,milop,sb,ied,accident,eof,indirect,cfindirect,isafinit,kabul,pakistan,fightingseason,north,south,southwest,west,kunar,helmand,khost,logar,east,timesaid
Unnamed: 0_level_1,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Int64,Int64,Float64,Float64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Int64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Int64
1,0.0,0.133183,0.000229862,0.0879121,0.0,0.177778,0.0709677,0.186002,0.206363,0.00916816,0.260895,1805,18,35.491,63.359,2,0,1,0,0,0,0,0,0,0,0,0,1,0,0,1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1
2,0.000943396,0.244545,0.000133265,0.153846,0.0,0.0222222,0.0,0.0325285,0.0630623,0.0015653,0.0,1806,18,35.0526,63.2788,2,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1
3,0.00188679,0.0808879,9.22604e-5,0.0879121,0.0,0.0,0.0516129,0.173614,0.195413,0.0136404,0.260895,1805,18,35.5026,63.3748,2,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1
4,0.00283019,0.262603,0.000236171,0.142857,0.151822,0.0,0.00645161,0.00376209,0.0966633,0.00268336,0.0,1806,18,35.0397,63.2892,2,0,0,1,0,0,0,0,0,0,0,0,1,0,0,1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1
5,0.00377358,0.234011,0.0,0.0,0.0,0.0222222,0.00645161,0.0173514,0.0861473,0.00134168,0.0,1806,18,35.0477,63.296,1,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1
6,0.00471698,0.272385,0.000465639,0.032967,0.0,0.0222222,0.0,0.0205413,0.0748469,0.00268336,0.0,1806,18,35.0502,63.3538,2,1,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1
7,0.00566038,0.246426,0.00354848,0.153846,0.0,0.0,0.00645161,0.0384791,0.0499758,0.00178891,0.298791,1800,18,34.9874,63.1275,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1
8,0.00660377,0.701279,0.000262982,0.0659341,0.0,0.0444444,0.0,0.202552,0.252099,0.000223614,0.0,1024,10,36.8087,71.0694,1,1,1,0,0,0,0,0,0,0,0,0,1,0,0,1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
9,0.00754717,0.131678,0.0221583,0.153846,0.0,0.0222222,0.0,0.031427,0.169941,0.00447227,0.355708,1210,12,36.1221,68.6861,2,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
10,0.00849057,0.0985704,0.000162047,0.208791,0.0,0.0,0.0258065,0.346741,0.390292,0.00559034,0.355708,1210,12,36.3294,68.841,2,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1


## Fitting a Model
We finally get to the fun part.

In [12]:
first_dr_model = DoublyRobustLearner(covariates, treatment, outcome)

DoublyRobustLearner([0.5443396226415095 0.754326561324304 … 1.0 1.0; 0.4660377358490566 0.31339352896914974 … 1.0 1.0; … ; 0.559433962264151 0.07938299473288186 … 0.0 0.0; 0.8981132075471698 0.24341610233258087 … 1.0 0.0], [1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0  …  0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], [1.0, 2.0, 0.0, 5.0, 21.0, 0.0, -3.0, -4.0, -3.0, 3.0  …  1.0, 5.0, 7.0, -7.0, 0.0, 0.0, 2.0, 4.0, 0.0, -5.0], "CATE", false, "regression", CausalELM.swish, 1061, 50, 31, 124, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN  …  NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 2)

Something is off here. The predicted CATEs are way outside the range of the outcomes. 

In [13]:
estimate_causal_effect!(first_dr_model)

1061-element Vector{Float64}:
      4.0173501192016995e8
      3.958240955318924e8
      1.166402131685069e8
      1.0744118726707132e9
      7.919148311275198e8
      4.685951678156178e8
      1.9973323646165748e9
      5.548531393866843e7
      4.606785323248733e8
      4.054043995605459e9
      ⋮
      6.791247513052346e7
      1.972344239783436e8
      4.2882865279086566e8
 878640.1310400628
      6.350807173502354e7
      5.481038181720447e11
     -2.2168463430087344e12
      3.084744895070118e6
      2.32223197328209e7

Maybe we can add some more machines to our underlying ELM ensemble, consider more features for each machine, or change the number of neurons in each machine. Maybe we can write a function to do cross validation and hopefully find a much better model. But we don't actually need to do this. Since extreme learning machines use randomized weights instead of gradient descent, seeing predicted values way outside the range of what you would expect is a strong indicator that the post-activation weights are too large. This is why CausalELM has 13 built-in activation functions—to either constrain the output weights. CasualELM uses the swish activation as a default but in this case, we need a more agressive activation function to attenuate our weights. A good place to start is with the fourier or binary step function.

In [14]:
better_dr_model = DoublyRobustLearner(covariates, treatment, outcome, activation=fourier)

DoublyRobustLearner([0.13018867924528302 0.41723100075244546 … 1.0 2.0; 0.8122641509433962 0.32091798344620015 … 1.0 0.0; … ; 0.16226415094339622 0.49435665914221216 … 1.0 1.0; 0.6566037735849056 0.2584650112866817 … 0.0 0.0], [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0  …  0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], [-3.0, 1.0, -2.0, -1.0, -4.0, 0.0, 0.0, -2.0, 13.0, 4.0  …  1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 4.0, -2.0, 6.0, -3.0], "CATE", false, "regression", CausalELM.fourier, 1061, 50, 31, 124, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN  …  NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 2)

In [15]:
estimate_causal_effect!(better_dr_model)

1061-element Vector{Float64}:
 -1.1248208493165057
 -0.33658630048075666
  0.5502186364208397
  0.44192281327345584
 -0.5255426352201444
 -0.012783507830087582
 -0.7809308817892282
 -1.3024783381733798
 -0.7028089822663802
 -0.1506935580311216
  ⋮
 -0.8777671389587851
 -0.9401054875993395
  1.0027280715847917
  0.6316318890842755
 -0.7387045473829954
 -0.6970800035383887
  1.1510660707945082
  0.2809723136292491
  0.1051366508528884

## Getting a summary
Just as in R, we can get a summary of the model and see basic information about it. Note that the p-value and standard error are null here. By default they are not calculated but you can calulating them by setting the inference keyword argument to true. Howver, since p-values and standard errors are calculated with randomization inference, this will take a long time.

In [16]:
summarize(better_dr_model)

Dict{Any, Any} with 11 entries:
  "Activation Function"    => fourier
  "Quantity of Interest"   => "CATE"
  "Sample Size"            => 1061
  "Number of Machines"     => 50
  "Causal Effect"          => [-1.12482, -0.336586, 0.550219, 0.441923, -0.5255…
  "Number of Neurons"      => 124
  "Task"                   => "regression"
  "Time Series/Panel Data" => false
  "Standard Error"         => NaN
  "p-value"                => NaN
  "Number of Features"     => 31

## Model Validation
Often the most important part of estimating causal effects is making sure the model meets some assumptions or seeing how sensitive it iw to violations of those assumptions. Using the validate method we can easily see how sensitive our model is to violations of the coutnerfactual consistency assumption, hidden confounders, and verify that there is overlap between the treatment and control groups. The first object in the dictionary shows the average estimated causal effects for some simulated violations of the counterfactual consistency assumption. The second item is the E-value, which tells how much hidden confounding there would need to be to change the results. And finally, the last item is a matrix with observations that have a zero or near zero probability of being assigned to either the treatment or control group.

In [17]:
validate(better_dr_model)

(Dict("0.1 Standard Deviations from Observed Outcomes" => 0.39944244256073913, "0.075 Standard Deviations from Observed Outcomes" => 0.6241976412639244, "0.025 Standard Deviations from Observed Outcomes" => -0.39599066747960493, "0.05 Standard Deviations from Observed Outcomes" => -0.4523769820827194), 6.204438360546767, Matrix{Float64}(undef, 0, 42))