Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD S3 support for downloading and uploading processed datasets #1723

Merged
merged 57 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
6cf8813
added fsspec and fsspec[s3] adjusted save_to_disk function
philschmid Jan 7, 2021
aa90496
added reading from s3
philschmid Jan 7, 2021
6dba448
fixed save_to_disk for s3 path
philschmid Jan 8, 2021
ac1d90f
implemented tests
philschmid Jan 8, 2021
49b2fd8
added filesystem utils to arrow dataset and dataset dict
philschmid Jan 11, 2021
7eb6160
added tests for filesystem_utils
philschmid Jan 11, 2021
82f891d
added DatasetDict test
philschmid Jan 11, 2021
ed17c5e
changed var from s3_ to proc_
philschmid Jan 12, 2021
3bbab8c
Merge remote-tracking branch 'upstream/master'
philschmid Jan 12, 2021
574c9dd
fixed error in load from disk function
philschmid Jan 12, 2021
0315025
fixing directory creation
philschmid Jan 12, 2021
8d072e6
removed fs.makedirs since files has to be saved temp local
philschmid Jan 12, 2021
7a70258
fixed code quality checks
philschmid Jan 12, 2021
29313ab
fixed quality check
philschmid Jan 12, 2021
5f6749f
added noqa for pytest to work with moto
philschmid Jan 12, 2021
354e39f
stupid mistake with wrong order at imports
philschmid Jan 12, 2021
bc36832
adjuste boto3 version to work with moto in tests
philschmid Jan 12, 2021
57bcbe7
removed pytest fixtures from unittest class
philschmid Jan 12, 2021
3493e87
forgot to remove fixture as parameter...
philschmid Jan 12, 2021
b4fa6a9
Make it working with Windows style paths.
mfuntowicz Jan 13, 2021
0be3a11
Merge pull request #1 from philschmid/add_s3
philschmid Jan 13, 2021
5fbfcd7
fixed code quality
philschmid Jan 13, 2021
b540612
Merge remote-tracking branch 'upstream/master'
philschmid Jan 13, 2021
9a1b282
fixed hopefully the last path problems for WIN
philschmid Jan 13, 2021
081f4bc
added Path().pathjoin with posix to load_from_disk for DatasetDict keys
philschmid Jan 13, 2021
2ce5fec
fixed win path problem
philschmid Jan 13, 2021
d346c6f
create conditional dataset_dict_split_path for creating correct path …
philschmid Jan 13, 2021
f25a036
added s3 as extra requires
philschmid Jan 14, 2021
df78d8b
fixed boto imports for docs
philschmid Jan 14, 2021
e3fa922
added S3FileSystem with documentation
philschmid Jan 17, 2021
fb992a5
reworked everything for datasets.filesystem
philschmid Jan 17, 2021
53a6a4b
documentation and styling
philschmid Jan 17, 2021
85f0297
added s3fs for documentation
philschmid Jan 17, 2021
8885a7b
handle optional s3fs dependency
lhoestq Jan 18, 2021
b91345c
fix test
lhoestq Jan 18, 2021
93a5f5b
adjusted doc order and renamed preproc_dataset_path to extract_path_f…
philschmid Jan 18, 2021
8b55b89
added temp dir when saving
philschmid Jan 19, 2021
2bf289d
fixed quality
philschmid Jan 19, 2021
83e4673
added documentation
philschmid Jan 19, 2021
04042ea
implemented save_to_disk for local remote filesystem with temp dir
philschmid Jan 19, 2021
ec29076
fixed documentation example
philschmid Jan 19, 2021
187e01d
fixed documentation for botocore and boto3
philschmid Jan 19, 2021
7785f90
Merge branch 'master' of git://github.com/huggingface/datasets
philschmid Jan 19, 2021
926f31c
Update docs/source/filesystems.rst
philschmid Jan 22, 2021
22b33d7
Update docs/source/filesystems.rst
philschmid Jan 22, 2021
72440ba
Update docs/source/filesystems.rst
philschmid Jan 22, 2021
ea273a8
Update src/datasets/arrow_dataset.py
philschmid Jan 22, 2021
5359003
Update src/datasets/arrow_dataset.py
philschmid Jan 22, 2021
fd106e4
Update src/datasets/filesystems/__init__.py
philschmid Jan 22, 2021
878f8b7
Update src/datasets/filesystems/s3filesystem.py
philschmid Jan 22, 2021
0b1a2f8
Update src/datasets/filesystems/s3filesystem.py
philschmid Jan 22, 2021
a3bebd5
Update src/datasets/load.py
philschmid Jan 22, 2021
eb69cdb
removed unnecessary @mock_s3
philschmid Jan 22, 2021
8b7cd48
Update docs/source/filesystems.rst
philschmid Jan 22, 2021
9d7f5c6
Update docs/source/filesystems.rst
philschmid Jan 26, 2021
8514bee
Update src/datasets/filesystems/s3filesystem.py
philschmid Jan 26, 2021
a8738ca
Update docs/source/processing.rst
philschmid Jan 26, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
114 changes: 61 additions & 53 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,19 @@
from setuptools import find_packages
from setuptools import setup

