Skip to content

Commit

Permalink
Improve default patterns resolution (#6704)
Browse files Browse the repository at this point in the history
* Separate filename and dirname patterns

* Nit

* Faster local files resolution

* Style

* Use context manager

* Replace `fsspec.get_fs_token_paths` with `url_to_fs`

* Fix

* Remove context manager
  • Loading branch information
mariosasko committed Mar 15, 2024
1 parent a02997d commit d1d3c06
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 80 deletions.
5 changes: 3 additions & 2 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
from fsspec.core import url_to_fs
from huggingface_hub import CommitInfo, CommitOperationAdd, CommitOperationDelete, DatasetCard, DatasetCardData, HfApi
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map
Expand Down Expand Up @@ -1504,7 +1505,7 @@ def save_to_disk(
num_shards = num_shards if num_shards is not None else num_proc

fs: fsspec.AbstractFileSystem
fs, _, _ = fsspec.get_fs_token_paths(dataset_path, storage_options=storage_options)
fs, _ = url_to_fs(dataset_path, **(storage_options or {}))

if not is_remote_filesystem(fs):
parent_cache_files_paths = {
Expand Down Expand Up @@ -1694,7 +1695,7 @@ def load_from_disk(
storage_options = fs.storage_options

fs: fsspec.AbstractFileSystem
fs, _, [dataset_path] = fsspec.get_fs_token_paths(dataset_path, storage_options=storage_options)
fs, dataset_path = url_to_fs(dataset_path, **(storage_options or {}))

dest_dataset_path = dataset_path
dataset_dict_json_path = posixpath.join(dest_dataset_path, config.DATASETDICT_JSON_FILENAME)
Expand Down
15 changes: 6 additions & 9 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from fsspec.core import url_to_fs

from . import config
from .features import Features, Image, Value
Expand Down Expand Up @@ -327,14 +328,10 @@ def __init__(
self._disable_nullable = disable_nullable

if stream is None:
fs_token_paths = fsspec.get_fs_token_paths(path, storage_options=storage_options)
self._fs: fsspec.AbstractFileSystem = fs_token_paths[0]
self._path = (
fs_token_paths[2][0]
if not is_remote_filesystem(self._fs)
else self._fs.unstrip_protocol(fs_token_paths[2][0])
)
self.stream = self._fs.open(fs_token_paths[2][0], "wb")
fs, path = url_to_fs(path, **(storage_options or {}))
self._fs: fsspec.AbstractFileSystem = fs
self._path = path if not is_remote_filesystem(self._fs) else self._fs.unstrip_protocol(path)
self.stream = self._fs.open(path, "wb")
self._closable_stream = True
else:
self._fs = None
Expand Down Expand Up @@ -681,7 +678,7 @@ def finalize(self, metrics_query_result: dict):
"""

# Beam FileSystems require the system's path separator in the older versions
fs, _, [parquet_path] = fsspec.get_fs_token_paths(self._parquet_path)
fs, parquet_path = url_to_fs(self._parquet_path)
parquet_path = str(Path(parquet_path)) if not is_remote_filesystem(fs) else fs.unstrip_protocol(parquet_path)

shards = fs.glob(parquet_path + "*.parquet")
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import fsspec
import pyarrow as pa
from fsspec.core import url_to_fs
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map

Expand Down Expand Up @@ -883,7 +884,7 @@ def download_and_prepare(

output_dir = output_dir if output_dir is not None else self._cache_dir
# output_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing)
fs, _, [output_dir] = fsspec.get_fs_token_paths(output_dir, storage_options=storage_options)
fs, output_dir = url_to_fs(output_dir, **(storage_options or {}))
self._fs = fs
self._output_dir = output_dir if not is_remote_filesystem(self._fs) else self._fs.unstrip_protocol(output_dir)

Expand Down
55 changes: 38 additions & 17 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import huggingface_hub
from fsspec import get_fs_token_paths
from fsspec.core import url_to_fs
from fsspec.implementations.http import HTTPFileSystem
from huggingface_hub import HfFileSystem
from packaging import version
Expand Down Expand Up @@ -46,36 +46,57 @@ class EmptyDatasetError(FileNotFoundError):
}
NON_WORDS_CHARS = "-._ 0-9"
if config.FSSPEC_VERSION < version.parse("2023.9.0"):
KEYWORDS_IN_PATH_NAME_BASE_PATTERNS = ["{keyword}[{sep}/]**", "**[{sep}/]{keyword}[{sep}/]**"]
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"]
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
"{keyword}/**",
"{keyword}[{sep}]*/**",
"**[{sep}/]{keyword}/**",
"**[{sep}/]{keyword}[{sep}]*/**",
]
elif config.FSSPEC_VERSION < version.parse("2023.12.0"):
KEYWORDS_IN_PATH_NAME_BASE_PATTERNS = ["{keyword}[{sep}/]**", "**/*[{sep}/]{keyword}[{sep}/]**"]
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/*[{sep}/]{keyword}[{sep}]*", "{keyword}[{sep}]*"]
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
"{keyword}/**/*",
"{keyword}[{sep}]*/**/*",
"**/*[{sep}/]{keyword}/**/*",
"**/*[{sep}/]{keyword}[{sep}]*/**/*",
]
else:
KEYWORDS_IN_PATH_NAME_BASE_PATTERNS = [
"**/{keyword}[{sep}]*",
KEYWORDS_IN_FILENAME_BASE_PATTERNS = ["**/{keyword}[{sep}]*", "**/*[{sep}]{keyword}[{sep}]*"]
KEYWORDS_IN_DIR_NAME_BASE_PATTERNS = [
"**/{keyword}/**",
"**/*[{sep}]{keyword}[{sep}]*",
"**/*[{sep}]{keyword}[{sep}]*/**",
"**/{keyword}[{sep}]*/**",
"**/*[{sep}]{keyword}/**",
"**/*[{sep}]{keyword}[{sep}]*/**",
]

DEFAULT_SPLITS = [Split.TRAIN, Split.VALIDATION, Split.TEST]
DEFAULT_PATTERNS_SPLIT_IN_PATH_NAME = {
DEFAULT_PATTERNS_SPLIT_IN_FILENAME = {
split: [
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
for keyword in SPLIT_KEYWORDS[split]
for pattern in KEYWORDS_IN_PATH_NAME_BASE_PATTERNS
for pattern in KEYWORDS_IN_FILENAME_BASE_PATTERNS
]
for split in DEFAULT_SPLITS
}
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME = {
split: [
pattern.format(keyword=keyword, sep=NON_WORDS_CHARS)
for keyword in SPLIT_KEYWORDS[split]
for pattern in KEYWORDS_IN_DIR_NAME_BASE_PATTERNS
]
for split in DEFAULT_SPLITS
}


DEFAULT_PATTERNS_ALL = {
Split.TRAIN: ["**"],
}

ALL_SPLIT_PATTERNS = [SPLIT_PATTERN_SHARDED]
ALL_DEFAULT_PATTERNS = [
DEFAULT_PATTERNS_SPLIT_IN_PATH_NAME,
DEFAULT_PATTERNS_SPLIT_IN_DIR_NAME,
DEFAULT_PATTERNS_SPLIT_IN_FILENAME,
DEFAULT_PATTERNS_ALL,
]
if config.FSSPEC_VERSION < version.parse("2023.9.0"):
Expand Down Expand Up @@ -351,7 +372,7 @@ def resolve_pattern(
else:
base_path = ""
pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config)
fs, _, _ = get_fs_token_paths(pattern, storage_options=storage_options)
fs, *_ = url_to_fs(pattern, **storage_options)
fs_base_path = base_path.split("::")[0].split("://")[-1] or fs.root_marker
fs_pattern = pattern.split("::")[0].split("://")[-1]
files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
Expand Down Expand Up @@ -409,7 +430,7 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig]
Output:
{"train": ["**"]}
{'train': ['**']}
Input:
Expand All @@ -435,8 +456,8 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig]
Output:
{'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**', 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**', 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
{'train': ['**/train[-._ 0-9]*', '**/*[-._ 0-9]train[-._ 0-9]*', '**/training[-._ 0-9]*', '**/*[-._ 0-9]training[-._ 0-9]*'],
'test': ['**/test[-._ 0-9]*', '**/*[-._ 0-9]test[-._ 0-9]*', '**/testing[-._ 0-9]*', '**/*[-._ 0-9]testing[-._ 0-9]*', ...]}
Input:
Expand All @@ -454,8 +475,8 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig]
Output:
{'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**', 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**', 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
{'train': ['**/train/**', '**/train[-._ 0-9]*/**', '**/*[-._ 0-9]train/**', '**/*[-._ 0-9]train[-._ 0-9]*/**', ...],
'test': ['**/test/**', '**/test[-._ 0-9]*/**', '**/*[-._ 0-9]test/**', '**/*[-._ 0-9]test[-._ 0-9]*/**', ...]}
Input:
Expand Down Expand Up @@ -504,7 +525,7 @@ def _get_single_origin_metadata(
download_config: Optional[DownloadConfig] = None,
) -> Tuple[str]:
data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
fs, _, _ = get_fs_token_paths(data_file, storage_options=storage_options)
fs, *_ = url_to_fs(data_file, **storage_options)
if isinstance(fs, HfFileSystem):
resolved_path = fs.resolve_path(data_file)
return (resolved_path.repo_id, resolved_path.revision)
Expand Down
5 changes: 3 additions & 2 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import fsspec
import numpy as np
from fsspec.core import url_to_fs
from huggingface_hub import (
CommitInfo,
CommitOperationAdd,
Expand Down Expand Up @@ -1280,7 +1281,7 @@ def save_to_disk(
storage_options = fs.storage_options

fs: fsspec.AbstractFileSystem
fs, _, _ = fsspec.get_fs_token_paths(dataset_dict_path, storage_options=storage_options)
fs, _ = url_to_fs(dataset_dict_path, **(storage_options or {}))

if num_shards is None:
num_shards = {k: None for k in self}
Expand Down Expand Up @@ -1354,7 +1355,7 @@ def load_from_disk(
storage_options = fs.storage_options

fs: fsspec.AbstractFileSystem
fs, _, [dataset_dict_path] = fsspec.get_fs_token_paths(dataset_dict_path, storage_options=storage_options)
fs, dataset_dict_path = url_to_fs(dataset_dict_path, **(storage_options or {}))

dataset_dict_json_path = posixpath.join(dataset_dict_path, config.DATASETDICT_JSON_FILENAME)
dataset_state_json_path = posixpath.join(dataset_dict_path, config.DATASET_STATE_JSON_FILENAME)
Expand Down
25 changes: 9 additions & 16 deletions src/datasets/download/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import fsspec
from aiohttp.client_exceptions import ClientError
from fsspec.core import url_to_fs
from huggingface_hub.utils import EntryNotFoundError
from packaging import version

Expand Down Expand Up @@ -159,7 +160,7 @@ def xexists(urlpath: str, download_config: Optional[DownloadConfig] = None):
else:
urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config)
main_hop, *rest_hops = urlpath.split("::")
fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options)
fs, *_ = url_to_fs(urlpath, **storage_options)
return fs.exists(main_hop)


Expand Down Expand Up @@ -259,7 +260,7 @@ def xisfile(path, download_config: Optional[DownloadConfig] = None) -> bool:
else:
path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config)
main_hop, *rest_hops = path.split("::")
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
fs, *_ = url_to_fs(path, **storage_options)
return fs.isfile(main_hop)


Expand All @@ -279,7 +280,7 @@ def xgetsize(path, download_config: Optional[DownloadConfig] = None) -> int:
else:
path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config)
main_hop, *rest_hops = path.split("::")
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
fs, *_ = fs, *_ = url_to_fs(path, **storage_options)
try:
size = fs.size(main_hop)
except EntryNotFoundError:
Expand Down Expand Up @@ -307,7 +308,7 @@ def xisdir(path, download_config: Optional[DownloadConfig] = None) -> bool:
else:
path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config)
main_hop, *rest_hops = path.split("::")
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
fs, *_ = fs, *_ = url_to_fs(path, **storage_options)
inner_path = main_hop.split("://")[-1]
if not inner_path.strip("/"):
return True
Expand Down Expand Up @@ -546,7 +547,7 @@ def xlistdir(path: str, download_config: Optional[DownloadConfig] = None) -> Lis
# globbing inside a zip in a private repo requires authentication
path, storage_options = _prepare_path_and_storage_options(path, download_config=download_config)
main_hop, *rest_hops = path.split("::")
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
fs, *_ = url_to_fs(path, **storage_options)
inner_path = main_hop.split("://")[-1]
if inner_path.strip("/") and not fs.isdir(inner_path):
raise FileNotFoundError(f"Directory doesn't exist: {path}")
Expand All @@ -573,11 +574,7 @@ def xglob(urlpath, *, recursive=False, download_config: Optional[DownloadConfig]
# globbing inside a zip in a private repo requires authentication
urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config)
main_hop, *rest_hops = urlpath.split("::")
fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options)
# - If there's no "*" in the pattern, get_fs_token_paths() doesn't do any pattern matching
# so to be able to glob patterns like "[0-9]", we have to call `fs.glob`.
# - Also "*" in get_fs_token_paths() only matches files: we have to call `fs.glob` to match directories.
# - If there is "**" in the pattern, `fs.glob` must be called anyway.
fs, *_ = url_to_fs(urlpath, **storage_options)
inner_path = main_hop.split("://")[1]
globbed_paths = fs.glob(inner_path)
protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1]
Expand All @@ -603,7 +600,7 @@ def xwalk(urlpath, download_config: Optional[DownloadConfig] = None, **kwargs):
# walking inside a zip in a private repo requires authentication
urlpath, storage_options = _prepare_path_and_storage_options(urlpath, download_config=download_config)
main_hop, *rest_hops = urlpath.split("::")
fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options)
fs, *_ = url_to_fs(urlpath, **storage_options)
inner_path = main_hop.split("://")[-1]
if inner_path.strip("/") and not fs.isdir(inner_path):
return []
Expand Down Expand Up @@ -659,11 +656,7 @@ def glob(self, pattern, download_config: Optional[DownloadConfig] = None):
posix_path = "::".join([main_hop, urlpath, *rest_hops[1:]])
else:
storage_options = None
fs, *_ = fsspec.get_fs_token_paths(xjoin(posix_path, pattern), storage_options=storage_options)
# - If there's no "*" in the pattern, get_fs_token_paths() doesn't do any pattern matching
# so to be able to glob patterns like "[0-9]", we have to call `fs.glob`.
# - Also "*" in get_fs_token_paths() only matches files: we have to call `fs.glob` to match directories.
# - If there is "**" in the pattern, `fs.glob` must be called anyway.
fs, *_ = url_to_fs(xjoin(posix_path, pattern), **(storage_options or {}))
globbed_paths = fs.glob(xjoin(main_hop, pattern))
for globbed_path in globbed_paths:
yield type(self)("::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops))
Expand Down
17 changes: 0 additions & 17 deletions src/datasets/filesystems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib
import shutil
import threading
import warnings
from typing import List

Expand Down Expand Up @@ -68,19 +67,3 @@ def rename(fs: fsspec.AbstractFileSystem, src: str, dst: str):
shutil.move(fs._strip_protocol(src), fs._strip_protocol(dst))
else:
fs.mv(src, dst, recursive=True)


def _reset_fsspec_lock() -> None:
"""
Clear reference to the loop and thread.
This is necessary otherwise HTTPFileSystem hangs in the ML training loop.
Only required for fsspec >= 0.9.0
See https://github.com/fsspec/gcsfs/issues/379
"""
if hasattr(fsspec.asyn, "reset_lock"):
# for future fsspec>2022.05.0
fsspec.asyn.reset_lock()
else:
fsspec.asyn.iothread[0] = None
fsspec.asyn.loop[0] = None
fsspec.asyn.lock = threading.Lock()
5 changes: 3 additions & 2 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from typing import ClassVar, Dict, List, Optional, Union

import fsspec
from fsspec.core import url_to_fs
from huggingface_hub import DatasetCard, DatasetCardData

from . import config
Expand Down Expand Up @@ -251,7 +252,7 @@ def write_to_directory(
storage_options = fs.storage_options

fs: fsspec.AbstractFileSystem
fs, _, _ = fsspec.get_fs_token_paths(dataset_info_dir, storage_options=storage_options)
fs, *_ = url_to_fs(dataset_info_dir, **(storage_options or {}))
with fs.open(posixpath.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f:
self._dump_info(f, pretty_print=pretty_print)
if self.license:
Expand Down Expand Up @@ -347,7 +348,7 @@ def from_directory(
storage_options = fs.storage_options

fs: fsspec.AbstractFileSystem
fs, _, _ = fsspec.get_fs_token_paths(dataset_info_dir, storage_options=storage_options)
fs, *_ = url_to_fs(dataset_info_dir, **(storage_options or {}))
logger.info(f"Loading Dataset info from {dataset_info_dir}")
if not dataset_info_dir:
raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.")
Expand Down
7 changes: 4 additions & 3 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from itertools import cycle, islice
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union

import fsspec.asyn
import numpy as np
import pyarrow as pa

from . import config
from .arrow_dataset import Dataset, DatasetInfoMixin
from .features import Features
from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects
from .filesystems import _reset_fsspec_lock
from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter
from .info import DatasetInfo
from .splits import NamedSplit
Expand Down Expand Up @@ -1257,8 +1257,9 @@ def n_shards(self) -> int:

def _iter_pytorch(self):
ex_iterable = self._prepare_ex_iterable_for_iteration()
# fix for fsspec when using multiprocess
_reset_fsspec_lock()
# Fix for fsspec when using multiprocess to avoid hanging in the ML training loop. (only required for fsspec >= 0.9.0)
# See https://github.com/fsspec/gcsfs/issues/379
fsspec.asyn.reset_lock()
# check if there aren't too many workers
import torch.utils.data

Expand Down
Loading

0 comments on commit d1d3c06

Please sign in to comment.