Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work on perf in gg_miss for large dataset #101

Closed
ColinFay opened this issue Aug 25, 2017 · 10 comments
Closed

Work on perf in gg_miss for large dataset #101

ColinFay opened this issue Aug 25, 2017 · 10 comments

Comments

@ColinFay
Copy link
Contributor

gg_miss() family seems to take a (very) long time to compute on large dataset.

Here's an example w/ diamonds:

mb <- microbenchmark::microbenchmark(gg_miss_case(diamonds), times = 10)
mb

capture d ecran 2017-08-24 a 14 52 04

Is this ggplot or naniar related?

@njtierney njtierney added this to Priority in CRAN Version 0.2.0 Aug 29, 2017
@ColinFay
Copy link
Contributor Author

ColinFay commented Oct 4, 2017

Current gg_miss_case_summary is :

miss_case_summary.default <- function(data, order = FALSE, ...){

    res <- purrrlyr::by_row(.d = data,
                            ..f = function(x) (mean(is.na(x)) * 100),
                            .collate = "row",
                            .to = "pct_miss") %>%
      purrrlyr::by_row(.d = .,
                       ..f = function(x) (sum(is.na(x))),
                       .collate = "row",
                       .to = "n_miss") %>%
      dplyr::mutate(case = 1:nrow(data),
                    n_miss_cumsum = cumsum(n_miss)) %>%
      dplyr::select(case,
                    n_miss,
                    pct_miss,
                    n_miss_cumsum)
  if (order) {
    return(dplyr::arrange(res, -n_miss))
  } else {
    return(res)
  }
}

Here's a quick shot at rewriting gg_miss_case in base R:

miss_case_summary_base <- function(data, order = FALSE, ...){
  
  res <- data
  res$pct_miss <- apply(X = data, MARGIN = 1, FUN = function(x) (mean(is.na(x)) * 100))
  res$n_miss <- apply(X = data, MARGIN = 1, FUN = function(x) (sum(is.na(x))))
  res$case <- 1:nrow(data)
  res$n_miss_cumsum <- cumsum(res$n_miss)
  res <- dplyr::select(res, 
                       case,
                       n_miss,
                       pct_miss,
                       n_miss_cumsum)
  
  if (order) {
    return(dplyr::arrange(res, -n_miss))
  } else {
    return(res)
  }
}

# Outputs are the same 

base_output <- miss_case_summary_base(airquality)
current_output <- miss_case_summary(airquality)

all_equal(base_output, current_output)
[1] TRUE

And here's the results from a microbenchmark w/ airquality and diamonds:

mb <- microbenchmark::microbenchmark(orig = miss_case_summary(airquality), 
                                     old_school = miss_case_summary_base(airquality))

Unit: milliseconds
       expr       min        lq     mean    median        uq       max neval
       orig 349.15314 393.77068 468.8828 423.84853 509.88387 937.31800   100
 old_school  21.25257  22.69959  28.5989  25.35215  32.29996  51.44304   100

The base solution is here ~16 times quicker than the current solution. But still, a matter of milliseconds here.

Things get more obvious when we move to diamonds :

library(ggplot2)
data(diamonds)
mb2 <- microbenchmark::microbenchmark(orig = miss_case_summary(diamonds), 
                                     old_school = miss_case_summary_bis(diamonds))

Unit: seconds
       expr       min         lq       mean     median         uq        max neval
       orig 67.256863 102.099246 102.850157 104.472683 106.681096 116.636314   100
 old_school  3.222446   3.634497   3.823108   3.770275   3.980973   5.133206   100

This is another order of magnitude :) We've got here units in seconds, and current miss_case is around 25 times slower than the function written in base.

When we tidy a little with a mutate :

