Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push dataset infos.json to Hub #3467

Merged
merged 6 commits into from Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
81 changes: 77 additions & 4 deletions src/datasets/arrow_dataset.py
Expand Up @@ -73,7 +73,7 @@
from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table
from .info import DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit, Split
from .splits import NamedSplit, Split, SplitInfo
from .table import (
InMemoryTable,
MemoryMappedTable,
Expand Down Expand Up @@ -3404,15 +3404,15 @@ def to_parquet(

return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write()

def push_to_hub(
def _push_parquet_shards_to_hub(
self,
repo_id: str,
split: Optional[str] = None,
private: Optional[bool] = False,
token: Optional[str] = None,
branch: Optional[str] = None,
shard_size: Optional[int] = 500 << 20,
):
) -> Tuple[str, str, int, int]:
"""Pushes the dataset to the hub.
The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed.

Expand All @@ -3437,12 +3437,18 @@ def push_to_hub(
The size of the dataset shards to be uploaded to the hub. The dataset will be pushed in files
of the size specified here, in bytes. Defaults to a shard size of 500MB.

Returns:
repo_id (:obj:`str`): ID of the repository in <user>/<dataset_name>` or `<org>/<dataset_name>` format
split (:obj:`str`): name of the uploaded split
uploaded_size (:obj:`int`): number of uploaded bytes
dataset_nbytes (:obj:`int`): approximate size in bytes of the uploaded dataset afer uncompression

Example:
.. code-block:: python

>>> dataset.push_to_hub("<organization>/<dataset_id>", split="evaluation")
"""
api = HfApi()
api = HfApi(endpoint=config.HF_ENDPOINT)
token = token if token is not None else HfFolder.get_token()

if token is None:
Expand Down Expand Up @@ -3518,6 +3524,7 @@ def delete_file(file):
):
delete_file(file)

uploaded_size = 0
for index, shard in utils.tqdm(
enumerate(shards),
desc="Pushing dataset shards to the dataset hub",
Expand All @@ -3526,6 +3533,7 @@ def delete_file(file):
):
buffer = BytesIO()
shard.to_parquet(buffer)
uploaded_size += buffer.tell()
api.upload_file(
path_or_fileobj=buffer.getvalue(),
path_in_repo=path_in_repo(index),
Expand All @@ -3535,6 +3543,71 @@ def delete_file(file):
revision=branch,
identical_ok=True,
)
return repo_id, split, uploaded_size, dataset_nbytes

def push_to_hub(
self,
repo_id: str,
split: Optional[str] = None,
private: Optional[bool] = False,
token: Optional[str] = None,
branch: Optional[str] = None,
shard_size: Optional[int] = 500 << 20,
):
"""Pushes the dataset to the hub.
The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed.

Args:
repo_id (:obj:`str`):
The ID of the repository to push to in the following format: `<user>/<dataset_name>` or
`<org>/<dataset_name>`. Also accepts `<dataset_name>`, which will default to the namespace
of the logged-in user.
split (Optional, :obj:`str`):
The name of the split that will be given to that dataset. Defaults to `self.split`.
private (Optional :obj:`bool`, defaults to :obj:`False`):
Whether the dataset repository should be set to private or not. Only affects repository creation:
a repository that already exists will not be affected by that parameter.
token (Optional :obj:`str`):
An optional authentication token for the Hugging Face Hub. If no token is passed, will default
to the token saved locally when logging in with ``huggingface-cli login``. Will raise an error
if no token is passed and the user is not logged-in.
branch (Optional :obj:`str`):
The git branch on which to push the dataset. This defaults to the default branch as specified
in your repository, which defaults to `"main"`.
shard_size (Optional :obj:`int`):
The size of the dataset shards to be uploaded to the hub. The dataset will be pushed in files
of the size specified here, in bytes. Defaults to a shard size of 500MB.

Example:
.. code-block:: python

>>> dataset.push_to_hub("<organization>/<dataset_id>", split="evaluation")
"""
repo_id, split, uploaded_size, dataset_nbytes = self._push_parquet_shards_to_hub(
repo_id=repo_id, split=split, private=private, token=token, branch=branch, shard_size=shard_size
)
organization, dataset_name = repo_id.split("/")
info_to_dump = self.info.copy()
info_to_dump.download_checksums = None
info_to_dump.download_size = uploaded_size
info_to_dump.dataset_size = dataset_nbytes
info_to_dump.size_in_bytes = uploaded_size + dataset_nbytes
info_to_dump.splits = {
split: SplitInfo(split, num_bytes=dataset_nbytes, num_examples=len(self), dataset_name=dataset_name)
}
buffer = BytesIO()
buffer.write(f'{{"{organization}--{dataset_name}": '.encode())
info_to_dump._dump_info(buffer)
buffer.write(b"}")
HfApi(endpoint=config.HF_ENDPOINT).upload_file(
path_or_fileobj=buffer.getvalue(),
path_in_repo=config.DATASETDICT_INFOS_FILENAME,
repo_id=repo_id,
token=token,
repo_type="dataset",
revision=branch,
identical_ok=True,
)

@transmit_format
@fingerprint_transform(inplace=False)
Expand Down
12 changes: 7 additions & 5 deletions src/datasets/builder.py
Expand Up @@ -203,6 +203,7 @@ def __init__(
name: Optional[str] = None,
hash: Optional[str] = None,
base_path: Optional[str] = None,
info: Optional[DatasetInfo] = None,
features: Optional[Features] = None,
use_auth_token: Optional[Union[bool, str]] = None,
namespace: Optional[str] = None,
Expand Down Expand Up @@ -263,11 +264,12 @@ def __init__(

# prepare info: DatasetInfo are a standardized dataclass across all datasets
# Prefill datasetinfo
info = self.get_exported_dataset_info()
info.update(self._info())
info.builder_name = self.name
info.config_name = self.config.name
info.version = self.config.version
if info is None:
info = self.get_exported_dataset_info()
info.update(self._info())
info.builder_name = self.name
info.config_name = self.config.name
info.version = self.config.version
self.info = info
# update info with user specified infos
if features is not None:
Expand Down
45 changes: 38 additions & 7 deletions src/datasets/dataset_dict.py
Expand Up @@ -3,23 +3,25 @@
import json
import os
import re
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import fsspec
import numpy as np

from datasets.splits import NamedSplit, Split
from datasets.utils.doc_utils import is_documented_by
from huggingface_hub import HfApi

from . import config
from .arrow_dataset import Dataset
from .features import Features
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .info import DatasetInfo
from .splits import NamedSplit, Split, SplitDict, SplitInfo
from .table import Table
from .tasks import TaskTemplate
from .utils import logging
from .utils.deprecation_utils import deprecated
from .utils.doc_utils import is_documented_by
from .utils.typing import PathLike


Expand Down Expand Up @@ -931,12 +933,41 @@ def push_to_hub(

>>> dataset_dict.push_to_hub("<organization>/<dataset_id>")
"""
for key in self.keys():
logger.warning(f"Pushing split {key} to the Hub.")
self._check_values_type()
total_uploaded_size = 0
total_dataset_nbytes = 0
info_to_dump: DatasetInfo = next(iter(self.values())).info.copy()
dataset_name = repo_id.split("/")[-1]
info_to_dump.splits = SplitDict(dataset_name=dataset_name)
for split in self.keys():
logger.warning(f"Pushing split {split} to the Hub.")
# The split=key needs to be removed before merging
self[key].push_to_hub(
repo_id, split=key, private=private, token=token, branch=branch, shard_size=shard_size
repo_id, split, uploaded_size, dataset_nbytes = self[split]._push_parquet_shards_to_hub(
repo_id, split=split, private=private, token=token, branch=branch, shard_size=shard_size
)
total_uploaded_size += uploaded_size
total_dataset_nbytes += dataset_nbytes
info_to_dump.splits[split] = SplitInfo(
str(split), num_bytes=dataset_nbytes, num_examples=len(self[split]), dataset_name=dataset_name
)
organization, dataset_name = repo_id.split("/")
info_to_dump.download_checksums = None
info_to_dump.download_size = total_uploaded_size
info_to_dump.dataset_size = total_dataset_nbytes
info_to_dump.size_in_bytes = total_uploaded_size + total_dataset_nbytes
buffer = BytesIO()
buffer.write(f'{{"{organization}--{dataset_name}": '.encode())
info_to_dump._dump_info(buffer)
buffer.write(b"}")
HfApi(endpoint=config.HF_ENDPOINT).upload_file(
path_or_fileobj=buffer.getvalue(),
path_in_repo=config.DATASETDICT_INFOS_FILENAME,
repo_id=repo_id,
token=token,
repo_type="dataset",
revision=branch,
identical_ok=True,
)


class IterableDatasetDict(dict):
Expand Down
4 changes: 2 additions & 2 deletions src/datasets/info.py
Expand Up @@ -34,7 +34,7 @@
import json
import os
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

from datasets.tasks.text_classification import TextClassification

Expand Down Expand Up @@ -274,7 +274,7 @@ def copy(self) -> "DatasetInfo":
return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})


class DatasetInfosDict(dict):
class DatasetInfosDict(Dict[str, DatasetInfo]):
def write_to_directory(self, dataset_infos_dir, overwrite=False):
total_dataset_infos = {}
dataset_infos_path = os.path.join(dataset_infos_dir, config.DATASETDICT_INFOS_FILENAME)
Expand Down
33 changes: 25 additions & 8 deletions src/datasets/load.py
Expand Up @@ -47,6 +47,7 @@
from .dataset_dict import DatasetDict, IterableDatasetDict
from .features import Features
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .info import DatasetInfo, DatasetInfosDict
from .iterable_dataset import IterableDataset
from .metric import Metric
from .packaged_modules import _EXTENSION_TO_MODULE, _PACKAGED_DATASETS_MODULES, hash_python_lines
Expand Down Expand Up @@ -398,7 +399,7 @@ def _create_importable_file(
name: str,
download_mode: GenerateMode,
) -> Tuple[str, str]:
importable_directory_path = os.path.join(dynamic_modules_path, module_namespace, name.replace("/", "___"))
importable_directory_path = os.path.join(dynamic_modules_path, module_namespace, name.replace("/", "--"))
Path(importable_directory_path).mkdir(parents=True, exist_ok=True)
(Path(importable_directory_path).parent / "__init__.py").touch(exist_ok=True)
hash = files_to_hash([local_path] + [loc[1] for loc in local_imports])
Expand All @@ -413,7 +414,7 @@ def _create_importable_file(
)
logger.debug(f"Created importable dataset file at {importable_local_file}")
module_path = ".".join(
[os.path.basename(dynamic_modules_path), module_namespace, name.replace("/", "___"), hash, name.split("/")[-1]]
[os.path.basename(dynamic_modules_path), module_namespace, name.replace("/", "--"), hash, name.split("/")[-1]]
)
return module_path, hash

Expand Down Expand Up @@ -741,6 +742,11 @@ def get_module(self) -> DatasetModule:
"name": os.path.basename(self.path),
"base_path": self.path,
}
if os.path.isfile(os.path.join(self.path, config.DATASETDICT_INFOS_FILENAME)):
with open(os.path.join(self.path, config.DATASETDICT_INFOS_FILENAME), encoding="utf-8") as f:
dataset_infos: DatasetInfosDict = json.load(f)
builder_kwargs["name"] = next(iter(dataset_infos.values()))
builder_kwargs["info"] = DatasetInfo.from_dict(dataset_infos[builder_kwargs["name"]])
return DatasetModule(module_path, hash, builder_kwargs)


Expand Down Expand Up @@ -799,7 +805,7 @@ def get_module(self) -> DatasetModule:
token = HfFolder.get_token() if self.download_config.use_auth_token else None
else:
token = self.download_config.use_auth_token
dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
self.name,
revision=self.revision,
token=token,
Expand All @@ -808,11 +814,11 @@ def get_module(self) -> DatasetModule:
patterns = (
sanitize_patterns(self.data_files)
if self.data_files is not None
else get_patterns_in_dataset_repository(dataset_info)
else get_patterns_in_dataset_repository(hfh_dataset_info)
)
data_files = DataFilesDict.from_hf_repo(
patterns,
dataset_info=dataset_info,
dataset_info=hfh_dataset_info,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
)
infered_module_names = {
Expand All @@ -828,9 +834,20 @@ def get_module(self) -> DatasetModule:
builder_kwargs = {
"hash": hash,
"data_files": data_files,
"name": self.name.replace("/", "___"),
"name": self.name.replace("/", "--"),
"base_path": hf_hub_url(self.name, "", revision=self.revision),
}
try:
dataset_infos_path = cached_path(
hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.revision),
download_config=self.download_config,
)
with open(dataset_infos_path, encoding="utf-8") as f:
dataset_infos: DatasetInfosDict = json.load(f)
builder_kwargs["name"] = next(iter(dataset_infos))
builder_kwargs["info"] = DatasetInfo.from_dict(dataset_infos[builder_kwargs["name"]])
except FileNotFoundError:
pass
return DatasetModule(module_path, hash, builder_kwargs)


Expand Down Expand Up @@ -918,7 +935,7 @@ def __init__(

def get_module(self) -> DatasetModule:
dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
importable_directory_path = os.path.join(dynamic_modules_path, "datasets", self.name.replace("/", "___"))
importable_directory_path = os.path.join(dynamic_modules_path, "datasets", self.name.replace("/", "--"))
hashes = (
[h for h in os.listdir(importable_directory_path) if len(h) == 64]
if os.path.isdir(importable_directory_path)
Expand All @@ -945,7 +962,7 @@ def _get_modification_time(module_hash):
[
os.path.basename(dynamic_modules_path),
"datasets",
self.name.replace("/", "___"),
self.name.replace("/", "--"),
hash,
self.name.split("/")[-1],
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inspect.py
Expand Up @@ -10,7 +10,7 @@
("acronym_identification", "default"),
("lhoestq/squad", "plain_text"),
("lhoestq/test", "default"),
("lhoestq/demo1", "lhoestq___demo1"),
("lhoestq/demo1", "lhoestq--demo1"),
],
)
def test_get_dataset_config_names(path, expected):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_load.py
Expand Up @@ -453,14 +453,14 @@ def test_load_dataset_builder_for_community_dataset_with_script():
assert builder.info.features == Features({"text": Value("string")})
namespace = SAMPLE_DATASET_IDENTIFIER[: SAMPLE_DATASET_IDENTIFIER.index("/")]
assert builder._relative_data_dir().startswith(namespace)
assert SAMPLE_DATASET_IDENTIFIER.replace("/", "___") in builder.__module__
assert SAMPLE_DATASET_IDENTIFIER.replace("/", "--") in builder.__module__


def test_load_dataset_builder_for_community_dataset_without_script():
builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER2)
assert isinstance(builder, DatasetBuilder)
assert builder.name == "text"
assert builder.config.name == SAMPLE_DATASET_IDENTIFIER2.replace("/", "___")
assert builder.config.name == SAMPLE_DATASET_IDENTIFIER2.replace("/", "--")
assert isinstance(builder.config.data_files, DataFilesDict)
assert len(builder.config.data_files["train"]) > 0
assert len(builder.config.data_files["test"]) > 0
Expand Down