DOCLINES = __doc__.split('\n')
DOCLINES = __doc__.split("\n")

REQUIRED_PKGS = [
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
'numpy>=1.17',
"numpy>=1.17",
# Backend and serialization. Minimum 0.17.1 to support extension array
'pyarrow>=0.17.1',
"pyarrow>=0.17.1",
# For smart caching dataset processing
'dill',
"dill",
# For performance gains with apache arrow
'pandas',
"pandas",
# for downloading datasets over HTTPS
'requests>=2.19.0',
"requests>=2.19.0",
# progress bars in download and scripts
# tqdm 4.50.0 introduced permission errors on windows
# see https://app.circleci.com/pipelines/github/huggingface/datasets/235/workflows/cfb6a39f-68eb-4802-8b17-2cd5e8ea7369/jobs/1111
Expand All @@ -82,38 +82,44 @@
"xxhash",
# for better multiprocessing
"multiprocess",
# for saving datsets to local or s3
"fsspec",
"fsspec[s3]",
# for getting credentials from aws_profile
"boto3",
philschmid marked this conversation as resolved.
Show resolved Hide resolved
# to get metadata of optional dependencies such as torch or tensorflow for Python versions that don't have it
"importlib_metadata;python_version<'3.8'"
"importlib_metadata;python_version<'3.8'",
]

BENCHMARKS_REQUIRE = [
'numpy==1.18.5',
'tensorflow==2.3.0',
'torch==1.6.0',
'transformers==3.0.2',
"numpy==1.18.5",
"tensorflow==2.3.0",
"torch==1.6.0",
"transformers==3.0.2",
]

TESTS_REQUIRE = [
'apache-beam',
'absl-py',
'bs4',
'conllu',
'elasticsearch',
'faiss-cpu',
'langdetect',
'lxml',
'mwparserfromhell',
'nltk',
'openpyxl',
'py7zr',
'pytest',
'pytest-xdist',
'tensorflow',
'torch',
'tldextract',
'transformers',
'zstandard',
'rarfile',
"apache-beam",
"absl-py",
"bs4",
"conllu",
"elasticsearch",
"faiss-cpu",
"langdetect",
"lxml",
"mwparserfromhell",
"nltk",
"openpyxl",
"py7zr",
"pytest",
"pytest-xdist",
"tensorflow",
"torch",
"tldextract",
"transformers",
"zstandard",
"rarfile",
"moto[s3]",
]

if os.name == "nt": # windows
Expand All @@ -128,34 +134,36 @@


