Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 29 additions & 37 deletions R/dataset-flickr.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#' Flickr8k Dataset
#'
#' Loads the Flickr8k dataset consisting of 8,000 images with five human-annotated captions per image.
#' The images in this dataset are in RGB format and vary in spatial resolution.
#'
#' The dataset is split into:
#' - `"train"`: training subset with captions.
#' - `"test"`: test subset with captions.
#' The Flickr8k and Flickr30k collections are **image captionning** datasets
#' composed of 8,000 and 30,000 color images respectively, each paired with five
#' human-annotated captions. The images are in RGB format with varying spatial
#' resolutions, and these datasets are widely used for training and evaluating
#' vision-language models.
#'
#' @inheritParams fgvc_aircraft_dataset
#' @param root : Root directory for dataset storage. The dataset will be stored under `root/flickr8k`.
Expand All @@ -25,17 +24,27 @@
#' first_item <- flickr8k[1]
#' first_item$x # image array with shape {3, H, W}
#' first_item$y # character vector containing five captions.
#'
#' # Load the Flickr30k caption dataset
#' flickr30k <- flickr30k_caption_dataset(download = TRUE)
#'
#' # Access the first item
#' first_item <- flickr30k[1]
#' first_item$x # image array with shape {3, H, W}
#' first_item$y # character vector containing five captions.
#' }
#'
#' @name flickr8k_caption_dataset
#' @aliases flickr8k_caption_dataset
#' @title Flickr8k Caption Dataset
#' @name flickr_caption_dataset
#' @title Flickr Caption Datasets
#' @rdname flickr_caption_dataset
#' @family caption_dataset
#' @export
flickr8k_caption_dataset <- torch::dataset(
name = "flickr8k",
training_file = "train.rds",
test_file = "test.rds",
class_index_file = "classes.rds",
archive_size = "1 GB",

resources = list(
c("https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip", "bf6c1abcb8e4a833b7f922104de18627"),
Expand All @@ -55,10 +64,9 @@ flickr8k_caption_dataset <- torch::dataset(
self$target_transform <- target_transform
self$train <- train
self$split <- if (train) "train" else "test"

cli_inform("{.cls {class(self)[[1]]}} Dataset (~1GB) will be downloaded and processed if not already cached.")

if (download)
cli_inform("{.cls {class(self)[[1]]}} Dataset (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
self$download()

if (!self$check_exists())
Expand Down Expand Up @@ -117,15 +125,16 @@ flickr8k_caption_dataset <- torch::dataset(
self$captions <- data$captions
self$classes <- readRDS(file.path(self$processed_folder, self$class_index_file))

cli_inform("Split '{self$split}' loaded with {length(self$images)} samples.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$images)} images across {length(self$classes)} classes.")
},

download = function() {

if (self$check_exists())
return()

cli_inform("Downloading {.cls {class(self)[[1]]}} split: '{self$split}'")
cli_inform("Downloading {.cls {class(self)[[1]]}}...")

fs::dir_create(self$raw_folder)

for (r in self$resources) {
Expand All @@ -141,6 +150,9 @@ flickr8k_caption_dataset <- torch::dataset(
utils::untar(tar_path, exdir = self$raw_folder)
}
}

cli_inform("{.cls {class(self)[[1]]}} dataset downloaded and extracted successfully.")

},

check_processed_exists = function() {
Expand Down Expand Up @@ -183,13 +195,6 @@ flickr8k_caption_dataset <- torch::dataset(

#' Flickr30k Dataset
#'
#' Loads the Flickr30k dataset consisting of 30,000 images with five human-annotated captions per image.
#' The images in this dataset are in RGB format and vary in spatial resolution.
#'
#' The dataset is split into:
#' - `"train"`: training subset with captions.
#' - `"test"`: test subset with captions.
#'
#' @inheritParams flickr8k_caption_dataset
#' @param root Character. Root directory where the dataset will be stored under `root/flickr30k`.
#'
Expand All @@ -198,24 +203,12 @@ flickr8k_caption_dataset <- torch::dataset(
#' - `x`: a H x W x 3 integer array representing an RGB image.
#' - `y`: a character vector containing all five captions associated with the image.
#'
#' @examples
#' \dontrun{
#' # Load the Flickr30k caption dataset
#' flickr30k <- flickr30k_caption_dataset(download = TRUE)
#'
#' # Access the first item
#' first_item <- flickr30k[1]
#' first_item$x # image array with shape {3, H, W}
#' first_item$y # character vector containing five captions.
#' }
#'
#' @name flickr30k_caption_dataset
#' @aliases flickr30k_caption_dataset
#' @title Flickr30k Caption Dataset
#' @rdname flickr_caption_dataset
#' @export
flickr30k_caption_dataset <- torch::dataset(
name = "flickr30k",
inherit = flickr8k_caption_dataset,
archive_size = "4.1 GB",
resources = list(
c("https://uofi.app.box.com/shared/static/1cpolrtkckn4hxr1zhmfg0ln9veo6jpl.gz", "985ac761bbb52ca49e0c474ae806c07c"),
c("https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip", "4fa8c08369d22fe16e41dc124bd1adc2")
Expand All @@ -234,9 +227,8 @@ flickr30k_caption_dataset <- torch::dataset(
self$train <- train
self$split <- if (train) "train" else "test"

cli_inform("{.cls {class(self)[[1]]}} Dataset (~4.1GB) will be downloaded and processed if not already cached.")

if (download)
cli_inform("{.cls {class(self)[[1]]}} Dataset (~{.emph {self$archive_size}}) will be downloaded and processed if not already available.")
self$download()

if (!self$check_exists())
Expand All @@ -261,7 +253,7 @@ flickr30k_caption_dataset <- torch::dataset(
self$captions <- vapply(self$filenames, function(f) caption_to_index[[f]], integer(1))
self$classes <- captions_map

cli_inform("Split '{self$split}' loaded with {length(self$images)} samples.")
cli_inform("{.cls {class(self)[[1]]}} dataset loaded with {length(self$images)} images across {length(self$classes)} classes.")
},

check_exists = function() {
Expand Down
4 changes: 4 additions & 0 deletions man/coco_caption_dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 0 additions & 58 deletions man/flickr8k_caption_dataset.Rd

This file was deleted.

49 changes: 38 additions & 11 deletions man/flickr30k_caption_dataset.Rd → man/flickr_caption_dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 18 additions & 8 deletions tests/testthat/test-dataset-flickr.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ t <- withr::local_tempdir()
test_that("tests for the flickr8k dataset for train split", {
skip_on_cran()

expect_error(
flickr8k <- flickr8k_caption_dataset(root = tempfile()),
class = "rlang_error"
)
skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")


flickr8k <- flickr8k_caption_dataset(root = t, train = TRUE, download = TRUE)
expect_length(flickr8k, 6000)
Expand All @@ -27,6 +26,9 @@ test_that("tests for the flickr8k dataset for train split", {
test_that("tests for the flickr8k dataset for test split", {
skip_on_cran()

skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")

flickr8k <- flickr8k_caption_dataset(root = t, train = FALSE)
expect_length(flickr8k, 1000)
first_item <- flickr8k[1]
Expand All @@ -43,6 +45,9 @@ test_that("tests for the flickr8k dataset for test split", {

test_that("tests for the flickr8k dataset for dataloader", {
skip_on_cran()

skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")

flickr8k <- flickr8k_caption_dataset(
root = t,
Expand Down Expand Up @@ -70,10 +75,9 @@ test_that("tests for the flickr8k dataset for dataloader", {
test_that("tests for the flickr30k dataset for train split", {
skip_on_cran()

expect_error(
flickr30k <- flickr30k_caption_dataset(root = tempfile()),
class = "rlang_error"
)
skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")


flickr30k <- flickr30k_caption_dataset(root = t, train = TRUE, download = TRUE)
expect_length(flickr30k, 29000)
Expand All @@ -93,6 +97,9 @@ test_that("tests for the flickr30k dataset for train split", {
test_that("tests for the flickr30k dataset for test split", {
skip_on_cran()

skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")

flickr30k <- flickr30k_caption_dataset(root = t, train = FALSE)
expect_length(flickr30k, 1000)
first_item <- flickr30k[1]
Expand All @@ -110,6 +117,9 @@ test_that("tests for the flickr30k dataset for test split", {
test_that("tests for the flickr30k dataset for dataloader", {
skip_on_cran()

skip_if(Sys.getenv("TEST_LARGE_DATASETS", unset = 0) != 1,
"Skipping test: set TEST_LARGE_DATASETS=1 to enable tests requiring large downloads.")

flickr30k <- flickr30k_caption_dataset(
root = t,
transform = function(x) {
Expand Down