miss_case_summary_mut <- function(data, order = FALSE, ...){
  
  res <- data %>% 
    dplyr::mutate(pct_miss = apply(X = ., MARGIN = 1, FUN = function(x) (mean(is.na(x)) * 100)),
                  n_miss = apply(X = data, MARGIN = 1, FUN = function(x) (sum(is.na(x)))), 
                  case = 1:nrow(.), 
                  n_miss_cumsum = cumsum(n_miss)) %>%
    dplyr::select(case,
                  n_miss,
                  pct_miss,
                  n_miss_cumsum)
  
  if (order) {
    return(dplyr::arrange(res, -n_miss))
  } else {
    return(res)
  }
}

mb3 <- microbenchmark::microbenchmark(mut = miss_case_summary_mut(airquality), 
                                     old_school = miss_case_summary_base(airquality))

Unit: milliseconds
       expr      min       lq     mean   median       uq      max neval
        mut 18.96697 19.56027 20.89321 20.08280 20.99360 44.78389   100
 old_school 11.88736 12.35850 13.12145 12.70893 13.54379 18.63309   100

library(ggplot2)
data(diamonds)
mb4 <- microbenchmark::microbenchmark(mut = miss_case_summary_mut(diamonds), 
                                      old_school = miss_case_summary_base(diamonds))

Unit: seconds
       expr      min       lq     mean   median       uq      max neval
        mut 1.338729 2.427672 2.439025 2.532146 2.602633 2.772421   100
 old_school 1.264281 2.376645 2.347995 2.460165 2.546791 3.255949   100

The mutate solution is as quick as the base solution.

@njtierney
Copy link
Owner

🎉 @ColinFay Thanks for taking the time to do this, it is super awesome, great job! :)

You have hit upon all of the things that I wanted to know about: performance and equal output.

It is interesting that base operations really are much faster, good to know that we've probably got lots of room for improvement in the future.

How would you feel about wrapping this up into a pull request in the next week or two, in time for the CRAN release (v0.2.0 or v0.3.0 [pending on how much we get done!]) at the end of this month?

@ColinFay
Copy link
Contributor Author

ColinFay commented Oct 5, 2017

@njtierney just performed a benchmark of all the miss_ functions and here's the result:

