Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add estimate_counterfactual_outcomes for W in {0,1} #403 #411

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions REFERENCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ The `average_treatment_effect` function implements two types of doubly robust av
- `target.sample = "overlap"`: the overlap-weighted ATE `sum_{i = 1}^n e(Xi) (1 - e(Xi)) E[Y(1) - Y(0) | X = Xi] / sum_{i = 1}^n e(Xi) (1 - e(Xi))`,
where `e(x) = P[W_i = 1 | X_i = x]`. This last estimand is recommended by Li et al. (2017) in case of poor overlap (i.e., when the treatment propensities e(x) may be very close to 0 or 1), as it doesn't involve dividing by estimated propensities.

### Counterfactual Outcomes

The `estimate_counterfactual_outcomes` function helps to estimate quantities of the form `E[Y|X=x, W=w]` by providing a list of vectors with pointwise estimates of the counterfactual outcomes. The keys of the list are character representations of the treatments.

```
Y.hats <- estimate_counterfactual_outcomes(forest, subset)
Y.hat.0 <- Y.hats$`0`
Y.hat.1 <- Y.hats$`1`
```

## Additional Features

The following sections describe other features of GRF that may be of interest.
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export(average_treatment_effect)
export(boosted_regression_forest)
export(causal_forest)
export(custom_forest)
export(est_counterfactual_outcomes)
export(get_sample_weights)
export(get_tree)
export(instrumental_forest)
Expand Down
49 changes: 46 additions & 3 deletions r-package/grf/R/average_treatment_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ average_treatment_effect <- function(forest,
stop("Invalid target sample.")
}

# Get estimates for the regress surfaces E[Y|X, W=0/1]
Y.hat.0 <- subset.Y.hat - subset.W.hat * tau.hat.pointwise
Y.hat.1 <- subset.Y.hat + (1 - subset.W.hat) * tau.hat.pointwise
Y.hats <- est_counterfactual_outcomes(forest, subset)
Y.hat.0 <- Y.hats$`0`
Y.hat.1 <- Y.hats$`1`

if (method == "TMLE") {
loaded <- requireNamespace("sandwich", quietly = TRUE)
Expand Down Expand Up @@ -321,6 +321,49 @@ average_treatment_effect <- function(forest,
return(c(estimate = tau.avg, std.err = tau.se))
}


#' Estimate counterfactual outcomes conditional on X and W given a trained
#' causal forest. Returns a list containing '0' and '1', the
#' Y.hat estimates for the W=0 and W=1 treatment cases.
#'
#' @param forest The trained forest.
#' @param subset Specifies subset of the training examples over which we
#' estimate the ATE. WARNING: For valid statistical performance,
#' the subset should be defined only using features Xi, not using
#' the treatment Wi or the outcome Yi.
#' @return A list containing '0' and '1', the estimates for the W=0 and W=1
#' treatment cases.
#' @export
est_counterfactual_outcomes <- function(forest, subset=NULL){

if (is.null(subset)) {
subset <- 1:length(forest$Y.hat)
}

if (class(subset) == "logical" & length(subset) == length(forest$Y.hat)) {
subset <- which(subset)
}

if (!all(subset %in% 1:length(forest$Y.hat))) {
stop(paste("If specified, subset must be a vector contained in 1:n,",
"or a boolean vector of length n."))
}

if (!all(unique(forest$W.orig) %in% c(0, 1))){
stop(paste("est_counterfactual_outcomes only implemented for ",
"binary treatments; !all(unique(forest$W.orig) %in% c(0,1))"))
}

subset.W.hat <- forest$W.hat[subset]
subset.Y.hat <- forest$Y.hat[subset]
tau.hat.pointwise <- predict(forest)$predictions[subset]

# Get estimates for the regress surfaces E[Y|X, W=0/1]
Y.hat.0 <- subset.Y.hat - subset.W.hat * tau.hat.pointwise
Y.hat.1 <- subset.Y.hat + (1 - subset.W.hat) * tau.hat.pointwise
list("0" = Y.hat.0, "1" = Y.hat.1)
}

observation_weights <- function(forest) {
sample.weights <- if (is.null(forest$sample.weights)) {
rep(1, length(forest$Y.orig))
Expand Down
12 changes: 12 additions & 0 deletions r-package/grf/tests/testthat/test_average_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,15 @@ test_that("average effect estimation doesn't error on data with a single feature
average_partial_effect(forest)
expect_true(TRUE) # so we don't get a warning about an empty test
})

test_that("est_counterfactual_outcomes yields all 1/0 when binary outcome == treatment", {
n <- 50; p <- 1
X <- matrix(rnorm(n * p), n, p)
W <- rbinom(n, 1, 0.5)
Y <- W
c.forest <- causal_forest(X, Y, W)
y.hats <- est_counterfactual_outcomes(c.forest)

expect_true(all(y.hats$`0` == 0))
expect_true(all(y.hats$`1` == 1))
})