Skip to content

Commit

Permalink
make sure combine.distcompare works even though currently internal
Browse files Browse the repository at this point in the history
  • Loading branch information
ericdunipace committed Jan 10, 2024
1 parent 61411d6 commit 1e4ef14
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 19 deletions.
54 changes: 47 additions & 7 deletions R/combine.dist.compare.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ methods::setClass("combine.distcompare")

#' Combine distance calculations from the distCompare function
#'
#' @param distances A list of `distcompare` objects that are the result of [distCompare()]
#' @param ... `distcompare` objects that are the result of [distCompare()]
#'
#' @return an object of class `combine.distcompare`, the combined `distcompare` class objects as returned by [distCompare()] function
#' @keywords internal
Expand Down Expand Up @@ -37,12 +37,52 @@ methods::setClass("combine.distcompare")
# combine <- combine.distcompare(dc1, dc2)
combine.distcompare <- function(...) {

if( (...length() == 1) && is.list(...) && !is.distcompare(...)) {
distances <- ...elt(1L)
} else if ( (...length() == 1) && is.distcompare(...)) {
message("Only one `distcompare` object was passed to the function. Returning original object")
return(...elt(1L))
} else {
distances <- list(...)
}

stopifnot(is.list(distances))
if (!all(sapply(distances, is.distcompare))) {
stop("All members must be distcompare object")
}
niter <- length(distances)
cmb <- list(parameters = NULL, predictions = NULL, p = NULL)

ps <- sapply(distances, function(d) d$p)
stopifnot(all(diff(ps)==0))

cmb$p <- ps[1]
post <- do.call("rbind", lapply(distances, function(d) d$parameters))
predictions <- do.call("rbind", lapply(distances, function(d) d$predictions))
# methods <- do.call("rbind", lapply(distances, function(d) d$method))

if(! is.null(post)) {
cmb$parameters <- post
}

if (!is.null(predictions)) {
cmb$predictions <- predictions
}

class(cmb) <- class(distances[[1L]])

return(cmb)
}


combine_and_augment_distcompare <- function(...) {

distances <- list(...)
if(is.list(...)) distances <- unlist(distances, recursive = FALSE)

stopifnot(is.list(distances))
if (!all(sapply(distances, is.distcompare))) {
stop("All members must be distcompare object")
stop("All members must be distcompare object")
}
niter <- length(distances)
length.each <- sapply(distances, function(i) nrow(i$predictions))
Expand Down Expand Up @@ -71,15 +111,15 @@ combine.distcompare <- function(...) {
if(!is.null(cmb$predictions)) cmb$predictions <- cbind(cmb$predictions, ranks = ranks.predictions, iter = iter)
}

class(cmb) <- c("combine.distcompare","WpProj")
class(cmb) <- c("combine_distcompare","WpProj")

return(cmb)
}


