Skip to content

Commit

Permalink
avoiding coro::yield() and coro::generator() should make dataload…
Browse files Browse the repository at this point in the history
…ers within luz slightly faster.
  • Loading branch information
dfalbel committed Feb 14, 2023
1 parent 3c67824 commit e2ca79e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ export(set_opt_hparams)
export(setup)
import(torch)
importFrom(coro,as_iterator)
importFrom(coro,is_exhausted)
importFrom(fs,dir_ls)
importFrom(generics,fit)
importFrom(magrittr,"%>%")
Expand Down
12 changes: 6 additions & 6 deletions R/accelerator.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ as_device_dataloader <- function(x, device) {
#' @export
as_iterator.device_dataloader <- function(x) {
g <- NextMethod()
gen <- coro::generator(function() {
for(batch in g) {
coro::yield(to_device(batch, device = x$.device))
}
})
gen()
device <- x$.device
function() {
batch <- g()
to_device(batch, device = device)
}
}

to_device <- function(batch, device) {
if (!is.list(batch)) return(batch)
lapply(batch, function(x) {
if (inherits(x, "torch_tensor"))
x$to(device = device)
Expand Down

0 comments on commit e2ca79e

Please sign in to comment.