Skip to content

Commit

Permalink
more general fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Jun 14, 2024
1 parent 6e380f7 commit 077d798
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 27 deletions.
8 changes: 5 additions & 3 deletions src/datasets/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import config
from .table import CastError
from .utils.deprecation_utils import deprecated
from .utils.track import TrackedIterable, tracked_list, tracked_str
from .utils.track import TrackedIterableFromGenerator, tracked_list, tracked_str


class DatasetsError(Exception):
Expand Down Expand Up @@ -65,9 +65,11 @@ def from_cast_error(
)
formatted_tracked_gen_kwargs: List[str] = []
for gen_kwarg in gen_kwargs.values():
if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterable)):
if not isinstance(gen_kwarg, (tracked_str, tracked_list, TrackedIterableFromGenerator)):
continue
while isinstance(gen_kwarg, (tracked_list, TrackedIterable)) and gen_kwarg.last_item is not None:
while (
isinstance(gen_kwarg, (tracked_list, TrackedIterableFromGenerator)) and gen_kwarg.last_item is not None
):
gen_kwarg = gen_kwarg.last_item
if isinstance(gen_kwarg, tracked_str):
gen_kwarg = gen_kwarg.get_origin()
Expand Down
26 changes: 4 additions & 22 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from io import BytesIO
from itertools import chain
from pathlib import Path, PurePosixPath
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
from unittest.mock import patch
from urllib.parse import urljoin, urlparse
from xml.etree import ElementTree as ET
Expand All @@ -47,7 +47,7 @@
from . import tqdm as hf_tqdm
from ._filelock import FileLock
from .extract import ExtractManager
from .track import TrackedIterable
from .track import TrackedIterableFromGenerator


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -1564,25 +1564,7 @@ def xxml_dom_minidom_parse(filename_or_file, download_config: Optional[DownloadC
return xml.dom.minidom.parse(f, **kwargs)


class _IterableFromGenerator(TrackedIterable):
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""

def __init__(self, generator: Callable, *args):
super().__init__()
self.generator = generator
self.args = args

def __iter__(self):
for x in self.generator(*self.args):
self.last_item = x
yield x
self.last_item = None

def __reduce__(self):
return (self.__class__, (self.generator, *self.args))


class ArchiveIterable(_IterableFromGenerator):
class ArchiveIterable(TrackedIterableFromGenerator):
"""An iterable of (path, fileobj) from a TAR archive, used by `iter_archive`"""

@staticmethod
Expand Down Expand Up @@ -1647,7 +1629,7 @@ def from_urlpath(cls, urlpath_or_buf, download_config: Optional[DownloadConfig]
return cls(cls._iter_from_urlpath, urlpath_or_buf, download_config)


class FilesIterable(_IterableFromGenerator):
class FilesIterable(TrackedIterableFromGenerator):
"""An iterable of paths from a list of directories or files"""

@classmethod
Expand Down
17 changes: 15 additions & 2 deletions src/datasets/utils/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,26 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(current={self.last_item})"


class TrackedIterable(Iterable):
def __init__(self) -> None:
class TrackedIterableFromGenerator(Iterable):
"""Utility class to create an iterable from a generator function, in order to reset the generator when needed."""

def __init__(self, generator, *args):
super().__init__()
self.generator = generator
self.args = args
self.last_item = None

def __iter__(self):
for x in self.generator(*self.args):
self.last_item = x
yield x
self.last_item = None

def __repr__(self) -> str:
if self.last_item is None:
return super().__repr__()
else:
return f"{self.__class__.__name__}(current={self.last_item})"

def __reduce__(self):
return (self.__class__, (self.generator, *self.args))

0 comments on commit 077d798

Please sign in to comment.