Skip to content

Commit

Permalink
Merge pull request #12 from bonStats/stan4bart
Browse files Browse the repository at this point in the history
dev versions of stan4bart + bartCause supported, but CRAN problems
  • Loading branch information
bonStats committed May 18, 2023
2 parents 1b1fc37 + 2660fbb commit e740bca
Show file tree
Hide file tree
Showing 28 changed files with 1,079 additions and 116 deletions.
111 changes: 0 additions & 111 deletions .github/workflows/R-CMD-check.yaml

This file was deleted.

49 changes: 49 additions & 0 deletions .github/workflows/check-standard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

name: R-CMD-check

jobs:
R-CMD-check:
runs-on: ${{ matrix.config.os }}

name: ${{ matrix.config.os }} (${{ matrix.config.r }})

strategy:
fail-fast: false
matrix:
config:
- {os: macos-latest, r: 'release'}
- {os: windows-latest, r: 'release'}
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'}
- {os: ubuntu-latest, r: 'oldrel-1'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes

steps:
- uses: actions/checkout@v3

- uses: r-lib/actions/setup-pandoc@v2

- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}
http-user-agent: ${{ matrix.config.http-user-agent }}
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
with:
upload-snapshots: true
11 changes: 8 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: tidytreatment
Type: Package
Title: Tidy Methods for Bayesian Treatment Effect Models
Version: 0.2.2
Version: 0.3.0.1
Authors@R: person("Joshua J", "Bon", email = "joshuajbon@gmail.com",
role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2313-2949"))
Expand All @@ -17,18 +17,23 @@ Suggests:
knitr,
rmarkdown,
BART,
stan4bart,
bartCause,
ggplot2,
testthat (>= 3.0.0),
withr
VignetteBuilder: knitr
RoxygenNote: 7.1.1
RoxygenNote: 7.2.3
Imports:
tidybayes,
purrr,
tidyr,
dplyr,
readr,
rlang
rlang,
dbarts,
coda,
magrittr
Enhances:
bartMachine
Config/testthat/edition: 3
17 changes: 17 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
# Generated by roxygen2: do not edit by hand

S3method(covariate_importance,bartMachine)
S3method(covariate_importance,bartcFit)
S3method(covariate_importance,lbart)
S3method(covariate_importance,mbart)
S3method(covariate_importance,mbart2)
S3method(covariate_importance,pbart)
S3method(covariate_importance,stan4bartFit)
S3method(covariate_importance,wbart)
S3method(covariate_with_treatment_importance,bartMachine)
S3method(covariate_with_treatment_importance,lbart)
S3method(covariate_with_treatment_importance,mbart)
S3method(covariate_with_treatment_importance,mbart2)
S3method(covariate_with_treatment_importance,pbart)
S3method(covariate_with_treatment_importance,wbart)
S3method(epred_draws,bartcFit)
S3method(epred_draws,stan4bartFit)
S3method(fitted_draws,bartMachine)
S3method(fitted_draws,lbart)
S3method(fitted_draws,mbart)
S3method(fitted_draws,mbart2)
S3method(fitted_draws,pbart)
S3method(fitted_draws,wbart)
S3method(linpred_draws,bartcFit)
S3method(linpred_draws,stan4bartFit)
S3method(model.matrix,bartMachine)
S3method(predicted_draws,bartMachine)
S3method(predicted_draws,bartcFit)
S3method(predicted_draws,stan4bartFit)
S3method(predicted_draws,wbart)
S3method(print,lbart)
S3method(print,mbart)
Expand All @@ -30,9 +38,13 @@ S3method(print,wbart)
S3method(residual_draws,bartMachine)
S3method(residual_draws,pbart)
S3method(residual_draws,wbart)
S3method(tidy_draws,bartcFit)
S3method(tidy_draws,stan4bartFit)
S3method(treatment_effects,bartcFit)
S3method(treatment_effects,default)
S3method(variance_draws,bartMachine)
S3method(variance_draws,wbart)
export("%>%")
export(avg_treatment_effects)
export(covariate_importance)
export(covariate_with_treatment_importance)
Expand All @@ -43,6 +55,8 @@ export(tidy_ate)
export(tidy_att)
export(treatment_effects)
export(variance_draws)
importFrom(dbarts,extract)
importFrom(magrittr,"%>%")
importFrom(rlang,"!!")
importFrom(rlang,":=")
importFrom(rlang,.data)
Expand All @@ -52,7 +66,10 @@ importFrom(stats,terms)
importFrom(tidybayes,add_fitted_draws)
importFrom(tidybayes,add_predicted_draws)
importFrom(tidybayes,add_residual_draws)
importFrom(tidybayes,epred_draws)
importFrom(tidybayes,fitted_draws)
importFrom(tidybayes,linpred_draws)
importFrom(tidybayes,predicted_draws)
importFrom(tidybayes,residual_draws)
importFrom(tidybayes,tidy_draws)
importFrom(utils,methods)
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# tidytreatment

## tidytreatment 0.3.0.1

* Now supports bartCause and stan4bart (github versions)
* Will move to CRAN soon

## tidytreatment 0.2.2

* Updates to handle changes in ggdist required arguments.
Expand Down
29 changes: 28 additions & 1 deletion R/covariate-importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#' @param ... Arguments to pass to particular methods.
#'
#' @return Tidy data with counts of variable inclusion, when interacting with treatment variable.
#' @export
#'
#' @export
covariate_with_treatment_importance <- function(model, treatment, ...) {
UseMethod("covariate_with_treatment_importance")
}
Expand Down Expand Up @@ -155,3 +155,30 @@ covariate_with_treatment_importance.mbart2 <- function(model, treatment, ...) {
covariate_with_treatment_importance.mbart <- function(model, treatment, ...) {
covariate_with_treatment_importance_BART(model, treatment, ...)
}

#' @export
covariate_importance.stan4bartFit <- function(model, ...) {

# extract mcmc draws
vv <- dbarts::extract(model, type = "varcount", combine_chains = F, include_warmup = F)

res <- dplyr::tibble(
variable = dimnames(vv)$predictor,
avg_inclusion = rowMeans(vv)
)

res
}

#' @export
covariate_importance.bartcFit <- function(model, fitstage = c("response","assignment"), ...) {

fitstage <- match.arg(fitstage)

if(fitstage == "response"){
covariate_importance(model$fit.rsp, ...)
} else {
covariate_importance(model$fit.trt, ...)
}

}
24 changes: 24 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,27 @@
#' @format Object of type \code{BART::wbart}
#' @source \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw}
"bartmodel1_modelmatrix"

#' Example simulated dataset 2: with subject specific random effects
#'
#' Simulated with \code{simulate_su_hill_data(...)}, see details.
#'
#' \preformatted{set.seed(101)
#' suhillsim1 <- simulate_su_hill_data(n = 100, treatment_linear = FALSE, omega = 0, add_categorical = TRUE,
#' coef_categorical_treatment = c(0,0,1),
#' coef_categorical_nontreatment = c(-1,0,-1), sd_subjects = 2, n_subjects = 10)
#' }
#'
#' @format See \code{?simulate_su_hill_data} for output format.
#' @source \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw}
"suhillsim2_ranef"

#' Example model 2
#'
#' Model fit with simulated data from simulated dataset \code{suhillsim1}.
#'
#' Propensity score estimated and included \code{suhillsim1} for fitting the model.
#'
#' @format Object of type \code{stan4bartFit}
#' @source \url{https://github.com/bonStats/tidytreatment/tree/master/data-raw}
#"stan4bartmodel2" # too large
11 changes: 10 additions & 1 deletion R/simulate-su-hill.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#' @param response_aligned Response surface is aligned?
#' @param y_sd Observation noise.
#' @param add_categorical Should a categorical variable be added? (Not in Hill and Su)
#' @param n_subjects How many subjects are there? For repeated observations. (Hill and Su = 0, default)
#' @param sd_subjects Random effect intercept standard deviation for subjects. (Not in Hill and Su. Used if n_subjects > 0)
#' @param coef_categorical_treatment What are the coefficients of the categorical variable under treatment? (Not in Hill and Su)
#' @param coef_categorical_nontreatment What are the coefficients of the categorical variable under nontreatment? (Not in Hill and Su)
#' @return An object of class \code{suhillsim} that is a list with elements
Expand All @@ -35,7 +37,7 @@
#' \item{formulas}{Response formulas used to generate data}
#' \item{coefs}{Coefficients for the formulas}
#' @export
simulate_su_hill_data <- function(n, treatment_linear = TRUE, response_parallel = TRUE, response_aligned = TRUE, y_sd = 1, tau = 4, omega = 0, add_categorical = FALSE, coef_categorical_treatment = NULL, coef_categorical_nontreatment = NULL) {
simulate_su_hill_data <- function(n, treatment_linear = TRUE, response_parallel = TRUE, response_aligned = TRUE, y_sd = 1, tau = 4, omega = 0, add_categorical = FALSE, n_subjects = 0, sd_subjects = 1, coef_categorical_treatment = NULL, coef_categorical_nontreatment = NULL) {
fargs <- as.list(match.call())

coefs <- dplyr::tribble(
Expand Down Expand Up @@ -148,6 +150,13 @@ simulate_su_hill_data <- function(n, treatment_linear = TRUE, response_parallel
rdata <- cbind(data.frame(y = y, z = z), X)
}

if (n_subjects > 0) {
# add subject effects
rdata <- rdata %>% dplyr::mutate(subject_id = factor(sample(1:n_subjects, nrow(rdata), replace = T)))
subject_effect <- rnorm(n_subjects, sd = sd_subjects)
rdata$y <- rdata$y + subject_effect[rdata$subject_id]
}

# prepare formula's to describe simulation truth
formula_terms <- attributes(terms(su_hill_formula))$term.labels

Expand Down
Loading

0 comments on commit e740bca

Please sign in to comment.