Skip to content

Commit

Permalink
Cache backward compatibility with 2.15.0 (#6514)
Browse files Browse the repository at this point in the history
* cache backward compatibility

* fix win

* debug run action for win

* fix
  • Loading branch information
lhoestq committed Dec 21, 2023
1 parent e1b82ea commit 2afbf78
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 25 deletions.
106 changes: 81 additions & 25 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, Union
from unittest.mock import patch

import fsspec
import pyarrow as pa
Expand Down Expand Up @@ -85,6 +86,10 @@
from .utils.track import tracked_list


if TYPE_CHECKING:
from .load import DatasetModule


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -400,6 +405,10 @@ def __init__(
if is_remote_url(self._cache_downloaded_dir)
else os.path.expanduser(self._cache_downloaded_dir)
)

# In case there exists a legacy cache directory
self._legacy_relative_data_dir = None

self._cache_dir = self._build_cache_dir()
if not is_remote_url(self._cache_dir_root):
os.makedirs(self._cache_dir_root, exist_ok=True)
Expand Down Expand Up @@ -452,23 +461,71 @@ def __setstate__(self, d):
def manual_download_instructions(self) -> Optional[str]:
return None

def _has_legacy_cache(self) -> bool:
"""Check for the old cache directory template {cache_dir}/{namespace}___{builder_name}"""
def _check_legacy_cache(self) -> Optional[str]:
"""Check for the old cache directory template {cache_dir}/{namespace}___{builder_name} from 2.13"""
if (
self.__module__.startswith("datasets.")
and not is_remote_url(self._cache_dir_root)
and self.config.name == "default"
):
from .packaged_modules import _PACKAGED_DATASETS_MODULES

namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
config_id = config_name + self.config_id[len(self.config.name) :]
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
legacy_relative_data_dir = posixpath.join(
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
config_id,
"0.0.0",
hash,
)
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
if os.path.isdir(legacy_cache_dir):
return legacy_relative_data_dir

def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]:
"""Check for the old cache directory template {cache_dir}/{namespace}___{dataset_name}/{config_name}-xxx from 2.14 and 2.15"""
if self.__module__.startswith("datasets.") and not is_remote_url(self._cache_dir_root):
from .packaged_modules import _PACKAGED_DATASETS_MODULES
from .utils._dill import Pickler

def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
"""
Used to update hash of packaged modules which is used for creating unique cache directories to reflect
different config parameters which are passed in metadata from readme.
"""
params_to_exclude = {"config_name", "version", "description"}
params_to_add_to_hash = {
param: value
for param, value in sorted(config_parameters.items())
if param not in params_to_exclude
}
m = Hasher()
m.update(hash)
m.update(params_to_add_to_hash)
return m.hexdigest()

namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
legacy_config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
legacy_config_id = legacy_config_name + self.config_id[len(self.config.name) :]
legacy_cache_dir = os.path.join(
self._cache_dir_root,
self.name if namespace is None else f"{namespace}___{self.name}",
legacy_config_id,
with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True):
config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
if (
dataset_module.builder_configs_parameters.metadata_configs
and self.config.name in dataset_module.builder_configs_parameters.metadata_configs
):
hash = update_hash_with_config_parameters(
hash, dataset_module.builder_configs_parameters.metadata_configs[self.config.name]
)
legacy_relative_data_dir = posixpath.join(
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
config_id,
"0.0.0",
hash,
)
return os.path.isdir(legacy_cache_dir)
return False
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
if os.path.isdir(legacy_cache_dir):
return legacy_relative_data_dir

@classmethod
def get_all_exported_dataset_infos(cls) -> DatasetInfosDict:
Expand Down Expand Up @@ -600,6 +657,14 @@ def builder_configs(cls) -> Dict[str, BuilderConfig]:
def cache_dir(self):
return self._cache_dir

