Skip to content

Commit

Permalink
Add option to load models from cache without outgoing traffic (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Mar 29, 2023
1 parent d94737b commit 543e9e4
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 16 deletions.
14 changes: 12 additions & 2 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ defmodule Bumblebee do
operating system. You can also configure it globally by
setting the `BUMBLEBEE_CACHE_DIR` environment variable
* `:offline` - if `true`, only cached files are accessed and
missing files result in an error. You can also configure it
globally by setting the `BUMBLEBEE_OFFLINE` environment
variable to `true`
* `:auth_token` - the token to use as HTTP bearer authorization
for remote files
Expand Down Expand Up @@ -797,22 +802,27 @@ defmodule Bumblebee do
defp download({:hf, repository_id, opts}, filename) do
revision = opts[:revision]
cache_dir = opts[:cache_dir]
offline = opts[:offline]
auth_token = opts[:auth_token]
subdir = opts[:subdir]

filename = if subdir, do: subdir <> "/" <> filename, else: filename

url = HuggingFace.Hub.file_url(repository_id, filename, revision)

HuggingFace.Hub.cached_download(url, cache_dir: cache_dir, auth_token: auth_token)
HuggingFace.Hub.cached_download(url,
cache_dir: cache_dir,
offline: offline,
auth_token: auth_token
)
end

defp normalize_repository!({:hf, repository_id}) when is_binary(repository_id) do
{:hf, repository_id, []}
end

defp normalize_repository!({:hf, repository_id, opts}) when is_binary(repository_id) do
opts = Keyword.validate!(opts, [:revision, :cache_dir, :auth_token, :subdir])
opts = Keyword.validate!(opts, [:revision, :cache_dir, :offline, :auth_token, :subdir])
{:hf, repository_id, opts}
end

Expand Down
48 changes: 34 additions & 14 deletions lib/bumblebee/huggingface/hub.ex
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ defmodule Bumblebee.HuggingFace.Hub do
Defaults to the standard cache location for the given operating
system
* `:offline` - if `true`, cached path is returned if exists and
and error otherwise
* `:auth_token` - the token to use as HTTP bearer authorization
for remote files
"""
@spec cached_download(String.t(), keyword()) :: {:ok, String.t()} | {:error, String.t()}
def cached_download(url, opts \\ []) do
cache_dir = opts[:cache_dir] || bumblebee_cache_dir()
offline = opts[:offline] || bumblebee_offline?()
auth_token = opts[:auth_token]

dir = Path.join(cache_dir, "huggingface")
Expand All @@ -47,25 +51,37 @@ defmodule Bumblebee.HuggingFace.Hub do
[]
end

with {:ok, etag, download_url} <- head_download(url, headers) do
metadata_path = Path.join(dir, metadata_filename(url))
entry_path = Path.join(dir, entry_filename(url, etag))
metadata_path = Path.join(dir, metadata_filename(url))

if offline do
case load_json(metadata_path) do
{:ok, %{"etag" => ^etag}} ->
{:ok, %{"etag" => etag}} ->
entry_path = Path.join(dir, entry_filename(url, etag))
{:ok, entry_path}

_ ->
case HTTP.download(download_url, entry_path, headers: headers) |> finish_request() do
:ok ->
:ok = store_json(metadata_path, %{"etag" => etag, "url" => url})
{:ok, entry_path}

error ->
File.rm_rf!(metadata_path)
File.rm_rf!(entry_path)
error
end
{:error, "could not find file in local cache and outgoing traffic is disabled"}
end
else
with {:ok, etag, download_url} <- head_download(url, headers) do
entry_path = Path.join(dir, entry_filename(url, etag))

case load_json(metadata_path) do
{:ok, %{"etag" => ^etag}} ->
{:ok, entry_path}

_ ->
case HTTP.download(download_url, entry_path, headers: headers) |> finish_request() do
:ok ->
:ok = store_json(metadata_path, %{"etag" => etag, "url" => url})
{:ok, entry_path}

error ->
File.rm_rf!(metadata_path)
File.rm_rf!(entry_path)
error
end
end
end
end
end
Expand Down Expand Up @@ -145,4 +161,8 @@ defmodule Bumblebee.HuggingFace.Hub do
:filename.basedir(:user_cache, "bumblebee")
end
end

defp bumblebee_offline?() do
System.get_env("BUMBLEBEE_OFFLINE") in ~w(1 true)
end
end
29 changes: 29 additions & 0 deletions test/bumblebee/huggingface/hub_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,35 @@ defmodule Bumblebee.HuggingFace.HubTest do

assert {:error, "repository not found"} = Hub.cached_download(url, cache_dir: tmp_dir)
end

@tag :tmp_dir
test "returns cached without checking etag when :offline is enabled",
%{bypass: bypass, tmp_dir: tmp_dir} do
Bypass.expect_once(bypass, "HEAD", "/file.json", fn conn ->
serve_with_etag(conn, ~s/"hash"/, "")
end)

Bypass.expect_once(bypass, "GET", "/file.json", fn conn ->
serve_with_etag(conn, ~s/"hash"/, "{}")
end)

url = url(bypass.port) <> "/file.json"

assert {:ok, path} = Hub.cached_download(url, cache_dir: tmp_dir)
assert File.read!(path) == "{}"

assert {:ok, path} = Hub.cached_download(url, cache_dir: tmp_dir, offline: true)
assert File.read!(path) == "{}"
end

@tag :tmp_dir
test "returns an error when :offline is enabled and file not in cache",
%{bypass: bypass, tmp_dir: tmp_dir} do
url = url(bypass.port) <> "/file.json"

assert {:error, "could not find file in local cache and outgoing traffic is disabled"} =
Hub.cached_download(url, cache_dir: tmp_dir, offline: true)
end
end

defp url(port), do: "http://localhost:#{port}"
Expand Down

0 comments on commit 543e9e4

Please sign in to comment.