EXTRAS_REQUIRE = {
'apache-beam': ['apache-beam'],
'tensorflow': ['tensorflow>=2.2.0'],
'tensorflow_gpu': ['tensorflow-gpu>=2.2.0'],
'torch': ['torch'],
'dev': TESTS_REQUIRE + QUALITY_REQUIRE,
'tests': TESTS_REQUIRE,
'quality': QUALITY_REQUIRE,
'benchmarks': BENCHMARKS_REQUIRE,
'docs': ["recommonmark", "sphinx==3.1.2", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"]
"apache-beam": ["apache-beam"],
"tensorflow": ["tensorflow>=2.2.0"],
"tensorflow_gpu": ["tensorflow-gpu>=2.2.0"],
"torch": ["torch"],
"dev": TESTS_REQUIRE + QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
"quality": QUALITY_REQUIRE,
"benchmarks": BENCHMARKS_REQUIRE,
"docs": [
"recommonmark",
"sphinx==3.1.2",
"sphinx-markdown-tables",
"sphinx-rtd-theme==0.4.3",
"sphinx-copybutton",
],
}

setup(
name='datasets',
name="datasets",
version="1.2.0",
description=DOCLINES[0],
long_description='\n'.join(DOCLINES[2:]),
author='HuggingFace Inc.',
author_email='thomas@huggingface.co',
url='https://github.com/huggingface/datasets',
download_url='https://github.com/huggingface/datasets/tags',
license='Apache 2.0',
long_description="\n".join(DOCLINES[2:]),
author="HuggingFace Inc.",
author_email="thomas@huggingface.co",
url="https://github.com/huggingface/datasets",
download_url="https://github.com/huggingface/datasets/tags",
license="Apache 2.0",
package_dir={"": "src"},
packages=find_packages("src"),
package_data={
'datasets': [
'scripts/templates/*',
],
},
package_data={"datasets": ["scripts/templates/*",],},
scripts=["datasets-cli"],
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS_REQUIRE,
Expand All @@ -171,5 +179,5 @@
"Programming Language :: Python :: 3.7",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords='datasets machine learning datasets metrics',
keywords="datasets machine learning datasets metrics",
)
52 changes: 42 additions & 10 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .info import DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit
from .utils import map_nested
from .utils import get_filesystem_from_dataset_path, is_remote_filesystem, map_nested
from .utils.logging import WARNING, get_logger, get_verbosity, set_verbosity_warning


Expand Down Expand Up @@ -425,17 +425,26 @@ def __setstate__(self, state):
if self._indices is None and self._indices_data_files:
self._indices = reader._read_files(self._indices_data_files)

def save_to_disk(self, dataset_path: str):
def save_to_disk(
self, dataset_path: str, aws_profile="default", aws_access_key_id=None, aws_secret_access_key=None
):
"""
Save the dataset in a dataset directory
Save the dataset in a dataset directory or to a s3 bucket

Args:
dataset_path (``str``): path of the dataset directory where the dataset will be saved to
dataset_path (``str``): path or s3 uri of the dataset directory where the dataset will be saved to
aws_profile (:obj:`str`, `optional`, defaults to :obj:``default``): the aws profile used to create the `boto_session` for uploading the data to s3
aws_access_key_id (:obj:`str`, `optional`, defaults to :obj:``None``): the aws access key id used to create the `boto_session` for uploading the data to s3
aws_secret_access_key (:obj:`str`, `optional`, defaults to :obj:``None``): the aws secret access key used to create the `boto_session` for uploading the data to s3
Copy link
Member

@julien-c julien-c Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have boto3 as an optional dependency, maybe there's a way to not use those (verbose) params and use like a profile name or something instead? (not sure, just a question)

Copy link
Member

@n1t0 n1t0 Jan 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @lhoestq mentioned, we can probably avoid using these params by letting the user provide a custom fs directly. I think this has several advantages (avoid having too many params, lets remove code specific to s3, ..)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you would suggest something like that?

import fsspec

s3 = fsspec.filesystem("s3", anon=False, key=aws_access_key_id, secret=aws_secret_access_key)

dataset.save_to_disk('s3://my-s3-bucket-with-region/dataset/train',fs=s3)

What I don't like about that is the manual creation, since fsspec is not that well documented for the remote filesystems, e.g. when you want to know which "credentials" you need you to have to go to the s3fs documentation.

what do you think if we remove the named arguments aws_profile... and handle it how fsspec does with an storage_options dict.

dataset.save_to_disk('s3://my-s3-bucket-with-region/dataset/train',
                                      storage_options={
                                             'aws_access_key_id': 123,
                                             'aws_secret_access_key': 123
                                      })

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed the docstring of fsspec.filesystem is not ideal:

Signature: fsspec.filesystem(protocol, **storage_options)
Docstring:
Instantiate filesystems for given protocol and arguments

``storage_options`` are specific to the protocol being chosen, and are
passed directly to the class.

Maybe we can have a better documentation on our side instead using a wrapper:

import datasets

fs = datasets.filesystem("s3", anon=False, key=aws_access_key_id, secret=aws_secret_access_key)

Where the docstring of datasets.filesystem is more complete and includes examples for popular filesystems like s3

Copy link
Member

@lhoestq lhoestq Jan 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option would be to make available the filesystems classes easily:

>>> from datasets.filesystems import S3FileSystem  # S3FileSystem is simply the class from s3fs
>>> S3FileSystem?  # show the docstring

shows

Init signature: s3fs.core.S3FileSystem(*args, **kwargs)
Docstring:     
Access S3 as if it were a file system.

This exposes a filesystem-like API (ls, cp, open, etc.) on top of S3
storage.

Provide credentials either explicitly (``key=``, ``secret=``) or depend
on boto's credential methods. See botocore documentation for more
information. If no credentials are available, use ``anon=True``.

Parameters
----------
anon : bool (False)
    Whether to use anonymous connection (public buckets only). If False,
    uses the key/secret given, or boto's credential resolver (client_kwargs,
    environment, variables, config files, EC2 IAM server, in that order)
key : string (None)
    If not anonymous, use this access key ID, if specified
secret : string (None)
    If not anonymous, use this secret access key, if specified
token : string (None)
    If not anonymous, use this security token, if specified
etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • If we add the various params separately, the more new filesystems we support, the more complicated it becomes. For the user it becomes difficult to know which params to look at, and for us to document everything. It seems like a good option to test things out, but having to change it down the road will require breaking changes, and we usually try to avoid these as much as possible.
  • Using some kind of storage_options just like fsspec.filesystem might be better but it seems difficult to document also. I think the same argument applies to having a datasets.filesystem helper.

I think I have a preference for your second option @lhoestq

  • It seems easy to show all the available filesystems to the user, with each of them having a meaningful documentation
  • We can probably add tests for those we support explicitly
  • If all of them share the same interface, then power users can probably use anything they want from fsspec without explicit support from us?

What do you guys think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think the second option is interesting and I totally agree with your three points.
Maybe let's start by having S3FileSystem in datasets.filesystems and we can add the other ones later.

In the documentation of save_to_disk/load_from_disk we can then say that any filesystem from datasets.filesystems or fsspec can be used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rebuilt everything so that you can now pass in a fsspec like filesystem.

from datasets import S3Filesystem, load_from_disk

s3 = datasets.S3FileSystem(key=aws_access_key_id, secret=aws_secret_access_key)  # doctest: +SKIP

dataset = load_from_disk('s3://my-private-datasets/imdb/train',fs=s3)  # doctest: +SKIP

print(len(dataset))
# 25000

I also created a "draft" documentation version with a few examples for S3Fileystem. Your feedback would be nice.
Afterwards, I would adjust the documentation of save_to_disk/load_from_disk

image

"""
assert (
not self.list_indexes()
), "please remove all the indexes using `dataset.drop_index` before saving a dataset"
self = pickle.loads(pickle.dumps(self))
# gets filesystem from dataset, either s3:// or file:// and adjusted dataset_path
fs, dataset_path = get_filesystem_from_dataset_path(
dataset_path, aws_profile, aws_access_key_id, aws_secret_access_key
)
os.makedirs(dataset_path, exist_ok=True)
# Write indices if needed
if self._indices is not None:
Expand All @@ -455,12 +464,13 @@ def save_to_disk(self, dataset_path: str):
self._inplace_history = [{"transforms": []}]
# Copy all files into the dataset directory
for data_file in self._data_files + self._indices_data_files:
# Copy file to destination directory
src = data_file["filename"]
filename = Path(src).name
dest = os.path.join(dataset_path, filename)
if src != dest:
shutil.copy(src, dest)
fs.put(src, dest)
elif fs.protocol != "file":
fs.put(src, dest)
# Change path to relative path from inside the destination directory
data_file["filename"] = filename
# Get state
Expand All @@ -472,19 +482,38 @@ def save_to_disk(self, dataset_path: str):
len(h["transforms"]) == 0 for h in state.get("_inplace_history", [])
), "in-place history needs to be empty"
# Serialize state
with open(os.path.join(dataset_path, "state.json"), "w", encoding="utf-8") as state_file:
with fs.open(os.path.join(dataset_path, "state.json"), "w", encoding="utf-8") as state_file:
json.dump(state, state_file, indent=2, sort_keys=True)
with open(os.path.join(dataset_path, "dataset_info.json"), "w", encoding="utf-8") as dataset_info_file:
with fs.open(os.path.join(dataset_path, "dataset_info.json"), "w", encoding="utf-8") as dataset_info_file:
json.dump(dataset_info, dataset_info_file, indent=2, sort_keys=True)
logger.info("Dataset saved in {}".format(dataset_path))
# removes temp empty directory if files are uploaded to s3
if "s3" in fs.protocol:
shutil.rmtree(dataset_path.split("/")[0])