def _use_legacy_cache_dir_if_possible(self, dataset_module: "DatasetModule"):
# Check for the legacy cache directory template (datasets<3.0.0)
self._legacy_relative_data_dir = (
self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None
)
self._cache_dir = self._build_cache_dir()
self._output_dir = self._cache_dir

def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
"""Relative path of this dataset in cache_dir:
Will be:
Expand All @@ -608,21 +673,12 @@ def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
self.namespace___self.dataset_name/self.config.version/self.hash/
If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
"""

# Check for the legacy cache directory template (datasets<3.0.0)
if self._has_legacy_cache():
# use legacy names
dataset_name = self.name
config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
config_id = config_name + self.config_id[len(self.config.name) :]
else:
dataset_name = self.dataset_name
config_name = self.config.name
config_id = self.config_id
if self._legacy_relative_data_dir is not None and with_version and with_hash:
return self._legacy_relative_data_dir

namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
builder_data_dir = dataset_name if namespace is None else f"{namespace}___{dataset_name}"
builder_data_dir = posixpath.join(builder_data_dir, config_id)
builder_data_dir = self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}"
builder_data_dir = posixpath.join(builder_data_dir, self.config_id)
if with_version:
builder_data_dir = posixpath.join(builder_data_dir, str(self.config.version))
if with_hash and self.hash and isinstance(self.hash, str):
Expand Down Expand Up @@ -1285,7 +1341,7 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem
"""
cache_dir = self._fs._strip_protocol(self._output_dir)
dataset_name = self.dataset_name
if self._has_legacy_cache():
if self._check_legacy_cache():
dataset_name = self.name
dataset_kwargs = ArrowReader(cache_dir, self.info).read(
name=dataset_name,
Expand Down
1 change: 1 addition & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,6 +2239,7 @@ def load_dataset_builder(
**builder_kwargs,
**config_kwargs,
)
builder_instance._use_legacy_cache_dir_if_possible(dataset_module)

return builder_instance

Expand Down
3 changes: 3 additions & 0 deletions src/datasets/utils/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class Pickler(dill.Pickler):
dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())
_legacy_no_dict_keys_sorting = False

def save(self, obj, save_persistent_id=True):
obj_type = type(obj)
Expand Down Expand Up @@ -68,6 +69,8 @@ def save(self, obj, save_persistent_id=True):
dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id)

def _batch_setitems(self, items):
if self._legacy_no_dict_keys_sorting:
return super()._batch_setitems(items)
# Ignore the order of keys in a dict
try:
# Faster, but fails for unorderable elements
Expand Down
31 changes: 31 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,3 +1657,34 @@ def test_resolve_trust_remote_code_future(trust_remote_code, expected):
else:
with pytest.raises(expected):
resolve_trust_remote_code(trust_remote_code, repo_id="dummy")


@pytest.mark.integration
def test_reload_old_cache_from_2_15(tmp_path: Path):
cache_dir = tmp_path / "test_reload_old_cache_from_2_15"
builder_cache_dir = (
cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/7896925d64deea5d"
)
builder_cache_dir.mkdir(parents=True)
arrow_path = builder_cache_dir / "audiofolder_two_configs_in_metadata-train.arrow"
dataset_info_path = builder_cache_dir / "dataset_info.json"
with dataset_info_path.open("w") as f:
f.write("{}")
arrow_path.touch()
builder = load_dataset_builder(
"polinaeterna/audiofolder_two_configs_in_metadata",
"v2",
data_files="v2/train/*",
cache_dir=cache_dir.as_posix(),
)
assert builder.cache_dir == builder_cache_dir.as_posix() # old cache from 2.15

builder = load_dataset_builder(
"polinaeterna/audiofolder_two_configs_in_metadata", "v2", cache_dir=cache_dir.as_posix()
)
assert (
builder.cache_dir
== (
cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata" / "v2" / "0.0.0" / str(builder.hash)
).as_posix()
) # new cache

0 comments on commit 2afbf78

Please sign in to comment.