Skip to content

Commit

Permalink
Merge branch 'adam-monkeying-with-cov_adj'
Browse files Browse the repository at this point in the history
  • Loading branch information
josherrickson committed Jul 22, 2022
2 parents fa5cf2c + 6f985b1 commit baa8270
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 13 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.0
RoxygenNote: 7.2.1
Suggests:
knitr,
rmarkdown,
Expand Down
43 changes: 34 additions & 9 deletions R/cov_adj.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
#' @include SandwichLayer.R
NULL

##' @title Covariance Adjustment for Treatment Estimation
##' @param model Any model that inherits from a \code{glm}, \code{lm}, or \code{
##' robustbase::lmrob} object
##' Covariance Adjustment for Treatment Estimation
##'
##' Prior to obtaining predicted values, \code{cov_adj()} tries to identify the
##' treatment variable (as specified in the \code{design}) and replace it with
##' the reference level. If the treatment is binary, this is \code{FALSE}. If
##' treatment is numeric, it is the smallest non-negative value (note that this
##' means for 0/1 binary, it uses a 0). Factor treatments are not currently
##' supported, but if we add them, it will use the first \code{level()} of the
##' factor, you may change this by using \code{relevel()} to adjust.
##' @param model Any model that inherits from a \code{glm}, \code{lm}, or
##' \code{robustbase::lmrob} object
##' @param newdata Optional; a data.frame of new data
##' @param design Optional \code{Design}.
##' @return Covariate adjusted outcomes
Expand All @@ -15,23 +23,40 @@ cov_adj <- function(model, newdata = NULL, design = NULL) {
newdata <- tryCatch(
.get_data_from_model("weights", form),
error = function(e) {
warning(paste("Could not find quasiexperimental data in the call stack,",
"using the covariance model data to generate the covariance",
"adjustments"))
warning(paste("Could not find quasiexperimental data in the call",
"stack, using the covariance model data to generate",
"the covariance adjustments"))
stats::model.frame(model)
})
}

if (is.null(design)) {
design <- .get_design(NULL_on_error = TRUE)
}

if (!is.null(design)) {
trt_name <- var_names(design,'t')
if (trt_name %in% names(newdata))
if (is.numeric(treatment(design)[, 1])) {
newdata[[trt_name]] <- min(abs(treatment(design)[, 1]))
} else if (is.logical(treatment(design)[, 1])) {
newdata[[trt_name]] <- FALSE
} else if (is.factor(treatment(design)[, 1])) {
newdata[[trt_name]] <- levels(treatment(design)[, 1])[1]
} else {
warning(paste("The treatment variable is in the covariance adjustment",
"model, and is neither logical or numeric; for now,",
"partial residuals only implemented for logical or",
"numeric treatments"))
}
}

ca_and_grad <- .get_ca_and_prediction_gradient(model, newdata)
psl <- new("PreSandwichLayer",
ca_and_grad$ca,
fitted_covariance_model = model,
prediction_gradient = ca_and_grad$prediction_gradient)

if (is.null(design)) {
design <- .get_design(NULL_on_error = TRUE)
}

if (is.null(design)) {
return(psl)
Expand Down
12 changes: 9 additions & 3 deletions man/cov_adj.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 47 additions & 0 deletions tests/testthat/test.cov_adj.R
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,50 @@ options(old_opt)
#STARdata, : variable lengths differ (found for '(weights)')
# mod <- lm(readk ~ birth + lunchk, data = STARdata,
# offset = cov_adj(cmod, newdata = STARdata), weights = ate(des))



test_that("Basics of replacing treatment variable with reference level", {
data(simdata)

# Binary treatment
des <- rct_design(z ~ cluster(cid1, cid2), data = simdata)
camod <- lm(y ~ x + z, data = simdata)
ca <- cov_adj(camod, newdata = simdata, design = des)

simdata2 <- simdata
simdata2$z <- 0

manual <- predict(camod, newdata = simdata2)
expect_true(all(manual == ca))

### Let's just make sure we're not getting spurious positive results
simdata2 <- simdata
manual <- predict(camod, newdata = simdata2)
expect_false(all(manual == ca))

# Numeric treatment
des <- rct_design(dose ~ cluster(cid1, cid2), data = simdata)
camod <- lm(y ~ x + dose, data = simdata)
ca <- cov_adj(camod, newdata = simdata, design = des)

simdata2 <- simdata
simdata2$dose <- 50

manual <- predict(camod, newdata = simdata2)
expect_true(all(manual == ca))

# Factor treatment
## simdata$dose <- as.factor(simdata$dose)
## des <- rct_design(dose ~ cluster(cid1, cid2), data = simdata)
## camod <- lm(y ~ x + dose, data = simdata)
## ca <- cov_adj(camod, newdata = simdata, design = des)

## simdata2 <- simdata
## simdata2$dose <- levels(simdata2$dose)[1]

## manual <- predict(camod, newdata = simdata2)
## expect_true(all(manual == ca))
# Current build does NOT allow factor treatment so this will fail

})

0 comments on commit baa8270

Please sign in to comment.