In [1]:
using CSV
using Statistics
using Distributions
using Optim
using ForwardDiff
using Random

In [2]:
using IJulia

In [3]:
struct VertexLabels
  u1::Vector{Int}  # Latent variables, parents of colliders Z1 and outcome Y
  u2::Vector{Int}  # Latent variables, parents of colliders Z1 and treatment X
  w::Vector{Int}       # Instrument
  x::Vector{Int}       # Treatment
  y::Vector{Int}       # Outcome
  z1::Vector{Int}  # Colliders between X and Y
  z3::Vector{Int}  # Confounders between X and Y
end

In [4]:
# graph definitions
w = collect(1:1)
x = collect(2:2)
y = collect(3:3)
u1 = collect(4:8)
u2 = collect(9:13)
z1 = collect(14:18)
z3 = collect(19:25)
vertex_labels = VertexLabels(u1, u2, w, x, y, z1, z3)
num_vertices = 25
n = 90000

90000

## helper function from app_linear.jl

In [5]:
function choose_lambdas(lambda_twos, lambda1, seeds, dat, vertex_labels, max_iter, lrs, corr_boost, path_to_file)
  initial_lambda1 = lambda1
  corr_pxs = Dict() 
  hypo_fail_all_lambdas = Dict()
  num_Z = length(vertex_labels.z1)+length(vertex_labels.z3)
