Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify and and optimize prcomp_irlba #52

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jan-glx
Copy link

@jan-glx jan-glx commented Nov 18, 2019

This gives a significant speed-up for sparse matrices by avoiding to create dense intermediates.

This gives a significant speed-up for sparse matrices
@jan-glx
Copy link
Author

jan-glx commented Nov 18, 2019

library(irlba)
#> Loading required package: Matrix
library(Matrix)

prcomp_irlba_new <- function(x, n = 3, retx = TRUE, center = TRUE, scale. = FALSE, ...)
{
  if (hasArg(tol))
    warning("The `tol` truncation argument from `prcomp` is not supported by
`prcomp_irlba`. If specified, `tol` is passed to the `irlba` function to
control that algorithm's convergence tolerance. See `?prcomp_irlba` for help.")

# Try to convert data frame to matrix...
  if (is.data.frame(x)) x <- as.matrix(x)

  col_means <- colMeans(x)
  center <- if (!is.logical(center)) center else if (center) col_means else 0

  col_vars <- (colMeans(x^2) - 2*col_means*center + center^2) / (1 - 1/nrow(x))
  scale. <- if (!is.logical(scale.)) scale. else if (scale.) sqrt(col_vars) else 1

  args <- list(A=x, nv=n)
  if(!isTRUE(all(center==0))) args$center <- center # center & scale are only supplied to irlba if
  if(!isTRUE(all(scale.==1))) args$scale <- scale.  # centering/scaling would actually be performed
  args <- c(args, list(...))

  s <- do.call(irlba, args=args)
  ans <-list(
    sdev = s$d / sqrt(nrow(x) - 1),
    rotation = s$v,
    center = if(is.null(args$center)) FALSE else args$center,
    scale = if(is.null(args$center)) FALSE else args$center
  )
  colnames(ans$rotation) <- paste("PC", seq_len(ncol(ans$rotation)), sep="")
  if (retx)
  {
    ans$x <- s$u %*% diag(s$d)
    colnames(ans$x) <- colnames(ans$rotation)
  }
  ans$totalvar <- sum(col_vars/scale.^2)
  class(ans) <- c("irlba_prcomp", "prcomp")
  ans
}

prcomp_irlba_old <- function(x, n = 3, retx = TRUE, center = TRUE, scale. = FALSE, ...)
{
  a <- names(as.list(match.call()))
  ans <- list(scale=scale.)
  if ("tol" %in% a)
    warning("The `tol` truncation argument from `prcomp` is not supported by
`prcomp_irlba`. If specified, `tol` is passed to the `irlba` function to
control that algorithm's convergence tolerance. See `?prcomp_irlba` for help.")
# Try to convert data frame to matrix...
  if (is.data.frame(x)) x <- as.matrix(x)
  args <- list(A=x, nv=n)
  if (is.logical(center))
  {
    if (center) args$center <- colMeans(x)
  } else args$center <- center
  if (is.logical(scale.))
  {
      if (is.numeric(args$center))
      {
        f <- function(i) sqrt(sum((x[, i] - args$center[i]) ^ 2) / (nrow(x) - 1L))
        scale. <- vapply(seq(ncol(x)), f, pi, USE.NAMES=FALSE)
        if (ans$scale) ans$totalvar <- ncol(x)
        else ans$totalvar <- sum(scale. ^ 2)
      } else
      {
        if (ans$scale)
        {
          scale. <- apply(x, 2L, function(v) sqrt(sum(v ^ 2) / max(1, length(v) - 1L)))
          f <- function(i) sqrt(sum((x[, i] / scale.[i]) ^ 2) / (nrow(x) - 1L))
          ans$totalvar <- sum(vapply(seq(ncol(x)), f, pi, USE.NAMES=FALSE) ^ 2)
        } else
        {
          f <- function(i) sum(x[, i] ^ 2) / (nrow(x) - 1L)
          ans$totalvar <- sum(vapply(seq(ncol(x)), f, pi, USE.NAMES=FALSE))
        }
      }
      if (ans$scale) args$scale <- scale.
  } else
  {
    args$scale <- scale.
    f <- function(i) sqrt(sum((x[, i] / scale.[i]) ^ 2) / (nrow(x) - 1L))
    ans$totalvar <- sum(vapply(seq(ncol(x)), f, pi, USE.NAMES=FALSE))
  }
  if (!missing(...)) args <- c(args, list(...))

  s <- do.call(irlba, args=args)
  ans$sdev <- s$d / sqrt(max(1, nrow(x) - 1))
  ans$rotation <- s$v
  colnames(ans$rotation) <- paste("PC", seq(1, ncol(ans$rotation)), sep="")
  ans$center <- args$center
  if (retx)
  {
    ans <- c(ans, list(x = sweep(s$u, 2, s$d, FUN=`*`)))
    colnames(ans$x) <- paste("PC", seq(1, ncol(ans$rotation)), sep="")
  }
  class(ans) <- c("irlba_prcomp", "prcomp")
  ans
}

n <- 10000
p <- 1000
mat <- matrix(rpois(n = n*p, lambda = 0.005), n, p)
sparse_mat <- as(mat, "sparseMatrix")

(lb <- bench::mark(
  prcomp_irlba_old(mat, scale.=TRUE),
  prcomp_irlba_old(sparse_mat, scale.=TRUE),
  prcomp_irlba_new(mat, scale.=TRUE),
  prcomp_irlba_new(sparse_mat, scale.=TRUE),
  check = FALSE
))
#> Warning: Some expressions had a GC in every iteration; so filtering is disabled.
#> # A tibble: 4 x 6
#>   expression                                       min   median `itr/sec`
#>   <bch:expr>                                  <bch:tm> <bch:tm>     <dbl>
#> 1 prcomp_irlba_old(mat, scale. = TRUE)           1.84s    1.84s     0.542
#> 2 prcomp_irlba_old(sparse_mat, scale. = TRUE) 296.94ms 361.07ms     2.77 
#> 3 prcomp_irlba_new(mat, scale. = TRUE)            1.9s     1.9s     0.527
#> 4 prcomp_irlba_new(sparse_mat, scale. = TRUE)  30.73ms  32.94ms    20.8  
#> # ... with 2 more variables: mem_alloc <bch:byt>, `gc/sec` <dbl>
plot(lb)
#> Loading required namespace: tidyr

Created on 2019-11-18 by the reprex package (v0.3.0)

@codecov-io
Copy link

codecov-io commented Nov 18, 2019

Codecov Report

Merging #52 into master will decrease coverage by 0.33%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #52      +/-   ##
==========================================
- Coverage    89.1%   88.76%   -0.34%     
==========================================
  Files           8        8              
  Lines         881      801      -80     
==========================================
- Hits          785      711      -74     
+ Misses         96       90       -6
Impacted Files Coverage Δ
R/prcomp.R 89.18% <100%> (+5.51%) ⬆️
R/irlba.R 82.82% <0%> (-2.45%) ⬇️
R/ssvd.R 89.58% <0%> (-1.49%) ⬇️
src/utility.c 95.83% <0%> (-0.17%) ⬇️
src/irlb.c 96.27% <0%> (+0.6%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2ad759c...12b4ced. Read the comment docs.

@jan-glx
Copy link
Author

jan-glx commented Nov 18, 2019

This PR changes the order of the elements in the returned list to match the order of stats::prcomp. Happy to change it back to the irlba::prcomp order...

@LTLA
Copy link
Contributor

LTLA commented Nov 18, 2019

Some comments from a developer perspective:

  • colMeans and colSums are executed even if center= and scale= are not requested. This is not desirable as it results in an unnecessary 1-5 minute penalty for very large matrices (e.g., file-backed).
  • The new scaling calculation is less numerically stable and can suffer from catastrophic cancellation.
x <- matrix(rnorm(100), 10, 10) + 1e9
center <- col_means <- colMeans(x)
col_vars <- (colMeans(x^2) - 2*col_means*center + center^2) / (1 - 1/nrow(x))
col_vars
## [1]  142.2222    0.0000 -142.2222  142.2222    0.0000    0.0000 -142.2222
## [8] -142.2222 -142.2222  142.2222
true_col_vars <- matrixStats::colVars(x)
true_col_vars
## [1] 0.5213468 0.7576393 1.7165307 0.9792770 0.9100585 0.7977485 1.0333882
## [8] 1.0694728 0.9085411 1.1655892

Perhaps this wouldn't be likely to occur for single-cell data, but people should be able to use irlba for other things with arbitrary location.

@bwlewis
Copy link
Owner

bwlewis commented Nov 19, 2019

Thanks for this!

I agree a speedup is desirable in the sparse case, and the order alignment with prcomp is a good idea. I'm carefully considering the @LTLA 's comments. We can, for instance, easily make the scaling parameters optionally computed as required. And hopefully carefully address numerical stability. After which I hope to merge and modify this. -bwl

@jan-glx
Copy link
Author

jan-glx commented Nov 19, 2019

Thanks for the comments, Aaron & for consideration @bwlewis!

  • colMeans and colSums are executed even if center= and scale= are not requested. This is not desirable as it results in an unnecessary 1-5 minute penalty for very large matrices (e.g., file-backed

We can, for instance, easily make the scaling parameters optionally computed as required.

It seems to me like these have also been computed before and that this is necessary to compute totalvar? Or do you suggest to add a ret.totalvar=TRUE argument? In any case, I feel like this overhead should should be less than computing even just one PC?

The new scaling calculation is less numerically stable and can suffer from catastrophic cancellation.

Now this is more serious, I would be happy to use matrixStats (would it be ok to add it as a dependency in Imports?) but it seems like it is relying on the same flawed algorithm:

matrixStats::colVars(x, center=center)
##  [1]   0.0000 134.7368   0.0000   0.0000   0.0000 134.7368   0.0000   0.0000   0.0000 134.7368

@jan-glx
Copy link
Author

jan-glx commented Nov 19, 2019

OK, figured it out: matrixStats::colVars(x) + (colMeans(x)-center)^2/(1-1/nrow(x))

# normal case
x <- matrix(rnorm(200), 20, 8) + 5
center <- colMeans(x) + 1

colMeans(t(t(x)-center)^2)/(1-1/nrow(x) ) # slow
#> [1] 2.276515 1.955017 1.922532 2.172460 2.162209 2.235109 2.140299 2.092379
matrixStats::colVars(x) + (colMeans(x)-center)^2/(1-1/nrow(x)) # ok
#> [1] 2.276515 1.955017 1.922532 2.172460 2.162209 2.235109 2.140299 2.092379
(colMeans(x^2) - 2*colMeans(x)*center + center^2) / (1 - 1/nrow(x)) # instable
#> [1] 2.276515 1.955017 1.922532 2.172460 2.162209 2.235109 2.140299 2.092379

# toxic case
x <- matrix(rnorm(200), 20, 8) + 1e9
center <- colMeans(x) + 1

colMeans(t(t(x)-center)^2)/(1-1/nrow(x) ) # slow
#> [1] 1.575861 2.406821 2.810052 1.853597 1.428814 1.697791 2.128663 1.925782
matrixStats::colVars(x) + (colMeans(x)-center)^2/(1-1/nrow(x)) # ok
#> [1] 1.575861 2.406821 2.810052 1.853597 1.428814 1.697791 2.128663 1.925782
(colMeans(x^2) - 2*colMeans(x)*center + center^2) / (1 - 1/nrow(x)) # instable
#> [1] -134.7368    0.0000    0.0000    0.0000 -134.7368    0.0000    0.0000
#> [8]    0.0000

# variance case
x <- matrix(rnorm(200), 20, 7) + 1e9
#> Warning in matrix(rnorm(200), 20, 7): data length [200] is not a sub-multiple or
#> multiple of the number of columns [7]
center <- colMeans(x) 

matrixStats::colVars(x) #variance only
#> [1] 1.1096721 0.9205107 1.1990229 1.0338478 0.9222606 1.0248208 0.8031727
colMeans(t(t(x)-center)^2)/(1-1/nrow(x) ) # slow
#> [1] 1.1096721 0.9205107 1.1990229 1.0338478 0.9222606 1.0248208 0.8031727
matrixStats::colVars(x) + (colMeans(x)-center)^2/(1-1/nrow(x)) # ok
#> [1] 1.1096721 0.9205107 1.1990229 1.0338478 0.9222606 1.0248208 0.8031727
(colMeans(x^2) - 2*colMeans(x)*center + center^2) / (1 - 1/nrow(x)) # instable
#> [1]    0.0000    0.0000  134.7368    0.0000    0.0000 -134.7368  134.7368

The matrixStats::colVars(x, center=) doesn't actually do what we need. (Wondering if it should?)


EDIT: I just realized matrixStats does not support sparse matrices at all. There is https://github.com/const-ae/sparseMatrixStats/blob/f4221ff8fd8b655151565443570e87cf8f2c6ce1/src/methods.cpp#L246 but it uses the same flawed algorithm...

@LTLA
Copy link
Contributor

LTLA commented Nov 19, 2019

There are several points here that warrant further discussion.

The matrixStats::colVars(x) + (colMeans(x)-center)^2/(1-1/nrow(x)) workaround is cute but - if this were my own code - I would hesitate to use it. It took me several minutes after reading it to confirm that it was correct in all circumstances; much longer than reading the original branching code for handling scale.. Despite what the tidyverse folks may have you believe, shorter is not better for write-once-read-often package code. In addition, the workaround involves an extra colMeans() call, and as you have already noticed, colVars() does not support sparse matrices.

IMO, the best solution is to keep the original if/else branches and replace the vapply calls with an appropriate colVars generic (see below). If you feel it's too verbose, just shove the entire set of branches into a .define_center_and_scale() function and then it becomes a one-liner.

As for the colVars() generics, the correct approach would be to sit tight and wait for efforts related to Bioconductor/MatrixGenerics#5 to get incorporated into Matrix. (This is for medians right now, but the same effort is trivially applied to row/column variances.) However, if you have nothing better to do, then you can easily home-brew your own solution for the time being.

## NOTE: untested, but you should get the idea.
setGeneric(".my_colVars", function(x, ...) standardGeneric(".my_colVars"))

#' @importFrom Matrix t rowSums
setMethod(".my_colVars", "ANY", function(x, center=NULL) {
      if (!is.null(center)) {
          y <- t(x) - center
          rowSums(y^2)/(ncol(y)-1)
      } else {
          colSums(x^2)/(nrow(x)-1)
      }
}) 

#' @importFrom Matrix t colSums
setMethod(".my_colVars", "dgCMatrix", function(x, center=NULL) {
     if (!is.null(center)) {
          nzero <- diff(x@p)
          expanded <- rep(center, nzero)
          x@x <- (x@x - expanded)^2
          (colSums(x) + nzero * center^2)/(nrow(x)-1)
     } else {
          colSums(x^2)/(nrow(x)-1)
     }
})

Optimized implementations for other classes are left as an exercise for the reader.

I must admit that I never realized that irlba_prcomp was computing totalvar. I guess I am fortunate that BiocSingular always calls irlba() directly, so I've never paid the penalty.

@jan-glx
Copy link
Author

jan-glx commented Feb 1, 2021

@bwlewis Would you be happy relying on Bioconductor/MatrixGenerics?

@bwlewis
Copy link
Owner

bwlewis commented Feb 2, 2021 via email

@LTLA
Copy link
Contributor

LTLA commented Feb 3, 2021

I'll be honest with you guys, I've forgotten most of what I suggested. But I would very much like to get back into this - particularly interested in seeing how we can improve the R-side IRLBA code, which I depend on a lot for my S4 matrix abstractions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants