Skip to content

Commit

Permalink
Support hfh 0.10 implicit auth (#5031)
Browse files Browse the repository at this point in the history
* support hfh 0.10 implicit auth

* update tests

* Bump minimum hfh to 0.2.0 and test minimum version

* style

* fix test

* fix tests

* again

* lucain's comment

* fix ci
  • Loading branch information
lhoestq committed Oct 5, 2022
1 parent 08dfdc9 commit c1a66f0
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 50 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/ci.yml
Expand Up @@ -38,7 +38,7 @@ jobs:
matrix:
test: ['unit', 'integration']
os: [ubuntu-latest, windows-latest]
pyarrow_version: [latest, 6.0.1]
deps_versions: [latest, minimum]
continue-on-error: ${{ matrix.test == 'integration' }}
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -63,12 +63,12 @@ jobs:
run: |
pip install .[tests]
pip install -r additional-tests-requirements.txt --no-deps
- name: Install latest PyArrow
if: ${{ matrix.pyarrow_version == 'latest' }}
run: pip install pyarrow --upgrade
- name: Install PyArrow ${{ matrix.pyarrow_version }}
if: ${{ matrix.pyarrow_version != 'latest' }}
run: pip install pyarrow==${{ matrix.pyarrow_version }}
- name: Install dependencies (latest versions)
if: ${{ matrix.deps_versions == 'latest' }}
run: pip install --upgrade pyarrow huggingface-hub
- name: Install depencencies (minimum versions)
if: ${{ matrix.deps_versions != 'latest' }}
run: pip install pyarrow==6.0.1 huggingface-hub==0.2.0 transformers
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Expand Up @@ -85,7 +85,8 @@
# for data streaming via http
"aiohttp",
# To get datasets from the Datasets Hub on huggingface.co
"huggingface-hub>=0.1.0,<1.0.0",
# minimum 0.2.0 for set_access_token
"huggingface-hub>=0.2.0,<1.0.0",
# Utilities from PyPA to e.g., compare versions
"packaging",
"responses<0.19",
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/arrow_dataset.py
Expand Up @@ -102,6 +102,7 @@
from .tasks import TaskTemplate
from .utils import logging
from .utils._hf_hub_fixes import create_repo
from .utils._hf_hub_fixes import list_repo_files as hf_api_list_repo_files
from .utils.file_utils import _retry, cached_path, estimate_dataset_size, hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.py_utils import asdict, convert_file_size_to_int, unique_values
Expand Down Expand Up @@ -4245,7 +4246,7 @@ def shards_with_embedded_external_files(shards):

shards = shards_with_embedded_external_files(shards)

files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token)
files = hf_api_list_repo_files(api, repo_id, repo_type="dataset", revision=branch, token=token)
data_files = [file for file in files if file.startswith("data/")]

def path_in_repo(_index, shard):
Expand Down
21 changes: 8 additions & 13 deletions src/datasets/load.py
Expand Up @@ -29,7 +29,7 @@

import fsspec
import requests
from huggingface_hub import HfApi, HfFolder
from huggingface_hub import HfApi

