Skip to content

Commit

Permalink
Add support for snapshotting.
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel committed Jul 6, 2023
1 parent aa2795b commit b697b23
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 5 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@

export(WEIGHTS_INDEX_NAME)
export(WEIGHTS_NAME)
export(hub_dataset_info)
export(hub_download)
export(hub_repo_info)
export(hub_snapshot)
importFrom(rlang,"%||%")
11 changes: 7 additions & 4 deletions R/hub_download.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#' @param force_download For re-downloading of files that are cached.
#' @param ... currenytly unused.
#'
#' @returns The file path of the downloaded or cached file.
#' @returns The file path of the downloaded or cached file. The snapshot path is returned
#' as an attribute.
#' @examples
#' try({
#' withr::with_envvar(c(HUGGINGFACE_HUB_CACHE = tempdir()), {
Expand Down Expand Up @@ -142,7 +143,7 @@ hub_download <- function(repo_id, filename, ..., revision = "main", repo_type =
type = "download",
)
progress <- function(down, up) {
if (down[1] !=0) {
if (down[1] != 0) {
cli::cli_progress_update(total = down[1], set = down[2], id = bar_id)
}
TRUE
Expand All @@ -156,6 +157,7 @@ hub_download <- function(repo_id, filename, ..., revision = "main", repo_type =
fs::file_move(tmp, blob_path)

# fs::link_create doesn't work for linking files on windows.
try(fs::file_delete(pointer_path), silent = TRUE) # delete the link to avoid warnings
file.symlink(blob_path, pointer_path)
})

Expand All @@ -171,8 +173,9 @@ hub_url <- function(repo_id, filename, ..., revision = "main", repo_type = "mode
}

get_pointer_path <- function(storage_folder, revision, relative_filename) {
snapshot_path <- fs::path(storage_folder, "snapshots")
pointer_path <- fs::path(snapshot_path, revision, relative_filename)
snapshot_path <- fs::path(storage_folder, "snapshots", revision)
pointer_path <- fs::path(snapshot_path, relative_filename)
attr(pointer_path, "snapshot_path") <- snapshot_path
pointer_path
}

Expand Down
44 changes: 44 additions & 0 deletions R/hub_info.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#' Queries information about Hub repositories
#'
#' @inheritParams hub_download
#' @param files_metadata Obtain files metadata information when querying repository information.
#' @export
hub_repo_info <- function(repo_id, ..., repo_type = NULL, revision = NULL, files_metadata = FALSE) {
if (is.null(repo_type) || repo_type == "model") {
path <- glue::glue("https://huggingface.co/api/models/{repo_id}")
} else {
path <- glue::glue("https://huggingface.co/api/{repo_type}s/{repo_id}")
}

if (!is.null(revision)) {
path <- glue::glue("{path}/revision/{revision}")
}

params <- list()
if (files_metadata) {
params$blobs <- TRUE
}

results <- httr::GET(
path,
query = params,
httr::add_headers(
"user-agent" = "hfhub/0.0.1"
)
)

httr::content(results)
}

#' @describeIn hub_repo_info Query information from a Hub Dataset
#' @export
hub_dataset_info <- function(repo_id, ..., revision = NULL, files_metadata = FALSE) {
hub_repo_info(
repo_id,
revision = revision,
repo_type = "dataset",
files_metadata = files_metadata,
token = token
)
}

53 changes: 53 additions & 0 deletions R/hub_snapshot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#' Snapshot the entire repository
#'
#' Downloads and stores all files from a Hugging Face Hub repository.
#' @inheritParams hub_download
#' @param allow_patterns A character vector containing patters that are used to
#' filter allowed files to snapshot.
#' @param ignore_patterns A character vector contaitning patterns to reject files
#' from being downloaded.
#'
#' @export
hub_snapshot <- function(repo_id, ..., revision = "main", repo_type = "model",
local_files_only = FALSE, force_download = FALSE,
allow_patterns = NULL, ignore_patterns = NULL) {
info <- hub_repo_info(repo_id, repo_type = repo_type)
all_files <- sapply(info$siblings, function(x) x$rfilename)

allowed_files <- all_files
if (!is.null(allow_patterns)) {
allowed_files <- lapply(allow_patterns, function(x) {
all_files[grepl(allow_patterns, all_files)]
})
allowed_files <- unique(unlist(allowed_files))
}

files <- allowed_files
if (!is.null(ignore_patterns)) {
for (pattern in ignore_patterns) {
files <- files[!grepl(pattern, files)]
}
}

id <- cli::cli_progress_bar(
name = "Downloading files",
type = "tasks",
total = length(files),
clear = FALSE
)

i <- 0
cli::cli_progress_step("Snapshotting files {i}/{length(files)}")
for (i in seq_along(files)) {
d <- hub_download(
repo_id = repo_id,
filename = files[i],
revision = info$sha,
repo_type = repo_type,
local_files_only = local_files_only,
force_download = force_download
)
}

attr(d, "snapshot_path")
}
3 changes: 2 additions & 1 deletion man/hub_download.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions man/hub_repo_info.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 39 additions & 0 deletions man/hub_snapshot.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/_snaps/hub_snapshot.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# snapshot

Code
p <- hub_snapshot("dfalbel/cran-packages", repo_type = "dataset",
allow_patterns = "\\.R")
Message <cliMessage>
i Snapshotting files 0/4
v Snapshotting files 4/4 [0ms]

7 changes: 7 additions & 0 deletions tests/testthat/test-hub_info.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
skip_on_cran()

test_that("dataset info", {
info <- hub_dataset_info("dfalbel/cran-packages")
expect_equal(info$author, "dfalbel")
expect_true(length(info$siblings) >= 13)
})
12 changes: 12 additions & 0 deletions tests/testthat/test-hub_snapshot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
skip_on_cran()

test_that("snapshot", {
expect_snapshot({
p <- hub_snapshot("dfalbel/cran-packages", repo_type = "dataset", allow_patterns = "\\.R")
},
transform = function(x) {
sub("\\[[0-9]+[a-z]+\\]", "[0ms]", x = x)
})

expect_true(length(fs::dir_ls(p)) >= 4)
})

0 comments on commit b697b23

Please sign in to comment.