@staticmethod
def load_from_disk(dataset_path: str) -> "Dataset":
def load_from_disk(
dataset_path: str, aws_profile="default", aws_access_key_id=None, aws_secret_access_key=None, anon=False
) -> "Dataset":
"""Load the dataset from a dataset directory

Args:
dataset_path (``str``): path of the dataset directory where the dataset will be loaded from
dataset_path (``str``): path or s3 uri of the dataset directory where the dataset will be loaded from
aws_profile (:obj:`str`, `optional`, defaults to :obj:``default``): the aws profile used to create the `boto_session` for downloading the data to s3
aws_access_key_id (:obj:`str`, `optional`, defaults to :obj:``None``): the aws access key id used to create the `boto_session` for downloading the data to s3
aws_secret_access_key (:obj:`str`, `optional`, defaults to :obj:``None``): the aws secret access key used to create the `boto_session` for downloading the data to s3
anon (:obj:`boolean`, `optional`, defaults to :obj:``False``): The connection can be anonymous - in which case only publicly-available, read-only buckets are accessible, for anonymous connection use `anon=True`
"""
# copies file from filesystem if it is s3 to local filesystem and modifies dataset_path to temp directory containing local copies
if is_remote_filesystem(dataset_path):
# gets filesystem from dataset, either s3:// or file://
fs, proc_dataset_path = get_filesystem_from_dataset_path(
dataset_path, aws_profile, aws_access_key_id, aws_secret_access_key, anon
)
tmp_dir = tempfile.TemporaryDirectory()
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
dataset_path = os.path.join(tmp_dir.name, proc_dataset_path)
fs.download(proc_dataset_path, dataset_path, recursive=True)