from . import config
from .arrow_dataset import Dataset
Expand Down Expand Up @@ -62,6 +62,7 @@
)
from .splits import Split
from .tasks import TaskTemplate
from .utils._hf_hub_fixes import dataset_info as hf_api_dataset_info
from .utils.deprecation_utils import deprecated
from .utils.file_utils import (
OfflineModeIsEnabled,
Expand Down Expand Up @@ -744,14 +745,11 @@ def __init__(
increase_load_count(name, resource_type="dataset")

def get_module(self) -> DatasetModule:
if isinstance(self.download_config.use_auth_token, bool):
token = HfFolder.get_token() if self.download_config.use_auth_token else None
else:
token = self.download_config.use_auth_token
hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
hfh_dataset_info = hf_api_dataset_info(
HfApi(config.HF_ENDPOINT),
self.name,
revision=self.revision,
token=token if token else "no-token",
use_auth_token=self.download_config.use_auth_token,
timeout=100.0,
)
patterns = (
Expand Down Expand Up @@ -1112,14 +1110,11 @@ def dataset_module_factory(
_raise_if_offline_mode_is_enabled()
hf_api = HfApi(config.HF_ENDPOINT)
try:
if isinstance(download_config.use_auth_token, bool):
token = HfFolder.get_token() if download_config.use_auth_token else None
else:
token = download_config.use_auth_token
dataset_info = hf_api.dataset_info(
dataset_info = hf_api_dataset_info(
hf_api,
repo_id=path,
revision=revision,
token=token if token else "no-token",
use_auth_token=download_config.use_auth_token,
timeout=100.0,
)
except Exception as e: # noqa: catch any exception of hf_hub and consider that the dataset doesn't exist
Expand Down
79 changes: 77 additions & 2 deletions src/datasets/utils/_hf_hub_fixes.py
@@ -1,7 +1,8 @@
from typing import Optional
from typing import List, Optional, Union

import huggingface_hub
from huggingface_hub import HfApi
from huggingface_hub import HfApi, HfFolder
from huggingface_hub.hf_api import DatasetInfo
from packaging import version


Expand Down Expand Up @@ -99,3 +100,77 @@ def delete_repo(
token=token,
repo_type=repo_type,
)


def dataset_info(
hf_api: HfApi,
repo_id: str,
*,
revision: Optional[str] = None,
timeout: Optional[float] = None,
use_auth_token: Optional[Union[bool, str]] = None,
) -> DatasetInfo:
"""
The huggingface_hub.HfApi.dataset_info parameters changed in 0.10.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.
Args:
hf_api (`huggingface_hub.HfApi`): Hub client
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
revision (`str`, *optional*):
The revision of the dataset repository from which to get the
information.
timeout (`float`, *optional*):
Whether to set a timeout for the request to the Hub.
use_auth_token (`bool` or `str`, *optional*):
Whether to use the `auth_token` provided from the
`huggingface_hub` cli. If not logged in, a valid `auth_token`
can be passed in as a string.
Returns:
[`hf_api.DatasetInfo`]: The dataset repository information.
<Tip>
Raises the following errors:
- [`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
- [`~utils.RevisionNotFoundError`]
If the revision to download from cannot be found.
</Tip>
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
if use_auth_token is False:
token = "no-token"
elif isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token() or "no-token"
return hf_api.dataset_info(
repo_id,
revision=revision,
token=token,
timeout=timeout,
)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.dataset_info(repo_id, revision=revision, timeout=timeout, use_auth_token=use_auth_token)


def list_repo_files(
hf_api: HfApi,
repo_id: str,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
) -> List[str]:
"""
The huggingface_hub.HfApi.list_repo_files parameters changed in 0.10.0 and some of them were deprecated.
This function checks the huggingface_hub version to call the right parameters.
"""
if version.parse(huggingface_hub.__version__) < version.parse("0.10.0"):
return hf_api.list_repo_files(repo_id, revision=revision, repo_type=repo_type, token=token, timeout=timeout)
else: # the `token` parameter is deprecated in huggingface_hub>=0.10.0
return hf_api.list_repo_files(
repo_id, revision=revision, repo_type=repo_type, use_auth_token=token, timeout=timeout
)
16 changes: 10 additions & 6 deletions src/datasets/utils/file_utils.py
Expand Up @@ -22,7 +22,9 @@
from typing import List, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse

import huggingface_hub
import requests
from huggingface_hub import HfFolder

from .. import __version__, config
from ..download.download_config import DownloadConfig
Expand Down Expand Up @@ -218,7 +220,9 @@ def cached_path(


def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str:
ua = f"datasets/{__version__}; python/{config.PY_VERSION}"
ua = f"datasets/{__version__}"
ua += f"; python/{config.PY_VERSION}"
ua += f"; huggingface_hub/{huggingface_hub.__version__}"
ua += f"; pyarrow/{config.PYARROW_VERSION}"
if config.TORCH_AVAILABLE:
ua += f"; torch/{config.TORCH_VERSION}"
Expand All @@ -239,13 +243,13 @@ def get_authentication_headers_for_url(url: str, use_auth_token: Optional[Union[
"""Handle the HF authentication"""
headers = {}
if url.startswith(config.HF_ENDPOINT):
token = None
if isinstance(use_auth_token, str):
if use_auth_token is False:
token = None
elif isinstance(use_auth_token, str):
token = use_auth_token
elif bool(use_auth_token):
from huggingface_hub import hf_api
else:
token = HfFolder.get_token()

token = hf_api.HfFolder.get_token()
if token:
headers["authorization"] = f"Bearer {token}"
return headers
Expand Down
3 changes: 2 additions & 1 deletion tests/test_filesystem.py
Expand Up @@ -12,6 +12,7 @@
extract_path_from_uri,
is_remote_filesystem,
)
from datasets.utils._hf_hub_fixes import dataset_info as hf_api_dataset_info

from .utils import require_lz4, require_zstandard

Expand Down Expand Up @@ -93,7 +94,7 @@ def test_fs_isfile(protocol, zip_jsonl_path, jsonl_gz_path):

@pytest.mark.integration
def test_hf_filesystem(hf_token, hf_api, hf_private_dataset_repo_txt_data, text_file):
repo_info = hf_api.dataset_info(hf_private_dataset_repo_txt_data, token=hf_token)
repo_info = hf_api_dataset_info(hf_api, hf_private_dataset_repo_txt_data, use_auth_token=hf_token)
hffs = HfFileSystem(repo_info=repo_info, token=hf_token)
assert sorted(hffs.glob("*")) == [".gitattributes", "data"]
assert hffs.isdir("data")
Expand Down
36 changes: 18 additions & 18 deletions tests/test_load.py
Expand Up @@ -756,18 +756,6 @@ def test_load_dataset_streaming_csv(path_extension, streaming, csv_path, bz2_csv
assert ds_item == {"col_1": "0", "col_2": 0, "col_3": 0.0}


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(hf_private_dataset_repo_zipped_img_data, hf_token, streaming):
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=hf_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


@pytest.mark.parametrize("streaming", [False, True])
@pytest.mark.parametrize("data_file", ["zip_csv_path", "zip_csv_with_dir_path", "csv_path"])
def test_load_dataset_zip_csv(data_file, streaming, zip_csv_path, zip_csv_with_dir_path, csv_path):
Expand Down Expand Up @@ -876,20 +864,32 @@ def assert_auth(url, *args, headers, **kwargs):

@pytest.mark.integration
def test_load_streaming_private_dataset(hf_token, hf_private_dataset_repo_txt_data):
with pytest.raises(FileNotFoundError):
load_dataset(hf_private_dataset_repo_txt_data, streaming=True)
ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True, use_auth_token=hf_token)
ds = load_dataset(hf_private_dataset_repo_txt_data, streaming=True)
assert next(iter(ds)) is not None


@pytest.mark.integration
def test_load_streaming_private_dataset_with_zipped_data(hf_token, hf_private_dataset_repo_zipped_txt_data):
with pytest.raises(FileNotFoundError):
load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True)
ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True, use_auth_token=hf_token)
ds = load_dataset(hf_private_dataset_repo_zipped_txt_data, streaming=True)
assert next(iter(ds)) is not None


@require_pil
@pytest.mark.integration
@pytest.mark.parametrize("implicit_token", [False, True])
@pytest.mark.parametrize("streaming", [False, True])
def test_load_dataset_private_zipped_images(
hf_private_dataset_repo_zipped_img_data, hf_token, streaming, implicit_token
):
use_auth_token = None if implicit_token else hf_token
ds = load_dataset(
hf_private_dataset_repo_zipped_img_data, split="train", streaming=streaming, use_auth_token=use_auth_token
)
assert isinstance(ds, IterableDataset if streaming else Dataset)
ds_items = list(ds)
assert len(ds_items) == 2


def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir, tmp_path, caplog):
cache_dir1 = tmp_path / "cache1"
cache_dir2 = tmp_path / "cache2"
Expand Down
18 changes: 17 additions & 1 deletion tests/test_metric_common.py
Expand Up @@ -38,6 +38,9 @@
UNSUPPORTED_ON_WINDOWS = {"code_eval"}
_on_windows = os.name == "nt"

REQUIRE_TRANSFORMERS = {"bertscore", "frugalscore", "perplexity"}
_has_transformers = importlib.util.find_spec("transformers") is not None


def skip_if_metric_requires_fairseq(test_case):
@wraps(test_case)
Expand All @@ -50,6 +53,17 @@ def wrapper(self, metric_name):
return wrapper


def skip_if_metric_requires_transformers(test_case):
@wraps(test_case)
def wrapper(self, metric_name):
if not _has_transformers and metric_name in REQUIRE_TRANSFORMERS:
self.skipTest('"test requires transformers"')
else:
test_case(self, metric_name)

return wrapper


def skip_on_windows_if_not_windows_compatible(test_case):
@wraps(test_case)
def wrapper(self, metric_name):
Expand All @@ -67,7 +81,9 @@ def get_local_metric_names():


@parameterized.named_parameters(get_local_metric_names())
@for_all_test_methods(skip_if_metric_requires_fairseq, skip_on_windows_if_not_windows_compatible)
@for_all_test_methods(
skip_if_metric_requires_fairseq, skip_if_metric_requires_transformers, skip_on_windows_if_not_windows_compatible
)
@local
@pytest.mark.integration
class LocalMetricTest(parameterized.TestCase):
Expand Down

0 comments on commit c1a66f0

Please sign in to comment.