mb_miss <- microbenchmark(miss_case_cumsum = miss_case_cumsum(iris),
                          miss_case_pct = miss_case_pct(iris),
                          miss_case_prop = miss_case_prop(iris),
                          miss_case_summary = miss_case_summary(iris),
                          miss_case_table = miss_case_table(iris),
                          miss_prop_summary = miss_prop_summary(iris),
                          miss_var_cumsum = miss_var_cumsum(iris),
                          miss_var_pct = miss_var_pct(iris),
                          miss_var_prop = miss_var_prop(iris),
                          miss_var_run = miss_var_run(iris, ),
                          miss_var_span = miss_var_span(iris, Species, 2),
                          miss_var_summary = miss_var_summary(iris),
                          miss_var_table = miss_var_table(iris)
)
Unit: milliseconds
              expr          min           lq         mean       median           uq          max neval
  miss_case_cumsum 38755.524316 47262.163167 67580.360315 77145.548654 80307.896329 90460.844409   100
     miss_case_pct     1.471261     1.674807     2.511312     2.639975     3.080532     3.728153   100
    miss_case_prop     1.491867     1.671794     2.579630     2.662190     3.080234     5.941145   100
 miss_case_summary 42329.380487 47071.160921 65944.475240 75982.870742 79688.354552 92552.756479   100
   miss_case_table 19992.680837 23336.115267 32529.062504 36345.907107 39093.041141 63130.532522   100
 miss_prop_summary    11.802318    18.260203    22.280662    23.320887    23.993652    66.332404   100
   miss_var_cumsum     5.085456     7.809327     9.121983    10.112508    10.323440    17.388096   100
      miss_var_pct     1.841673     2.852282     3.455549     3.728420     3.835149     6.466376   100
     miss_var_prop     1.853609     2.865887     3.469881     3.725174     3.827486     8.856607   100
      miss_var_run     8.547970    13.232820    17.381646    17.307650    18.965652    47.030692   100
     miss_var_span  4385.277110  5857.601094  7785.928276  8762.017150  9335.342613 11706.524023   100
  miss_var_summary     6.293708     9.667991    11.486791    12.589340    13.005405    18.226312   100
    miss_var_table     9.894916    13.899672    17.856583    19.315199    20.053010    38.358893   100```

Here, the 4 time consuming functions are :

  • miss_case_cumsum (which first perform a miss_case_summary), and this function should be removed in the next release, as suggests Add miss_var_cumsum and miss_case_cumsum into miss_var_summary and miss_case_summary #98.
  • miss_case_summary which will be rewritten as shown in the last comment
  • miss_case_table, which is written with purrrlyr and has to be rewritten (like miss_case_summary)
  • miss_var_span is a bit slower but not in the same order of magnitude, so maybe for the following release.

I'll work on these two and send you a PR :)

@ColinFay
Copy link
Contributor Author

ColinFay commented Oct 5, 2017

{purrrlyr} seems to be the biggest time consumer here :

miss_case_table.default <- function(data){
  
  purrrlyr::by_row(.d = data,
                   # how many are missing in each row?
                   ..f = ~n_miss(.),
                   .collate = "row",
                   .to = "n_miss_in_case") %>%
    dplyr::group_by(n_miss_in_case) %>%
    dplyr::tally() %>%
    dplyr::mutate(pct_miss = (n / nrow(data) * 100)) %>%
    dplyr::rename(n_cases = n)
  
}

a <- miss_case_table(airquality)

miss_case_table_base <- function(data){
  res <- data
  res$n_miss_in_case <- apply(data, MARGIN = 1, FUN = n_miss)
  res %>%
    dplyr::group_by(n_miss_in_case) %>%
    dplyr::tally() %>%
    dplyr::mutate(pct_miss = (n / nrow(data) * 100)) %>%
    dplyr::rename(n_cases = n)
  
}

b <- miss_case_table_base(airquality)
all_equal(a, b)
[1] TRUE

library(ggplot2)
data(diamonds)
mb_table <- microbenchmark::microbenchmark(base_table = miss_case_table_base(diamonds))

Unit: milliseconds
expr      min       lq     mean   median       uq      max neval
base_table 620.2758 661.8756 1065.734 719.7365 802.8198 5155.629   100

Just rewrote the {purrrlyr} part and this new version is ~20 times quicker than the actual.

Not sure rewriting the dplyr part is worth the shot though.

I'll send a PR this weekend :)

@ColinFay
Copy link
Contributor Author

Hey,

Is this issue fixed by the cpp PR from Romain?

@njtierney
Copy link
Owner

Yes, ish! We're still merging the pull request at the moment actually - just doing some final testing.

I'll close this once I have the final bit of rcpp code merged in, it would be nice to keep all of the code together so we can look at the speedups! :)

@njtierney njtierney moved this from Priority to In Progress in CRAN Version 0.2.0 Nov 20, 2017
@njtierney
Copy link
Owner

OK so I couldn't quite reprex this as I would have liked (gist of the full doc is here - but here is a summary of the speed improvements: These will get merged in very soon.

# tidy up the benchmark output into a tibble
tidy_mb <- function(mb_obj){
  tibble::as_tibble(mb_obj) %>%
    dplyr::arrange(expr) %>%
    dplyr::group_by(expr) %>%
    dplyr::mutate(row_id = 1:n()) %>%
    dplyr::ungroup()
}

mb_old_funs <- readr::read_rds("~/Downloads/mb_old_master.rds")
mb_new_funs <- readr::read_rds("~/Downloads/mb_new_row_means.rds")

library(tidyverse)
#> ── Attaching packages ─────────────────────────────────────────────────────────────────────────── tidyverse 1.2.1 ──
#> ✔ ggplot2 2.2.1.9000     ✔ purrr   0.2.4     
#> ✔ tibble  1.3.4          ✔ dplyr   0.7.4     
#> ✔ tidyr   0.7.2          ✔ stringr 1.2.0     
#> ✔ readr   1.1.1          ✔ forcats 0.2.0
#> ── Conflicts ────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag()    masks stats::lag()

tidy_mb_old <- tidy_mb(mb_old_funs) %>% mutate(type = "old")
tidy_mb_new <- tidy_mb(mb_new_funs) %>% mutate(type = "new")

mb_combined <- bind_rows(tidy_mb_old, tidy_mb_new)

# I think this is in microseconds? Unsure.
mb_combined %>%
  group_by(type,expr) %>%
  summarise_at(.vars = vars(time),
               .funs = funs(min,mean,median,max)) %>%
  arrange(expr) %>%
  knitr::kable()
type expr min mean median max
new miss_case_pct 117943 184379.5 178148.0 402108
old miss_case_pct 106193 177510.1 180681.0 266083
new miss_case_prop 117773 180140.1 180333.5 361126
old miss_case_prop 110790 181768.4 180751.0 303183
new miss_case_summary 8419355 9754482.1 9321849.5 13886879
old miss_case_summary 148829857 173563729.5 171862037.5 242414116
new miss_case_table 5346805 6507095.1 5914116.5 11927311
old miss_case_table 114623875 129453846.5 128055961.5 199175146
new miss_prop_summary 5833181 6983529.0 6496031.0 11128861
old miss_prop_summary 5451496 7257661.4 6552096.5 19571237
new miss_var_pct 592197 745296.5 694597.5 2849866
old miss_var_pct 575854 1130419.0 694072.5 37490316
new miss_var_prop 589345 737217.7 673482.0 2695309
old miss_var_prop 584099 727344.2 687114.5 2647324
new miss_var_run 6059785 7127802.0 6559186.5 14579901
old miss_var_run 5585718 7049816.8 6612873.0 13353024
new miss_var_span 19837499 24554941.6 22591093.0 95149500
old miss_var_span 19753235 23586619.1 22935410.5 64614291
new miss_var_summary 4238269 4978811.9 4568662.0 9697506
old miss_var_summary 4031169 5054523.3 4737219.5 8254638
new miss_var_table 7108722 8238089.9 7847992.5 12146477
old miss_var_table 6919870 8810073.9 7883966.5 50184873
new gg_miss_case 14480615 17156486.4 16442428.0 23802507
old gg_miss_case 142392590 179270247.4 179722623.0 249276889
new gg_miss_fct 85789881 96275362.5 95870843.5 164521438
old gg_miss_fct 86559868 99718903.4 97457500.5 177047262
new gg_miss_var 10884530 13189968.1 11967101.5 85998486
old gg_miss_var 10038797 12704118.9 12072091.5 38914761
new gg_miss_span 27039057 30972826.0 30713547.0 37507724
old gg_miss_span 25270603 32757441.6 30862518.5 145657585
library(ggplot2)
mb_combined %>%
  ggplot(aes(x = expr,
             y = time,
             colour = type)) +
  geom_boxplot() +
  coord_flip()

# let's focus on the ones that have a more noticeable difference
mb_combined %>%
  dplyr::filter(expr %in% c("gg_miss_case",
                     "miss_var_table",
                     "miss_var_summary",
                     "miss_var_span",
                     "miss_case_table",
                     "miss_case_summary")) %>%
  ggplot(aes(x = expr,
             y = time,
             colour = type)) +
  geom_boxplot() +
  coord_flip()

@njtierney
Copy link
Owner

For the moment this issue is addressed, but the c++ branch will make everything faster once again. I will open up another issue for that comparison when we come to it.

@njtierney njtierney moved this from In Progress to Done in CRAN Version 0.2.0 Dec 15, 2017
@romainfrancois
Copy link
Contributor

yep. Let's just not make a huge "C++ branch" though. Once we have the initial one merged in, we can make small branches that are easily reviewable and mergeable ...

@njtierney
Copy link
Owner

Yup, I agree - it is ideal to have a nice compact branch with incremental changes at each pull request :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Development

No branches or pull requests

3 participants