with open(os.path.join(dataset_path, "state.json"), "r", encoding="utf-8") as state_file:
state = json.load(state_file)
with open(os.path.join(dataset_path, "dataset_info.json"), "r", encoding="utf-8") as dataset_info_file:
Expand All @@ -496,6 +525,9 @@ def load_from_disk(dataset_path: str) -> "Dataset":
for data_file in state.get("_data_files", []) + state.get("_indices_data_files", []):
data_file["filename"] = os.path.join(dataset_path, data_file["filename"])
dataset.__setstate__(state)

if "tmp_dir" in vars():
tmp_dir.cleanup()
return dataset

@property
Expand Down
46 changes: 37 additions & 9 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .arrow_dataset import Dataset
from .features import Features
from .utils import get_filesystem_from_dataset_path


class DatasetDict(dict):
Expand Down Expand Up @@ -478,31 +479,58 @@ def shuffle(
}
)

def save_to_disk(self, dataset_dict_path: str):
def save_to_disk(
self, dataset_dict_path: str, aws_profile="default", aws_access_key_id=None, aws_secret_access_key=None
):
"""
Save the dataset dict in a dataset dict directory.
Save the dataset dict in a dataset dict directory or to a s3 bucket

Args:
dataset_dict_path (``str``): path of the dataset dict directory where the dataset dict will be saved to
aws_profile (:obj:`str`, `optional`, defaults to :obj:``default``): the aws profile used to create the `boto_session` for uploading the data to s3
aws_access_key_id (:obj:`str`, `optional`, defaults to :obj:``None``): the aws access key id used to create the `boto_session` for uploading the data to s3
aws_secret_access_key (:obj:`str`, `optional`, defaults to :obj:``None``): the aws secret access key used to create the `boto_session` for uploading the data to s3
"""
os.makedirs(dataset_dict_path, exist_ok=True)
fs, proc_dataset_dict_path = get_filesystem_from_dataset_path(
dataset_dict_path, aws_profile, aws_access_key_id, aws_secret_access_key
)
os.makedirs(proc_dataset_dict_path, exist_ok=True)
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
json.dump(
{"splits": list(self)}, open(os.path.join(dataset_dict_path, "dataset_dict.json"), "w", encoding="utf-8")
{"splits": list(self)},
fs.open(os.path.join(proc_dataset_dict_path, "dataset_dict.json"), "w", encoding="utf-8"),
)
for k, dataset in self.items():
dataset.save_to_disk(os.path.join(dataset_dict_path, k))
dataset.save_to_disk(
os.path.join(dataset_dict_path, k), aws_profile, aws_access_key_id, aws_secret_access_key
)