#   ATE_results_test = nothing


  w, x, y = vertex_labels.w, vertex_labels.x, vertex_labels.y
  Z = [vertex_labels.z1; vertex_labels.z3]
  best_Z = vertex_labels.z3

  split = trunc(Int,(size(dat)[1]/3))
  dat_train = dat[1:split,:]
  dat_valid = dat[split+1:split*2,:]
  dat_test = dat[(split*2)+1:end,:]
  # Sigma train construction
  Sigma_hat = cov(dat_train)
 for seed in seeds
    Random.seed!(seed);
  for lambda2 in lambda_twos
    println("lambda2: ", lambda2)
    for lr in lrs
      hypo_fail = Dict()
      println("lr: ", lr)
      reject_null_hypothesis = false
      lambda1 = initial_lambda1
      bonferonni_correction = 1.
      while (!reject_null_hypothesis)

        # learn theta on training set
        theta_hat, corr_px_theta_hat, corr_p_theta_hat = 
            bd_learn_linear(Sigma_hat, lambda1, lambda2, vertex_labels, 1, path_to_file,
                                    max_iter = max_iter, lr = lr, corr_boost = corr_boost)

        thresh = 1e-3
        sel_Z = findall(x->abs(x)>thresh, theta_hat)
        sel_Z = [x+Z[1]-1 for x in sel_Z]
        println(sel_Z)
        # test for the null hypothesis
        reject_null_hypothesis = 
          ind_null_hypo(n, num_Z, corr_p_theta_hat; significance_level=(0.01/bonferonni_correction))
        # if null rejected, i.e. reject_null_hypothesis=true
        if reject_null_hypothesis
          # compute ATEs training set
          Sigma_wyx_phi = build_phi_covariance(theta_hat, w, y, x, Z, Sigma_hat)
        try   
          atr_real_train = 0.
          ate_hat_train = (Sigma_wyx_phi[[3; 4], [3; 4]] \ Sigma_wyx_phi[[3; 4], 2])[1]
          ate_hat_all_Z_train = (Sigma_hat[[x; Z], [x; Z]] \ Sigma_hat[[x; Z], y])[1]
          ate_hat_best_Z_train = (Sigma_hat[[x; best_Z], [x; best_Z]] \ Sigma_hat[[x; best_Z], y])[1]
          ate_hat_sel_Z_train = (Sigma_hat[[x; sel_Z], [x; sel_Z]] \ Sigma_hat[[x; sel_Z], y])[1]
          ate_hat_marg_Z_train = Sigma_hat[x, y] / Sigma_hat[x, x]

          ATE_results_train = [atr_real_train ate_hat_best_Z_train ate_hat_train ate_hat_all_Z_train ate_hat_sel_Z_train ate_hat_marg_Z_train]
          # ATE_results_train = [ate_hat_best_Z_train ate_hat_train ate_hat_all_Z_train ate_hat_sel_Z_train ate_hat_marg_Z_train]

          # compute \rho(W,Y|beta*Z, X) validation
          Sigma_hat_valid = cov(dat_valid)
          Sigma_wyx_phi_valid = build_phi_covariance(theta_hat, w, 
            x, y, Z, Sigma_hat_valid)
   
            corr_px_valid = abs(partial_corr([1; 2], [3; 4], Sigma_wyx_phi_valid)[1, 2])
            # compute ATEs validation set
            atr_real_valid = 0.
            ate_hat_valid = (Sigma_wyx_phi_valid[[3; 4], [3; 4]] \ Sigma_wyx_phi_valid[[3; 4], 2])[1]
            ate_hat_all_Z_valid = (Sigma_hat_valid[[x; Z], [x; Z]] \ Sigma_hat_valid[[x; Z], y])[1]
            ate_hat_best_Z_valid = (Sigma_hat_valid[[x; best_Z], [x; best_Z]] \ Sigma_hat_valid[[x; best_Z], y])[1]
            ate_hat_sel_Z_valid = (Sigma_hat_valid[[x; sel_Z], [x; sel_Z]] \ Sigma_hat_valid[[x; sel_Z], y])[1]
            println(ate_hat_sel_Z_valid)
            ate_hat_marg_Z_valid = Sigma_hat_valid[x, y] / Sigma_hat_valid[x, x]

            ATE_results_valid = [atr_real_valid ate_hat_best_Z_valid ate_hat_valid ate_hat_all_Z_valid ate_hat_sel_Z_valid ate_hat_marg_Z_valid]
            # ATE_results_valid = [ate_hat_best_Z_valid ate_hat_valid ate_hat_all_Z_valid ate_hat_sel_Z_valid ate_hat_marg_Z_valid]
            # save corrs
            hypo_fail["seed"] = seed
            hypo_fail["corr_px_train"] = corr_px_theta_hat
            hypo_fail["corr_p_train"] = corr_p_theta_hat
            hypo_fail["corr_px_valid"] = corr_px_valid
            hypo_fail["lr"] = lr
            # parameters
            hypo_fail["theta_hat"] = theta_hat
            hypo_fail["sel_Z"] = sel_Z
            hypo_fail["lambda1"] = lambda1
            hypo_fail["lambda2"] = lambda2
            # ATEs train
            hypo_fail["ATE_results_train"] = ATE_results_train
            # ATEs valid
            hypo_fail["ATE_results_valid"] = ATE_results_valid
            hypo_fail_all_lambdas[hypo_fail["corr_px_valid"]] = hypo_fail
            break
          catch
            lambda1 = lambda1 * 2
            bonferonni_correction += 1.
          end
        else
          lambda1 = lambda1 * 2
          bonferonni_correction += 1.
        end
      end
    end
  end
 end
  best_setup = hypo_fail_all_lambdas[minimum(keys(hypo_fail_all_lambdas))]
    Sigma_hat_test = cov(dat_test)
    Sigma_wyx_phi_test = build_phi_covariance(best_setup["theta_hat"], w, x, y, Z, Sigma_hat_test)
    # compute ATEs test set
    atr_real_test = 0.
    ate_hat_test = (Sigma_wyx_phi_test[[3; 4], [3; 4]] \ Sigma_wyx_phi_test[[3; 4], 2])[1]
    ate_hat_all_Z_test = (Sigma_hat_test[[x; Z], [x; Z]] \ Sigma_hat_test[[x; Z], y])[1]
    ate_hat_best_Z_test = (Sigma_hat_test[[x; best_Z], [x; best_Z]] \ Sigma_hat_test[[x; best_Z], y])[1]
    ate_hat_sel_Z_test = (Sigma_hat_test[[x; best_setup["sel_Z"]], [x; best_setup["sel_Z"]]] \ Sigma_hat_test[[x; best_setup["sel_Z"]], y])[1]

    println(ate_hat_sel_Z_test)
    ate_hat_marg_Z_test = Sigma_hat_test[x, y] / Sigma_hat_test[x, x]

    ATE_results_test = [atr_real_test ate_hat_best_Z_test ate_hat_test ate_hat_all_Z_test ate_hat_sel_Z_test ate_hat_marg_Z_test]
  return best_setup, ATE_results_test
end

choose_lambdas (generic function with 1 method)

### NHS data

In [6]:
# read dataset in
dat_small_var = zeros(90000, num_vertices);
m = 1
for row in CSV.Rows("../../code/real_world_nhs_dataset/nhs_data_smaller_var.csv", datarow=2)
    dat_small_var[m,:] = [parse(Float64, x) for x in row[2:end]]
    m+=1
end

In [7]:
# compute Entner baseline for read dataset
split = trunc(Int,(size(dat_small_var)[1]/2))
dat_train = dat_small_var[1:split,:]
dat_valid = dat_small_var[split+1:end,:];

