diff --git a/DESCRIPTION b/DESCRIPTION index 04336e59..82943801 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: torchvision Title: Models, Datasets and Transformations for Images -Version: 0.4.0.9000 +Version: 0.4.0.9001 Authors@R: c( person(given = "Daniel", family = "Falbel", @@ -26,7 +26,8 @@ RoxygenNote: 7.1.2 Suggests: testthat, magick, - coro + coro, + withr Imports: torch (>= 0.3.0), fs, diff --git a/R/dataset-cifar.R b/R/dataset-cifar.R index 53fa8511..7f4dcb59 100644 --- a/R/dataset-cifar.R +++ b/R/dataset-cifar.R @@ -83,7 +83,7 @@ cifar10_dataset <- torch::dataset( p <- download_and_cache(self$url) if (!tools::md5sum(p) == self$md5) - runtime_error("Corrupt file!") + runtime_error(sprintf("Corrupt file! Delete the file in '%s' and try again.", p)) utils::untar(p, exdir = self$root) }, diff --git a/R/dataset-mnist.R b/R/dataset-mnist.R index 65c95f51..0e066873 100644 --- a/R/dataset-mnist.R +++ b/R/dataset-mnist.R @@ -64,13 +64,11 @@ mnist_dataset <- dataset( filename <- tail(strsplit(r[1], "/")[[1]], 1) destpath <- file.path(self$raw_folder, filename) - withr::with_options( - list(timeout = 600), - utils::download.file(r[1], destfile = destpath) - ) + p <- download_and_cache(r[1], prefix = class(self)[1]) + fs::file_copy(p, destpath) if (!tools::md5sum(destpath) == r[2]) - runtime_error("MD5 sums are not identical for file: {r[1}.") + runtime_error("MD5 sums are not identical for file: {r[1]}.") } diff --git a/R/tiny-imagenet-dataset.R b/R/tiny-imagenet-dataset.R index 0205d62b..f09180e3 100644 --- a/R/tiny-imagenet-dataset.R +++ b/R/tiny-imagenet-dataset.R @@ -41,7 +41,8 @@ tiny_imagenet_dataset <- torch::dataset( rlang::inform("Downloding tiny imagenet dataset!") - download.file(self$url, raw_path) + p <- download_and_cache(self$url) + fs::file_copy(p, raw_path) rlang::inform("Download complete. Now unzipping.") diff --git a/R/utils.R b/R/utils.R index 6bd80456..1f145b30 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1,13 +1,27 @@ -download_and_cache <- function(url, redownload = FALSE) { +download_and_cache <- function(url, redownload = FALSE, prefix = NULL) { cache_path <- rappdirs::user_cache_dir("torch") fs::dir_create(cache_path) + if (!is.null(prefix)) { + cache_path <- file.path(cache_path, prefix) + } + try(fs::dir_create(cache_path, recurse = TRUE), silent = TRUE) path <- file.path(cache_path, fs::path_file(url)) - if (!file.exists(path) || redownload) - utils::download.file(url, path, mode = "wb") + if (!file.exists(path) || redownload) { + # we should first download to a temporary file because + # download probalems could cause hard to debug errors. + tmp <- tempfile(fileext = fs::path_ext(path)) + on.exit({try({fs::file_delete(tmp)}, silent = TRUE)}, add = TRUE) + + withr::with_options( + list(timeout = 600), + utils::download.file(url, tmp, mode = "wb") + ) + fs::file_move(tmp, path) + } path } diff --git a/tests/testthat/test-models-vgg.R b/tests/testthat/test-models-vgg.R index 200d505a..3b02c061 100644 --- a/tests/testthat/test-models-vgg.R +++ b/tests/testthat/test-models-vgg.R @@ -18,12 +18,13 @@ test_that("vgg models works", { } - skip_on_os(os = "mac") # not downloading a bunch of files locally. - skip_on_os(os = "windows") # not downloading a bunch of files locally. + skip_on_ci() # unfortunatelly we don't have anough RAM on CI for that. + #skip_on_os(os = "mac") # not downloading a bunch of files locally. + #skip_on_os(os = "windows") # not downloading a bunch of files locally. for (m in vggs) { model <- m(pretrained = TRUE) - expect_tensor_shape(model(torch_ones(5, 3, 224, 224)), c(5, 1000)) + expect_tensor_shape(model(torch_ones(1, 3, 224, 224)), c(1, 1000)) rm(model) gc()