Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 93d8526
Showing
52 changed files
with
32,207 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
Package: bartBMA | ||
Type: Package | ||
Title: Bayesian Additive Regression Trees using Bayesian Model | ||
Averaging | ||
Version: 1.0 | ||
Date: 2020-02-05 | ||
Author: Belinda Hernandez [aut, cre] | ||
Adrian E. Raftery [aut] | ||
Stephen R Pennington [aut] | ||
Andrew C. Parnell [aut] | ||
Eoghan O'Neill [ctb] | ||
Maintainer: Belinda Hernandez <HERNANDB@tcd.ie> | ||
Description: "BART-BMA Bayesian Additive Regression Trees using Bayesian Model Averaging" (Hernandez B, Raftery A.E., Parnell A.C. (2018) <doi:10.1007/s11222-017-9767-1>) is an extension to the original BART sum-of-trees model (Chipman et al 2010). BART-BMA differs to the original BART model in two main aspects in order to implement a greedy model which | ||
will be computationally feasible for high dimensional data. Firstly BART-BMA uses a greedy search for the best split points and variables when growing decision trees within each sum-of-trees | ||
model. This means trees are only grown based on the most predictive set of split rules. Also rather than using Markov chain Monte Carlo (MCMC), BART-BMA uses a greedy implementation of Bayesian Model Averaging called Occam's Window | ||
which take a weighted average over multiple sum-of-trees models to form its overall prediction. This means that only the set of sum-of-trees for which there is high support from the data | ||
are saved to memory and used in the final model. | ||
License: GPL (>= 2) | ||
Imports: Rcpp (>= 1.0.0), mvnfast, Rdpack | ||
RdMacros: Rdpack | ||
LinkingTo: Rcpp, RcppArmadillo, BH | ||
RoxygenNote: 7.0.2 | ||
Encoding: UTF-8 | ||
NeedsCompilation: yes | ||
Packaged: 2020-03-05 17:24:23 UTC; HERNANDB | ||
Repository: CRAN | ||
Date/Publication: 2020-03-13 11:50:05 UTC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
e79e25a3ca77b49109ac20c74932c886 *DESCRIPTION | ||
050c1fece042b3a719216d2848e814ca *NAMESPACE | ||
6a12c83aef7560428f3d29c2c144d1e4 *R/ITEs_CATT_bartBMA_exact_par.R | ||
92ff2fa242c053b94651e8b646cf06b3 *R/ITEs_bartBMA.R | ||
7bd973c88993ce2dac864eb9a753cf2b *R/ITEs_bartBMA_exact_par.R | ||
0bb2a28f04260a5d13f0ae86bea1d465 *R/RcppExports.R | ||
b469aeb2d1d9df01bed0c42abbc9cf3a *R/bartBMA.R | ||
97d916be4ab7171ee253805cc456a66e *R/bartBMA_with_ITEs_exact_par.R | ||
d1d66c192ce6ba9469676a5d489d1923 *R/pred_expectation_intervals_bbma_GS.R | ||
e24bafa48615c0f4dac8534a1e4078fc *R/pred_intervals_GS.R | ||
056b89f61600dad068e7f399bd6aebba *R/pred_intervals_new_initials_GS.R | ||
328d77c4357b9e1b056c55b7dab13dcd *R/pred_ints_exact.R | ||
a1d1ca44d8039ef4f12ca2a8b20280a2 *R/pred_ints_exact_par.R | ||
75a3eceb040ee0b38e3889a7ace10c22 *R/pred_means_bbma_GS.R | ||
5b4a9bc8f2241125b7d4ef4ac5e019be *R/pred_means_new_initials_GS.R | ||
bbf6c8d6e546d65ac5abcc3450436350 *R/predict_bartBMA.R | ||
ddb16845ae209507e80ce977d4b8a212 *R/predict_probit_bartBMA.R | ||
a91ede8823bc890a85d4e15d829035b8 *R/preds_bbma_lin_alg.R | ||
fbcc63b72bb12bcd48d56d09f9480aff *R/printbartBMA.R | ||
7518970253acc7f578e62e8a7252c396 *R/probit_bartBMA.R | ||
b64587a6d4bb889d5b49ed2a34531af2 *R/varImpScores.R | ||
880a57c4dc5377fa41fc07934dd2049f *R/varIncProb.R | ||
ac71c648b512113b453627ee6abe2715 *README.md | ||
0ba955f4ce5c2dc74fcc88583ecdfc4d *build/partial.rdb | ||
e76313c15640dd3cb06af10c654c6265 *inst/REFERENCES.bib | ||
9e1662c7a418c3b34668f8b349511ccf *man/ITEs_CATT_bartBMA_exact_par.Rd | ||
92fbe8d5d7cbf498405db3373b083fca *man/ITEs_bartBMA.Rd | ||
53cbfe1318b36acfe0443de07b74bb61 *man/ITEs_bartBMA_exact_par.Rd | ||
40b7f0effe06deaebe6c1a039b4bcc69 *man/bartBMA.Rd | ||
b21688a47846f395a32dfc6fe141b05a *man/bartBMA_with_ITEs_exact_par.Rd | ||
60534399cef508d15915e7bbe65e8941 *man/pred_expectation_intervals_bbma_GS.Rd | ||
bdbfd7702601d391e6c572c46ae5edc4 *man/pred_intervals_bbma_GS.Rd | ||
3c44c2348ec50a28ee6ab996109940ad *man/pred_intervals_new_initials_GS.Rd | ||
6b0178edbae0d6fe45fd4b763ea16986 *man/pred_ints_exact.Rd | ||
fb1ef610de44192dc7e9a3a2dcf0f54a *man/pred_ints_exact_par.Rd | ||
501ffd048a6ed9f26912e08ededd36f4 *man/pred_means_bbma_GS.Rd | ||
fb670a97fd7d45009b7a8518ca0acba2 *man/pred_means_bbma_new_initials_GS.Rd | ||
7646d42ef36befdd42219d6af14950f6 *man/predict_bartBMA.Rd | ||
9217efa5ce4ef018684a39b2f1f8b4e4 *man/predict_probit_bartBMA.Rd | ||
e58339f97f0eecb8b5ae18568fd5fe72 *man/preds_bbma_lin_alg.Rd | ||
650b23959ba3b4b0906fd8a32933c5f4 *man/probit_bartBMA.Rd | ||
7f6e136681063b52c38dc63ee6ad267f *man/varImpScores.Rd | ||
1e82e3ed46c543494b1f6aa2c5054d9a *man/varIncProb.Rd | ||
54030fd124e8b4d989c15438630619ae *src/161014_predict_function.cpp | ||
91c3add20d1afe917750f259b28ca1e4 *src/201014_BART_BMA_var_importance.cpp | ||
ab9c51353963493f7c72113cd5c3411e *src/BARTBMA_SumTreeLikelihood.cpp | ||
4abb712c55946e7fb752a08868edeaf9 *src/GibbsSamplerTrainingTestData.cpp | ||
61059660eb073d93e00e8ee054237071 *src/Makevars | ||
61059660eb073d93e00e8ee054237071 *src/Makevars.win | ||
1a8175f4305d0f4e56116609f9531738 *src/RcppExports.cpp | ||
d41d8cd98f00b204e9800998ecf8427e *src/bbma_preds_samples.cpp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
S3method(bartBMA,default) | ||
S3method(probit_bartBMA,default) | ||
export(ITEs_CATT_bartBMA_exact_par) | ||
export(ITEs_bartBMA) | ||
export(ITEs_bartBMA_exact_par) | ||
export(bartBMA) | ||
export(bartBMA.default) | ||
export(bartBMA_with_ITEs_exact_par) | ||
export(pred_expectation_intervals_bbma_GS) | ||
export(pred_intervals_bbma_GS) | ||
export(pred_intervals_new_initials_GS) | ||
export(pred_ints_exact) | ||
export(pred_ints_exact_par) | ||
export(pred_means_bbma_GS) | ||
export(pred_means_bbma_new_initials_GS) | ||
export(predict_bartBMA) | ||
export(predict_probit_bartBMA) | ||
export(preds_bbma_lin_alg) | ||
export(probit_bartBMA) | ||
export(probit_bartBMA.default) | ||
export(varImpScores) | ||
export(varIncProb) | ||
importFrom(Rcpp,evalCpp) | ||
importFrom(Rdpack,reprompt) | ||
importFrom(stats,pnorm) | ||
importFrom(stats,qchisq) | ||
importFrom(stats,qnorm) | ||
importFrom(stats,quantile) | ||
importFrom(stats,sd) | ||
importFrom(utils,head) | ||
useDynLib(bartBMA, .registration = TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#' @title Estimate ITEs, CATE, CATT, CATNT and obtain credible intervals (in-sample or out-of-sample). | ||
#' | ||
#' @description This function takes a set of sum of tree models obtained from ITEs_bartBMA, and then estimates ITEs, and the CATE, CATT, and CATNT and obtains prediction intervals | ||
#' @param object Output from ITEs_bartBMA of class ITE_ests.bartBMA. | ||
#' @param l_quant Lower quantile of credible intervals for the ITEs, CATT, CATNT. | ||
#' @param u_quant Upper quantile of credible intervals for the ITEs, CATT, CATNT. | ||
#' @param newdata Test data for which predictions are to be produced. Default = NULL. If NULL, then produces prediction intervals for training data if no test data was used in producing the bartBMA object, or produces prediction intervals for the original test data if test data was used in producing the bartBMA object. | ||
#' @param update_resids Option for whether to update the partial residuals in the gibbs sampler. If equal to 1, updates partial residuals, if equal to zero, does not update partial residuals. The defaullt setting is to update the partial residuals. | ||
#' @param num_cores Number of cores used in parallel. | ||
#' @param root_alg_precision The algorithm should obtain approximate bounds that are within the distance root_alg_precision of the true quantile for the chosen average of models. | ||
#' @param training_data The training data matrix | ||
#' @param zvec The treatment indicator vector. Training data treatment vector for insample predictions, test data treatment vector for out of sample predictions. | ||
#' @export | ||
#' @return The output is a list of length 8: | ||
#' \item{ITE_intervals}{A 3 by n matrix, where n is the number of observations. The first row gives the l_quant*100 quantiles of the individual treatment effects. The second row gives the medians of the ITEs. The third row gives the u_quant*100 quantiles of the ITEs.} | ||
#' \item{ITE_estimates}{An n by 1 matrix containing the Individual Treatment Effect estimates.} | ||
#' \item{CATE_estimate}{The Conditional Average Treatment Effect Estimate} | ||
#' \item{CATE_Interval}{A 3 by 1 matrix. The first element is the l_quant*100 quantile of the CATE distribution, the second element is the median of the CATE distribution, and the thied element is the u_quant*100 quantile of the CATE distribution.} | ||
#' \item{CATT_estimate}{The Conditional Average Treatment Effect on the Treated Estimate} | ||
#' \item{CATT_Interval}{A 3 by 1 matrix. The first element is the l_quant*100 quantile of the CATT distribution, the second element is the median of the CATT distribution, and the thied element is the u_quant*100 quantile of the CATT distribution.} | ||
#' \item{CATNT_estimate}{The Conditional Average Treatment Effect on the Not Treated Estimate} | ||
#' \item{CATNT_Interval}{A 3 by 1 matrix. The first element is the l_quant*100 quantile of the CATNT distribution, the second element is the median of the CATNT distribution, and the thied element is the u_quant*100 quantile of the CATNT distribution.} | ||
#' @examples | ||
#' \dontrun{ | ||
#' #Example of BART-BMA for ITE estimation | ||
#' # Applied to data simulations from Hahn et al. (2020, Bayesian Analysis) | ||
#' # "Bayesian Regression Tree Models for Causal Inference: Regularization, | ||
#' # Confounding, and Heterogeneous Effects | ||
#' n <- 250 | ||
#' x1 <- rnorm(n) | ||
#' x2 <- rnorm(n) | ||
#' x3 <- rnorm(n) | ||
#' x4 <- rbinom(n,1,0.5) | ||
#' x5 <- as.factor(sample( LETTERS[1:3], n, replace=TRUE)) | ||
#' | ||
#' p= 0 | ||
#' xnoise = matrix(rnorm(n*p), nrow=n) | ||
#' x5A <- ifelse(x5== 'A',1,0) | ||
#' x5B <- ifelse(x5== 'B',1,0) | ||
#' x5C <- ifelse(x5== 'C',1,0) | ||
#' | ||
#' x_covs_train <- cbind(x1,x2,x3,x4,x5A,x5B,x5C,xnoise) | ||
#' | ||
#' #Treatment effect | ||
#' #tautrain <- 3 | ||
#' tautrain <- 1+2*x_covs_train[,2]*x_covs_train[,4] | ||
#' | ||
#' #Prognostic function | ||
#' mutrain <- 1 + 2*x_covs_train[,5] -1*x_covs_train[,6]-4*x_covs_train[,7] + | ||
#' x_covs_train[,1]*x_covs_train[,3] | ||
#' sd_mtrain <- sd(mutrain) | ||
#' utrain <- runif(n) | ||
#' #pitrain <- 0.8*pnorm((3*mutrain/sd_mtrain)-0.5*x_covs_train[,1])+0.05+utrain/10 | ||
#' pitrain <- 0.5 | ||
#' ztrain <- rbinom(n,1,pitrain) | ||
#' ytrain <- mutrain + tautrain*ztrain | ||
#' #pihattrain <- pbart(x_covs_train,ztrain )$prob.train.mean | ||
#' | ||
#' #set lower and upper quantiles for intervals | ||
#' lbound <- 0.025 | ||
#' ubound <- 0.975 | ||
#' | ||
#' | ||
#' trained_bbma <- ITEs_bartBMA(x_covariates = x_covs_train, | ||
#' z_train = ztrain, | ||
#' y_train = ytrain) | ||
#' | ||
#' example_output <- ITEs_CATT_bartBMA_exact_par(trained_bbma[[2]], | ||
#' l_quant = lbound, | ||
#' u_quant= ubound, | ||
#' training_data = x_covs_train, | ||
#' zvec = ztrain, | ||
#' num_cores = 1) | ||
#'} | ||
|
||
|
||
ITEs_CATT_bartBMA_exact_par <-function(object,#min_possible,max_possible, | ||
l_quant,u_quant,newdata=NULL,update_resids=1, | ||
num_cores=1, | ||
root_alg_precision=0.00001, | ||
training_data,zvec){ | ||
|
||
|
||
if(is.null(newdata) && length(object)==16){ | ||
#if test data specified separately | ||
ret<-pred_ints_ITE_CATT_outsamp_par(object$sumoftrees,object$obs_to_termNodesMatrix,object$response,object$bic,#min_possible, max_possible, | ||
object$nrowTrain, | ||
nrow(object$test_data),object$a,object$sigma,0,object$nu, | ||
object$lambda,#diff_inital_resids, | ||
object$test_data,l_quant,u_quant,num_cores, | ||
root_alg_precision,training_data,zvec | ||
) | ||
|
||
|
||
|
||
}else{if(is.null(newdata) && length(object)==14){ | ||
#else return Pred Ints for training data | ||
ret <- pred_ints_ITE_CATT_insamp_par(object$sumoftrees, | ||
object$obs_to_termNodesMatrix, | ||
object$response,object$bic,#min_possible, max_possible, | ||
object$nrowTrain,#nrow(object$test_data), | ||
object$a,object$sigma,0,object$nu, | ||
object$lambda,#diff_inital_resids,object$test_data, | ||
l_quant,u_quant, | ||
num_cores,root_alg_precision,training_data,zvec | ||
) | ||
|
||
}else{ | ||
#if test data included in call to object | ||
ret<-pred_ints_ITE_CATT_outsamp_par(object$sumoftrees,object$obs_to_termNodesMatrix,object$response,object$bic,#min_possible, max_possible, | ||
object$nrowTrain, | ||
nrow(newdata), object$a,object$sigma,0,object$nu, | ||
object$lambda,#diff_inital_resids, | ||
newdata,l_quant,u_quant,num_cores, | ||
root_alg_precision,training_data,zvec | ||
) | ||
|
||
}} | ||
|
||
#PI<-apply(draws_from_mixture,2,function(x)quantile(x,probs=c(l_quant,0.5,u_quant))) | ||
|
||
|
||
|
||
#each row is a vector drawn from the mixture distribution | ||
|
||
|
||
names(ret)<-c("ITE_intervals", | ||
"ITE_estimates", | ||
"CATE_estimate", | ||
"CATE_Interval", | ||
"CATT_estimate", | ||
"CATT_Interval", | ||
"CATNT_estimate", | ||
"CATNT_Interval") | ||
|
||
ret | ||
} |
Oops, something went wrong.