Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ Imports:
glue,
zeallot
Suggests:
arrow,
magick,
prettyunits,
testthat,
coro,
R.matlab,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ export(transform_ten_crop)
export(transform_to_tensor)
export(transform_vflip)
export(vision_make_grid)
export(whoi_plankton_dataset)
export(whoi_small_plankton_dataset)
importFrom(grDevices,dev.off)
importFrom(graphics,polygon)
importFrom(jsonlite,fromJSON)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Added `lfw_people_dataset()` and `lfw_pairs_dataset()` for loading Labelled Faces in the Wild (LFW) datasets (@DerrickUnleashed, #203).
* Added `places365_dataset()`for loading the Places365 dataset (@koshtiakanksha, #196).
* Added `pascal_segmentation_dataset()`, and `pascal_detection_dataset()` for loading the Pascal Visual Object Classes datasets (@DerrickUnleashed, #209).
* Added `whoi_plankton_dataset()`, and `whoi_small_plankton_dataset()` (@cregouby, #236).

## New models

Expand Down
4 changes: 2 additions & 2 deletions R/dataset-caltech.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ caltech101_dataset <- torch::dataset(
self$image_indices <- c(self$image_indices, seq_along(imgs))
}

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

.getitem = function(index) {
Expand Down Expand Up @@ -205,7 +205,7 @@ caltech256_dataset <- torch::dataset(
}, seq_along(self$classes), images_per_class, SIMPLIFY = FALSE),
use.names = FALSE
)
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

check_exists = function() {
Expand Down
8 changes: 4 additions & 4 deletions R/dataset-coco.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ coco_detection_dataset <- torch::dataset(
) {

year <- match.arg(year)
split <- if (train) "train" else "val"
split <- ifelse(train, "train", "val")

root <- fs::path_expand(root)
self$root <- root
Expand All @@ -76,7 +76,7 @@ coco_detection_dataset <- torch::dataset(

self$data_dir <- fs::path(root, glue::glue("coco{year}"))

image_year <- if (year == "2016") "2014" else year
image_year <- ifelse(year == "2016", "2014", year)
self$image_dir <- fs::path(self$data_dir, glue::glue("{split}{image_year}"))
self$annotation_file <- fs::path(self$data_dir, "annotations",
glue::glue("instances_{split}{year}.json"))
Expand Down Expand Up @@ -288,7 +288,7 @@ coco_caption_dataset <- torch::dataset(
) {

year <- match.arg(year)
split <- if (train) "train" else "val"
split <- ifelse(train, "train", "val")

root <- fs::path_expand(root)
self$root <- root
Expand Down Expand Up @@ -329,7 +329,7 @@ coco_caption_dataset <- torch::dataset(
image_id <- ann$image_id
y <- ann$caption

prefix <- if (self$split == "train") "COCO_train2014_" else "COCO_val2014_"
prefix <- ifelse(self$split == "train", "COCO_train2014_", "COCO_val2014_")
filename <- paste0(prefix, sprintf("%012d", image_id), ".jpg")
image_path <- fs::path(self$image_dir, filename)

Expand Down
8 changes: 3 additions & 5 deletions R/dataset-eurosat.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#' @inheritParams mnist_dataset
#' @param root (Optional) Character. The root directory where the dataset will be stored.
#' if empty, will use the default `rappdirs::user_cache_dir("torch")`.
#' @param split Character. Must be one of `train`, `val`, or `test`.
#' @param split One of `"train"`, `"val"`, or `"test"`. Default is `"val"`.
#'
#' @return A `torch::dataset` object. Each item is a list with:
#' * `x`: a 64x64 image tensor with 3 (RGB) or 13 (all bands) channels
Expand All @@ -39,7 +39,7 @@ eurosat_dataset <- torch::dataset(

initialize = function(
root = tempdir(),
split = "train",
split = "val",
download = FALSE,
transform = NULL,
target_transform = NULL
Expand All @@ -53,7 +53,7 @@ eurosat_dataset <- torch::dataset(
self$images_dir <- file.path(self$root, class(self)[1], "images")
self$split_file <- file.path(self$root, fs::path_ext_remove(basename(self$split_url)))

if (download){
if (download) {
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
self$download()
}
Expand Down Expand Up @@ -184,5 +184,3 @@ eurosat100_dataset <- torch::dataset(
split_url = "https://huggingface.co/datasets/torchgeo/eurosat/resolve/main/eurosat-100-{split}.txt?download=true",
archive_size = "7 MB"
)


40 changes: 28 additions & 12 deletions R/dataset-fer.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
fer_dataset <- dataset(
name = "fer_dataset",
archive_size = "90 MB",
url = "https://huggingface.co/datasets/JimmyUnleashed/FER-2013/resolve/main/fer2013.tar.gz",
md5 = "ca95d94fe42f6ce65aaae694d18c628a",
classes = c(
"Angry",
"Disgust",
"Fear",
"Happy",
"Sad",
"Surprise",
"Neutral"
),

initialize = function(
root = tempdir(),
Expand All @@ -39,25 +50,25 @@ fer_dataset <- dataset(
target_transform = NULL,
download = FALSE
) {

self$root <- root
self$train <- train
self$transform <- transform
self$target_transform <- target_transform
self$split <- if (train) "Train" else "Test"
self$split <- ifelse(train, "Train", "Test")
self$folder_name <- "fer2013"
self$url <- "https://huggingface.co/datasets/JimmyUnleashed/FER-2013/resolve/main/fer2013.tar.gz"
self$md5 <- "ca95d94fe42f6ce65aaae694d18c628a"
self$classes <- c("Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral")
self$class_to_idx <- setNames(seq_along(self$classes), self$classes)

if (download){
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
if (download) {
cli_inform(
"Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available."
)
self$download()
}

if (!self$check_files()) {
runtime_error("Dataset not found. You can use `download = TRUE` to download it.")
runtime_error(
"Dataset not found. You can use `download = TRUE` to download it."
)
}

csv_file <- file.path(self$root, self$folder_name, "fer2013.csv")
Expand Down Expand Up @@ -87,11 +98,13 @@ fer_dataset <- dataset(

y <- self$y[i]

if (!is.null(self$transform))
if (!is.null(self$transform)) {
x <- self$transform(x)
}

if (!is.null(self$target_transform))
if (!is.null(self$target_transform)) {
y <- self$target_transform(y)
}

list(x = x, y = y)
},
Expand All @@ -112,11 +125,14 @@ fer_dataset <- dataset(

archive <- download_and_cache(self$url)

if (!tools::md5sum(archive) == self$md5)
if (!tools::md5sum(archive) == self$md5) {
runtime_error("Corrupt file! Delete the file in {archive} and try again.")
}

untar(archive, exdir = self$root)
cli_inform("Dataset {.cls {class(self)[[1]]}} downloaded and extracted successfully.")
cli_inform(
"Dataset {.cls {class(self)[[1]]}} downloaded and extracted successfully."
)
},

check_files = function() {
Expand Down
8 changes: 4 additions & 4 deletions R/dataset-fgvc.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fgvc_aircraft_dataset <- dataset(
target_transform = NULL,
download = FALSE
) {

self$root <- root
self$split <- split
self$annotation_level <- annotation_level
Expand Down Expand Up @@ -132,10 +132,10 @@ fgvc_aircraft_dataset <- dataset(
.getitem = function(index) {
x <- jpeg::readJPEG(self$image_paths[index]) * 255

y <- if (self$annotation_level == "all") {
as.integer(self$labels_df[index, ])
if (self$annotation_level == "all") {
y <- as.integer(self$labels_df[index, ])
} else {
self$labels_df[[self$annotation_level]][index]
y <- self$labels_df[[self$annotation_level]][index]
}

if (!is.null(self$transform)) {
Expand Down
14 changes: 7 additions & 7 deletions R/dataset-flickr.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ flickr8k_caption_dataset <- torch::dataset(
self$transform <- transform
self$target_transform <- target_transform
self$train <- train
self$split <- if (train) "train" else "test"
self$split <- ifelse(train, "train", "test")

if (download)
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
self$download()
Expand Down Expand Up @@ -130,7 +130,7 @@ flickr8k_caption_dataset <- torch::dataset(

download = function() {

if (self$check_exists())
if (self$check_exists())
return()

cli_inform("Downloading {.cls {class(self)[[1]]}}...")
Expand Down Expand Up @@ -173,10 +173,10 @@ flickr8k_caption_dataset <- torch::dataset(
caption_index <- self$captions[[index]]
y <- self$classes[[caption_index]]

if (!is.null(self$transform))
if (!is.null(self$transform))
x <- self$transform(x)

if (!is.null(self$target_transform))
if (!is.null(self$target_transform))
y <- self$target_transform(y)

list(x = x, y = y)
Expand Down Expand Up @@ -225,13 +225,13 @@ flickr30k_caption_dataset <- torch::dataset(
self$transform <- transform
self$target_transform <- target_transform
self$train <- train
self$split <- if (train) "train" else "test"
self$split <- ifelse(train, "train", "test")

if (download)
cli_inform("Dataset {.cls {class(self)[[1]]}} (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
self$download()

if (!self$check_exists())
if (!self$check_exists())
cli_abort("Dataset not found. Use `download = TRUE` to download it.")

captions_path <- file.path(self$raw_folder, "dataset_flickr30k.json")
Expand Down
4 changes: 2 additions & 2 deletions R/dataset-lfw.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ lfw_people_dataset <- torch::dataset(
self$classes <- class_names
self$class_to_idx <- class_to_idx

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

download = function() {
Expand Down Expand Up @@ -283,7 +283,7 @@ lfw_pairs_dataset <- torch::dataset(
self$pairs <- do.call(rbind, pair_list)
self$img_path <- c(self$pairs$img1, self$pairs$img2)

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

.getitem = function(index) {
Expand Down
6 changes: 3 additions & 3 deletions R/dataset-mnist.R
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ emnist_dataset <- dataset(
},

.getitem = function(index) {
data_set <- if (self$is_train) self$data else self$test_data
targets_set <- if (self$is_train) self$targets else self$test_targets
data_set <- ifelse(self$is_train, self$data, self$test_data)
targets_set <- ifelse(self$is_train, self$targets, self$test_targets)

x <- data_set[index, , ]
y <- targets_set[index]
Expand All @@ -502,7 +502,7 @@ emnist_dataset <- dataset(
},

.length = function() {
data_set <- if (self$is_train) self$data else self$test_data
data_set <- ifelse(self$is_train, self$data, self$test_data)
dim(data_set)[1]
},

Expand Down
6 changes: 3 additions & 3 deletions R/dataset-oxfordiiitpet.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ oxfordiiitpet_segmentation_dataset <- torch::dataset(
self$classes <- c("Cat", "Dog")
}

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

download = function() {
Expand Down Expand Up @@ -299,7 +299,7 @@ oxfordiiitpet_dataset <- dataset(
self$class_to_idx <- data$class_to_idx
self$classes <- names(self$class_to_idx)

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

.getitem = function(index) {
Expand Down Expand Up @@ -363,7 +363,7 @@ oxfordiiitpet_binary_dataset <- dataset(
self$class_to_idx <- data$class_to_idx
self$classes <- c("Cat", "Dog")

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

.getitem = function(index) {
Expand Down
4 changes: 2 additions & 2 deletions R/dataset-pascal.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ pascal_segmentation_dataset <- torch::dataset(
self$img_path <- data$img_path
self$mask_paths <- data$mask_paths

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

download = function() {
Expand Down Expand Up @@ -314,7 +314,7 @@ pascal_detection_dataset <- torch::dataset(
install.packages("xml2")
}

cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$img_path)} images across {length(self$classes)} classes.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {self$.length()} images across {length(self$classes)} classes.")
},

.getitem = function(index) {
Expand Down
Loading
Loading