@staticmethod
def load_from_disk(dataset_dict_path: str) -> "DatasetDict":
def load_from_disk(
dataset_dict_path: str, aws_profile="default", aws_access_key_id=None, aws_secret_access_key=None, anon=False
) -> "DatasetDict":
"""
Load the dataset dict from a dataset dict directory
Load the dataset dict from a dataset dict directory or from a s3 bucket

Args:
dataset_dict_path (``str``): path of the dataset dict directory where the dataset dict will be loaded from
aws_profile (:obj:`str`, `optional`, defaults to :obj:``default``): the aws profile used to create the `boto_session` for uploading the data to s3
aws_access_key_id (:obj:`str`, `optional`, defaults to :obj:``None``): the aws access key id used to create the `boto_session` for uploading the data to s3
aws_secret_access_key (:obj:`str`, `optional`, defaults to :obj:``None``): the aws secret access key used to create the `boto_session` for uploading the data to s3
anon (:obj:`boolean`, `optional`, defaults to :obj:``False``): The connection can be anonymous - in which case only publicly-available, read-only buckets are accessible, for anonymous connection use `anon=True`

"""
dataset_dict = DatasetDict()
for k in json.load(open(os.path.join(dataset_dict_path, "dataset_dict.json"), "r", encoding="utf-8"))[
fs, proc_dataset_dict_path = get_filesystem_from_dataset_path(
dataset_dict_path, aws_profile, aws_access_key_id, aws_secret_access_key
)
for k in json.load(fs.open(os.path.join(proc_dataset_dict_path, "dataset_dict.json"), "r", encoding="utf-8"))[
"splits"
]:
dataset_dict[k] = Dataset.load_from_disk(os.path.join(dataset_dict_path, k))
dataset_dict[k] = Dataset.load_from_disk(
os.path.join(dataset_dict_path, k),
aws_profile,
aws_access_key_id,
aws_secret_access_key,
anon,
)
return dataset_dict
29 changes: 22 additions & 7 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .info import DATASET_INFOS_DICT_FILE_NAME
from .metric import Metric
from .splits import Split
from .utils import get_filesystem_from_dataset_path
from .utils.download_manager import GenerateMode
from .utils.file_utils import HF_MODULES_CACHE, DownloadConfig, cached_path, head_hf_s3, hf_bucket_url, hf_github_url
from .utils.filelock import FileLock
Expand Down Expand Up @@ -620,24 +621,38 @@ def load_dataset(
return ds


def load_from_disk(dataset_path: str) -> Union[Dataset, DatasetDict]:
def load_from_disk(
dataset_path: str,
aws_profile="default",
aws_access_key_id=None,
aws_secret_access_key=None,
anon=False,
) -> Union[Dataset, DatasetDict]:
"""
Load a dataset that was previously saved using ``dataset.save_to_disk(dataset_path)``.
Load a dataset that was previously saved using ``dataset.save_to_disk(dataset_path)`` from s3 or local filesystem.

Args:
dataset_path (``str``): path of a Dataset directory or a DatasetDict directory
aws_profile (:obj:`str`, `optional`, defaults to :obj:``default``): the aws profile used to create the `boto_session` for downloading the data to s3
aws_access_key_id (:obj:`str`, `optional`, defaults to :obj:``None``): the aws access key id used to create the `boto_session` for downloading the data to s3
aws_secret_access_key (:obj:`str`, `optional`, defaults to :obj:``None``): the aws secret access key used to create the `boto_session` for downloading the data to s3
anon (:obj:`boolean`, `optional`, defaults to :obj:``False``): The connection can be anonymous - in which case only publicly-available, read-only buckets are accessible, for anonymous connection use `anon=True`

Returns:
``datasets.Dataset`` or ``datasets.DatasetDict``
if `dataset_path` is a path of a dataset directory: the dataset requested,
if `dataset_path` is a path of a dataset dict directory: a ``datasets.DatasetDict`` with each split.
"""
if not os.path.isdir(dataset_path):
# gets filesystem from dataset, either s3:// or file:// and adjusted dataset_path
fs, proc_dataset_path = get_filesystem_from_dataset_path(
dataset_path, aws_profile, aws_access_key_id, aws_secret_access_key
)
if not fs.exists(proc_dataset_path):
raise FileNotFoundError("Directory {} not found".format(dataset_path))
if os.path.exists(os.path.join(dataset_path, "dataset_info.json")):
return Dataset.load_from_disk(dataset_path)
elif os.path.exists(os.path.join(dataset_path, "dataset_dict.json")):
return DatasetDict.load_from_disk(dataset_path)
if fs.isfile(os.path.join(proc_dataset_path, "dataset_info.json")):
return Dataset.load_from_disk(dataset_path, aws_profile, aws_access_key_id, aws_secret_access_key, anon)
elif fs.isfile(os.path.join(proc_dataset_path, "dataset_dict.json")):
return DatasetDict.load_from_disk(dataset_path, aws_profile, aws_access_key_id, aws_secret_access_key, anon)
else:
raise FileNotFoundError(
"Directory {} is neither a dataset directory nor a dataset dict directory.".format(dataset_path)
Expand Down