Skip to content

Commit

Permalink
Refactor patching to specific submodule (#2639)
Browse files Browse the repository at this point in the history
* Create patching submodule

* Minor fix in docstring section header
  • Loading branch information
albertvillanova committed Jul 13, 2021
1 parent 4aff493 commit c722810
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 63 deletions.
64 changes: 1 addition & 63 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,13 @@
from typing import Optional, Union

from .utils.logging import get_logger
from .utils.patching import patch_submodule
from .utils.streaming_download_manager import xjoin, xopen


logger = get_logger(__name__)


class _PatchedModuleObj:
"""Set all the modules components as attributes of the _PatchedModuleObj object"""

def __init__(self, module):
if module is not None:
for key in getattr(module, "__all__", module.__dict__):
if not key.startswith("__"):
setattr(self, key, getattr(module, key))


class patch_submodule:
"""
Patch a submodule attribute of an object, by keeping all other submodules intact at all levels.
Example::
>>> import importlib
>>> from datasets.load import prepare_module
>>> from datasets.streaming import patch_submodule, xjoin
>>>
>>> snli_module_path, _ = prepare_module("snli")
>>> snli_module = importlib.import_module(snli_module_path)
>>> patcher = patch_submodule(snli_module, "os.path.join", xjoin)
>>> patcher.start()
>>> assert snli_module.os.path.join is xjoin
"""

_active_patches = []

def __init__(self, obj, target: str, new):
self.obj = obj
self.target = target
self.new = new
self.key = target.split(".")[0]
self.original = getattr(obj, self.key, None)

def __enter__(self):
*submodules, attr = self.target.split(".")
current = self.obj
for key in submodules:
setattr(current, key, _PatchedModuleObj(getattr(current, key, None)))
current = getattr(current, key)
setattr(current, attr, self.new)

def __exit__(self, *exc_info):
setattr(self.obj, self.key, self.original)

def start(self):
"""Activate a patch."""
self.__enter__()
self._active_patches.append(self)

def stop(self):
"""Stop an active patch."""
try:
self._active_patches.remove(self)
except ValueError:
# If the patch hasn't been started this will fail
return None

return self.__exit__()


def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str, bool]] = None):
"""
Extend the `open` and `os.path.join` functions of the module to support data streaming.
Expand Down
67 changes: 67 additions & 0 deletions src/datasets/utils/patching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from .logging import get_logger


logger = get_logger(__name__)


class _PatchedModuleObj:
"""Set all the modules components as attributes of the _PatchedModuleObj object."""

def __init__(self, module):
if module is not None:
for key in getattr(module, "__all__", module.__dict__):
if not key.startswith("__"):
setattr(self, key, getattr(module, key))


class patch_submodule:
"""
Patch a submodule attribute of an object, by keeping all other submodules intact at all levels.
Examples:
>>> import importlib
>>> from datasets.load import prepare_module
>>> from datasets.streaming import patch_submodule, xjoin
>>>
>>> snli_module_path, _ = prepare_module("snli")
>>> snli_module = importlib.import_module(snli_module_path)
>>> patcher = patch_submodule(snli_module, "os.path.join", xjoin)
>>> patcher.start()
>>> assert snli_module.os.path.join is xjoin
"""

_active_patches = []

def __init__(self, obj, target: str, new):
self.obj = obj
self.target = target
self.new = new
self.key = target.split(".")[0]
self.original = getattr(obj, self.key, None)

def __enter__(self):
*submodules, attr = self.target.split(".")
current = self.obj
for key in submodules:
setattr(current, key, _PatchedModuleObj(getattr(current, key, None)))
current = getattr(current, key)
setattr(current, attr, self.new)

def __exit__(self, *exc_info):
setattr(self.obj, self.key, self.original)

def start(self):
"""Activate a patch."""
self.__enter__()
self._active_patches.append(self)

def stop(self):
"""Stop an active patch."""
try:
self._active_patches.remove(self)
except ValueError:
# If the patch hasn't been started this will fail
return None

return self.__exit__()

1 comment on commit c722810

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==3.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010115 / 0.011353 (-0.001238) 0.004001 / 0.011008 (-0.007008) 0.036542 / 0.038508 (-0.001967) 0.040359 / 0.023109 (0.017249) 0.345567 / 0.275898 (0.069669) 0.383086 / 0.323480 (0.059606) 0.008580 / 0.007986 (0.000594) 0.005871 / 0.004328 (0.001543) 0.010442 / 0.004250 (0.006191) 0.045279 / 0.037052 (0.008226) 0.372734 / 0.258489 (0.114245) 0.382151 / 0.293841 (0.088310) 0.025887 / 0.128546 (-0.102659) 0.008658 / 0.075646 (-0.066989) 0.296858 / 0.419271 (-0.122414) 0.051626 / 0.043533 (0.008093) 0.356346 / 0.255139 (0.101207) 0.392528 / 0.283200 (0.109329) 0.091891 / 0.141683 (-0.049792) 1.883359 / 1.452155 (0.431204) 1.921937 / 1.492716 (0.429220)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.021802 / 0.018006 (0.003795) 0.475789 / 0.000490 (0.475299) 0.004250 / 0.000200 (0.004050) 0.000445 / 0.000054 (0.000391)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.041204 / 0.037411 (0.003793) 0.025958 / 0.014526 (0.011432) 0.027711 / 0.176557 (-0.148845) 0.142963 / 0.737135 (-0.594172) 0.029022 / 0.296338 (-0.267316)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.419328 / 0.215209 (0.204119) 4.186895 / 2.077655 (2.109240) 2.242537 / 1.504120 (0.738417) 2.007982 / 1.541195 (0.466787) 2.032434 / 1.468490 (0.563944) 0.358842 / 4.584777 (-4.225935) 5.146939 / 3.745712 (1.401227) 3.442168 / 5.269862 (-1.827694) 1.327768 / 4.565676 (-3.237909) 0.041719 / 0.424275 (-0.382556) 0.005983 / 0.007607 (-0.001624) 0.534861 / 0.226044 (0.308816) 5.361613 / 2.268929 (3.092685) 2.700547 / 55.444624 (-52.744077) 2.324387 / 6.876477 (-4.552090) 2.349811 / 2.142072 (0.207739) 0.488977 / 4.805227 (-4.316250) 0.115905 / 6.500664 (-6.384760) 0.060956 / 0.075469 (-0.014513)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 12.389053 / 1.841788 (10.547265) 14.462584 / 8.074308 (6.388275) 30.495105 / 10.191392 (20.303713) 0.885414 / 0.680424 (0.204991) 0.596339 / 0.534201 (0.062138) 0.260623 / 0.579283 (-0.318660) 0.571787 / 0.434364 (0.137423) 0.195853 / 0.540337 (-0.344485) 1.023616 / 1.386936 (-0.363320)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010000 / 0.011353 (-0.001353) 0.003963 / 0.011008 (-0.007045) 0.035644 / 0.038508 (-0.002864) 0.040541 / 0.023109 (0.017432) 0.338990 / 0.275898 (0.063092) 0.373133 / 0.323480 (0.049653) 0.008581 / 0.007986 (0.000596) 0.005086 / 0.004328 (0.000758) 0.010055 / 0.004250 (0.005805) 0.043418 / 0.037052 (0.006365) 0.337463 / 0.258489 (0.078974) 0.380023 / 0.293841 (0.086182) 0.026642 / 0.128546 (-0.101904) 0.008543 / 0.075646 (-0.067104) 0.296429 / 0.419271 (-0.122842) 0.052319 / 0.043533 (0.008786) 0.340573 / 0.255139 (0.085434) 0.366980 / 0.283200 (0.083780) 0.091051 / 0.141683 (-0.050631) 1.834597 / 1.452155 (0.382442) 1.857180 / 1.492716 (0.364464)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.096946 / 0.018006 (0.078940) 0.475635 / 0.000490 (0.475146) 0.052876 / 0.000200 (0.052676) 0.000587 / 0.000054 (0.000533)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.041885 / 0.037411 (0.004473) 0.026099 / 0.014526 (0.011574) 0.028489 / 0.176557 (-0.148067) 0.145543 / 0.737135 (-0.591592) 0.030658 / 0.296338 (-0.265680)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.409786 / 0.215209 (0.194577) 4.066752 / 2.077655 (1.989097) 2.122937 / 1.504120 (0.618818) 1.917466 / 1.541195 (0.376272) 1.931556 / 1.468490 (0.463066) 0.355350 / 4.584777 (-4.229427) 5.183337 / 3.745712 (1.437625) 3.188543 / 5.269862 (-2.081319) 1.156052 / 4.565676 (-3.409625) 0.041880 / 0.424275 (-0.382395) 0.005912 / 0.007607 (-0.001695) 0.526795 / 0.226044 (0.300751) 5.261721 / 2.268929 (2.992793) 2.650313 / 55.444624 (-52.794311) 2.237720 / 6.876477 (-4.638757) 2.272189 / 2.142072 (0.130117) 0.486781 / 4.805227 (-4.318446) 0.114353 / 6.500664 (-6.386311) 0.060948 / 0.075469 (-0.014521)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 12.682590 / 1.841788 (10.840803) 14.518013 / 8.074308 (6.443705) 30.714134 / 10.191392 (20.522742) 0.898083 / 0.680424 (0.217659) 0.582970 / 0.534201 (0.048769) 0.262239 / 0.579283 (-0.317044) 0.570363 / 0.434364 (0.135999) 0.204780 / 0.540337 (-0.335557) 1.124312 / 1.386936 (-0.262624)

CML watermark

Please sign in to comment.