plot.combine.distcompare <- function (x, ylim = NULL, ylabs = c(NULL,NULL), facet.group = NULL, ...) {
setOldClass("combine_distcompare")
plot.combine_distcompare <- function (x, ylim = NULL, ylabs = c(NULL,NULL), facet.group = NULL, ...) {
distances <- x
stopifnot(inherits(distances, "combine.dist.compare"))
stopifnot(inherits(distances, "combine_distcompare"))
dots <- list(...)
alpha <- dots$alpha
base_size <- dots$base_size
Expand Down Expand Up @@ -394,6 +434,6 @@ print.plotrank <- function(x,...) {
# combine <- combine.distcompare(list(dc1, dc2))
# plot(combine)
# }
methods::setMethod("plot", c("x" ="combine.distcompare"), plot.combine.distcompare)
methods::setMethod("plot", c("x" ="combine_distcompare"), plot.combine_distcompare)
methods::setMethod("print", c("x" ="plotcombine"), print.plotcombine)
methods::setMethod("print", c("x" ="plotrank"), print.plotrank)
11 changes: 7 additions & 4 deletions R/distanceCompare.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ methods::setClass("distcompare",

#' Compares Optimal Transport Distances Between WpProj and Original Models
#'
#' @param models Models from WpProj methods
#' @param models A list of models from WpProj methods
#' @param target The target to compare the methods to. Should be a list with slots "parameters" to compare the parameters and "predictions" to compare predictions
#' @param power The power parameter of the Wasserstein distance.
#' @param method Which approximation to the Wasserstein distance to use. Should be one of "exact", "sinkhorn", "greenkhorn", "gandkhorn", "randkhorn", or "hilbert".
Expand All @@ -14,7 +14,10 @@ methods::setClass("distcompare",
#' @param transform Transformation function for the predictions.
#' @param ... other options passed to the wasserstein distance function
#'
#' @return an object of class `distcompare` with slots `parameters`, `predictions`, and `p`.
#' @return an object of class `distcompare` with slots `parameters`, `predictions`, and `p`. The slots `parameters` and `predictions` are data frames. See the details for more info. The slot `p` is the power parameter of the Wasserstein distance used in the distance calculation.
#'
#' @details
#' For the data frames, `dist` is the Wasserstein distance, `nactive` is the number of active variables in the model, `groups` is the name distinguishing the model, and `method` is the method used to calculate the distance (i.e., exact, sinkhorn, etc.). If the list in `models` is named, these will be used as the group names otherwise the group names will be created based on the call from the `WpProj` method.
#'
#' @export
#' @examples
Expand All @@ -35,7 +38,7 @@ methods::setClass("distcompare",
#' method = "binary program", solver = "lasso",
#' options = list(solver.options = list(penalty = "mcp"))
#' )
#' dc <- distCompare(models = list(fit1, fit2),
#' dc <- distCompare(models = list("L1" = fit1, "BP" = fit2),
#' target = list(parameters = post_beta, predictions = post_mu))
#' plot(dc)
#' }
Expand Down Expand Up @@ -336,7 +339,7 @@ set_equal_y_limits.distcompare <- function(x){
#'
#' @keywords internal
#'
#' @return The ranks of a `distcompare` object
#' @return The ranks of a `distcompare` object as a list containing slots "predictions" and "parameters".
rank_distcompare <- function(distances) {
if(!is.distcompare(distances)) stop("Must be distcompare object")
rank.fun <- function(distance, quant) {
Expand Down
2 changes: 1 addition & 1 deletion man/combine.distcompare.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions man/distCompare.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/rank_distcompare.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions tests/testthat/test-combinedistcompare.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
test_that("combine.distcompare works", {
n <- 32
p <- 10
s <- 21
# covariates and coefficients
x <- matrix( stats::rnorm( p * n ), nrow = n, ncol = p )
beta <- (1:10)/10
#outcome
y <- x %*% beta + stats::rnorm(n)
# fake posterior
post_beta <- matrix(beta, nrow=p, ncol=s) + stats::rnorm(p*s, 0, 0.1)
post_mu <- x %*% post_beta #posterior predictive distributions
# fit models
## L1 model
fit.p2 <- WpProj(X=x, eta=post_mu, power = 2.0,
method = "L1", #default
solver = "lasso" #default
)
## approximate binary program
fit.p2.bp <- WpProj(X=x, eta=post_mu, theta = post_beta, power = 2.0,
method = "binary program",
solver = "lasso" #default because approximate algorithm is faster
)
## compare performance by measuring distance from full model
dc <- distCompare(models = list("L1" = fit.p2, "BP" = fit.p2.bp))

testthat::expect_silent(cc1 <- combine.distcompare(dc,dc))
testthat::expect_silent(cc2 <- combine.distcompare(list(dc,dc)))
testthat::expect_equal(cc1, cc2)
testthat::expect_true(inherits(cc1, "distcompare"))
testthat::expect_message(cc <- combine.distcompare(dc))
testthat::expect_equal(cc, dc)
})

0 comments on commit 1e4ef14

Please sign in to comment.