In [1]:
library(grf)

In [26]:
run_simulation <- function(param.set,dgp,iterations) {
    
    repeat_simulation <- function(params,dgp,iterations) {

        mse.tau <- replicate(iterations,run_iteration(params,dgp)) 

        return(rowMeans(mse.tau)) 
    }
              
    mse.tau.table <- apply(param.set,MARGIN=1,FUN=repeat_simulation,dgp,iterations)
    return(t(mse.tau.table)) 
}

run_iteration <- function(params,dgp) {
    n = params[1]
    p = params[2] 
    noise = params[3]
    X <- matrix(rnorm(n * p), n, p)
    data <- dgp(X)
    Y <- data$E.Y + rnorm(n,sd=noise)
    W <- data$W 
    TAU <- data$tau 
    tau.forest.boost = causal_forest(X, Y, W,
                           tune.parameters = TRUE,boosting=TRUE)
    tau.forest = causal_forest(X, Y, W,
                          tune.parameters = TRUE,boosting=FALSE)

    tau.hat.boost = predict(tau.forest.boost)$predictions
    mse.tau.boost = mean((TAU - tau.hat.boost)^2)

    tau.hat.forest = predict(tau.forest)$predictions
    mse.tau.forest = mean((TAU - tau.hat.forest)^2)
    return(c(mse.tau.boost,mse.tau.forest))
}

dgp_nonlinear <- function(X) {
    result = list() 
    TAU = 1 / (1 + exp(-X[, 3]))
    E.W = 1 / (1 + exp(-X[, 1] - X[, 2]))
    W = rbinom(length(E.W) ,1, E.W)
    E.Y = pmax(X[, 2] + X[, 3], 0) + rowMeans(X[, 4:6]) / 2 + W * TAU
    result$W <- W
    result$tau <- TAU
    result$E.Y <- E.Y 
    return(result)
}

dgp_linear <- function(X) {
    n = dim(X)[1]
    result = list() 
    TAU = 2
    E.W = rep(0.5,n)
    W = rbinom(length(E.W) ,1, E.W)
    E.Y = W*TAU + X[,1] +2*X[,2]
    result$W <- W
    result$tau <- TAU
    result$E.Y <- E.Y 
    return(result)
    
}

In [None]:
set.seed(1) 
n.opt <- c(100,1000,10000)
#p.opt <- c(6,12,60)
p.opt <- c(6,12,60)
#noise.opt <- c(1,5,20)
noise.opt <-c(1,5,10) 

param.set <- expand.grid(n.opt,p.opt,noise.opt)

mse.table <- run_simulation(param.set,dgp_linear,100)

results <- cbind(param.set,mse.table)
colnames(results) <- c("n","p","sigma","mse.tau.boost","mse.tau.forest")
results
#mean MSE 
#histogram of both 