In [8]:
IJulia.installkernel("Julia nodeps", "--depwarn=no")
include("../../code/real_world_nhs_dataset/simulate.jl")
include("../../code/real_world_nhs_dataset/learn_linear.jl")
include("../../code/real_world_nhs_dataset/util.jl")

┌ Info: Installing Julia nodeps kernelspec in /Users/lgultchin/Library/Jupyter/kernels/julia-nodeps-1.2
└ @ IJulia /Users/lgultchin/.julia/packages/IJulia/fRegO/deps/kspec.jl:78
┌ Info: Loading DataFrames support into Gadfly.jl
└ @ Gadfly /Users/lgultchin/.julia/packages/Gadfly/09PWZ/src/mapping.jl:228


ATE_learned

In [9]:
# computereal_world_nhs_dataset/posed algorithm on nhs dataset
lambda1=5e-4
lambda_twos=[1, 1e-1, 1e-2]
lrs=[5e-5, 2e-5, 5e-4]
seeds = [20, 40, 60]
max_iter = 1000
corr_boost=1
best_setup, ATEs = choose_lambdas(lambda_twos, lambda1, seeds, dat_small_var, vertex_labels, max_iter, lrs, corr_boost, "../../code/real_world_nhs_dataset")

lambda2: 1.0
lr: 5.0e-5
[14, 15, 20, 21, 22, 23, 24]
[15, 16, 17, 18, 19, 20, 21, 22, 23]
-0.1023688604165896
lr: 2.0e-5
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
-0.27396984822314113
lr: 0.0005
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
-0.27396984822314113
lambda2: 0.1
lr: 5.0e-5
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
-0.27396984822314113
lr: 2.0e-5
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
-0.27396984822314113
lr: 0.0005
[14, 16, 17, 19, 21, 22, 23, 24, 25]
2.1707132667199054
lambda2: 0.01
lr: 5.0e-5
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
9.262140018745267
lr: 2.0e-5
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
-0.27396984822314113
lr: 0.0005
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
-0.27396984822314113
lambda2: 1.0
lr: 5.0e-5
[16, 17, 20, 21, 23, 24]
3.6695709510196775
lr: 2.0e-5
[14, 15, 16, 18, 19, 20, 22, 23, 24, 25]
1.8581889336209754
lr: 0.0005
[14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
[14, 15, 16, 17, 18, 19, 20, 21, 22, 

(Dict{Any,Any}("lr" => 5.0e-5,"ATE_results_valid" => [0.0 0.036316494492162635 … -0.1023688604165896 1.7733204733140853],"ATE_results_train" => [0.0 0.03612441828304275 … -0.11833420333802329 1.7815848263285687],"sel_Z" => [15, 16, 17, 18, 19, 20, 21, 22, 23],"lambda2" => 1.0,"lambda1" => 0.001,"corr_px_valid" => 0.5912905368409451,"theta_hat" => [-3.6727527729830745e-6, 0.545617219575932, 0.05334283112597428, -0.019592604357356536, 0.2898270188110904, -0.08246154060627756, -0.18436372081037034, -0.02816810113084275, -0.042175589063625625, 0.3077244351998149, -6.32423222191002e-5, 5.017706241915715e-5],"seed" => 20,"corr_p_train" => 0.11575514640347653…), [0.0 0.03647508648596374 … -0.12664746354762166 1.783437366524365])

### process results

In [10]:
best_setup["lr"],best_setup["lambda2"]

(5.0e-5, 1.0)

In [11]:
best_setup["ATE_results_valid"]

1×6 Array{Float64,2}:
 0.0  0.0363165  0.331118  -0.27397  -0.102369  1.77332

In [12]:
# These are ATEs computed on test set
ATEs

1×6 Array{Float64,2}:
 0.0  0.0364751  0.333002  -0.287561  -0.126647  1.78344

In [13]:
ATEs_results = []
append!(ATEs_results, ATEs)
# append Entner baseline ATE err
append!(ATEs_results, -0.7343839853467186)
# This is the pre-computed real ATE, which we will use in the next two cells for ATE_err figures presented in Table 1 in paper
ATEs_results[1] = 0.03613994943765849;

### These are the results in Table 1 of manuscript

In [14]:
round(abs(ATEs_results[1] - ATEs_results[5]),digits=3),
round(abs(ATEs_results[1] - ATEs_results[4]), digits=3),
round(abs(ATEs_results[1] - ATEs_results[6]), digits=3),
round(abs(ATEs_results[1] - ATEs_results[7]),digits=3)

(0.163, 0.324, 1.747, 0.771)