Skip to content

Commit

Permalink
speed up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ericdunipace committed Feb 1, 2024
1 parent f24a28a commit 280c6e2
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 22 deletions.
Binary file modified src/selectionVarMeanGeneration.o
Binary file not shown.
Binary file modified src/trans_univariate_approx_pwr.o
Binary file not shown.
9 changes: 5 additions & 4 deletions tests/testthat/test-W2L1.R
Original file line number Diff line number Diff line change
Expand Up @@ -659,13 +659,14 @@ testthat::test_that("W2L1 function for projection", {
testthat::test_that("W2L1 function for grouped projection", {
set.seed(283947)

n <- 256
p <- 100
n <- 32
p <- 20
g <- 10
s <- 21

x <- matrix(stats::rnorm(p*n), nrow=n, ncol=p)
beta <- rep((1:10)/10, 10)
groups <- rep(1:10, 10)
beta <- rep((1:g)/g, p/g)
groups <- rep(1:g, p/g)
y <- x %*% beta + stats::rnorm(n)

#posterior
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-WPL1.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ testthat::test_that("WPL1 works for W1", {

n <- 32
p <- 10
s <- 99
s <- 21

x <- matrix(stats::rnorm(p*n), nrow=n, ncol=p)
beta <- (1:10)/10
Expand Down Expand Up @@ -265,7 +265,7 @@ testthat::test_that("WPL1 works for WInf", {

n <- 32
p <- 10
s <- 99
s <- 21

x <- matrix(stats::rnorm(p*n), nrow=n, ncol=p)
beta <- (1:10)/10
Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test-WPR2.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ wpr2.prep <- function(n, p, s) {
test_that("WPR2 works", {
set.seed(203402)

n <- 128
n <- 32
p <- 10
s <- 100
s <- 21

x <- matrix( stats::rnorm( p * n ), nrow = n, ncol = p )
x_ <- t(x)
Expand Down Expand Up @@ -101,9 +101,9 @@ test_that("WPR2 works", {
testthat::test_that("WPR2 combining works", {
set.seed(203402)

n <- 128
n <- 32
p <- 10
s <- 100
s <- 21

out1 <- wpr2.prep(n,p,s)
out2 <- wpr2.prep(n,p,s)
Expand All @@ -119,10 +119,10 @@ testthat::test_that("WPR2 combining works", {
testthat::test_that("WPR2 plotting works", {
set.seed(203402)

n <- 128
n <- 64
p <- 10
s <- 100
reps <- 10
s <- 50
reps <- 3
out <- lapply(1:reps, function(i) wpr2.prep(n,p,s))
# debugonce(combine.WPR2)
comb <- combine.WPR2(out)
Expand Down
18 changes: 9 additions & 9 deletions tests/testthat/test-wasserstein.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ testthat::test_that("wasserstein matches transport package for shortsimplex", {
testthat::test_that("wasserstein from sp matches transport package",{
testthat::skip_if_not_installed("transport")
set.seed(32857)
A <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
B <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
A <- matrix(stats::rnorm(100*104),nrow=104,ncol=100)
B <- matrix(stats::rnorm(100*104),nrow=104,ncol=100)
at <- t(A)
bt <- t(B)
cost <- cost_calc(at,bt,2)
Expand All @@ -146,7 +146,7 @@ testthat::test_that("wasserstein from sp matches transport package",{
# microbenchmark::microbenchmark(wasserstein_(tplan$mass, cost, p = 2, tplan$from, tplan$to), unit = "us")
# microbenchmark::microbenchmark(sinkhorn_(mass_a, mass_b, cost^2, 0.05*median(cost^2), 100), unit="ms")

C <- t(A[1:100,,drop = FALSE])
C <- t(A[1:10,,drop = FALSE])
D <- t(B[1:2,,drop = FALSE])

cost2 <- cost_calc(C,D,2)
Expand All @@ -160,8 +160,8 @@ testthat::test_that("wasserstein from sp matches transport package",{

testthat::test_that("make sure wass less than all other transports", {
set.seed(32857)
A <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
B <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
A <- matrix(stats::rnorm(100*124),nrow=124,ncol=100)
B <- matrix(stats::rnorm(100*124),nrow=124,ncol=100)
at <- t(A)
bt <- t(B)
cost <- cost_calc(at,bt,2)
Expand Down Expand Up @@ -190,8 +190,8 @@ testthat::test_that("make sure wass less than all other transports", {

testthat::test_that("make sure sinkhorn outputs agree and are less than wass", {
set.seed(32857)
A <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
B <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
A <- matrix(stats::rnorm(100*104),nrow=104,ncol=100)
B <- matrix(stats::rnorm(100*104),nrow=104,ncol=100)
at <- t(A)
bt <- t(B)
cost <- cost_calc(at,bt,2)
Expand All @@ -211,8 +211,8 @@ testthat::test_that("make sure sinkhorn outputs agree and are less than wass", {

testthat::test_that("give error when p < 1", {
set.seed(32857)
A <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
B <- matrix(stats::rnorm(1000*1024),nrow=1024,ncol=1000)
A <- matrix(stats::rnorm(100*124),nrow=124,ncol=100)
B <- matrix(stats::rnorm(100*124),nrow=124,ncol=100)
ground_p <- 2
p <- 0

Expand Down

0 comments on commit 280c6e2

Please sign in to comment.