Skip to content

Commit

Permalink
Prevent data duplication on unspill (#5936)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 15, 2022
1 parent 2fffe74 commit 7a69b5e
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 160 deletions.
27 changes: 16 additions & 11 deletions distributed/spill.py
Expand Up @@ -2,7 +2,7 @@

import logging
import time
from collections.abc import Mapping
from collections.abc import Mapping, MutableMapping
from contextlib import contextmanager
from functools import partial
from typing import Any, Literal, NamedTuple, cast
Expand All @@ -14,7 +14,9 @@
from distributed.sizeof import safe_sizeof

logger = logging.getLogger(__name__)
has_zict_210 = parse_version(zict.__version__) > parse_version("2.0.0")
has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0")
# At the moment of writing, zict 2.2.0 has not been released yet. Support git tip.
has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0.dev2")


class SpilledSize(NamedTuple):
Expand All @@ -38,7 +40,7 @@ class SpillBuffer(zict.Buffer):
the total size of the stored data exceeds the target. If max_spill is provided the
key/value pairs won't be spilled once this threshold has been reached.
Paramaters
Parameters
----------
spill_directory: str
Location on disk to write the spill files to
Expand All @@ -63,14 +65,15 @@ def __init__(
):

if max_spill is not False and not has_zict_210:
raise ValueError("zict > 2.0.0 required to set max_weight")
raise ValueError("zict >= 2.1.0 required to set max-spill")

super().__init__(
fast={},
slow=Slow(spill_directory, max_spill),
n=target,
weight=_in_memory_weight,
)
slow: MutableMapping[str, Any] = Slow(spill_directory, max_spill)
if has_zict_220:
# If a value is still in use somewhere on the worker since the last time it
# was unspilled, don't duplicate it
slow = zict.Cache(slow, zict.WeakValueMapping())

super().__init__(fast={}, slow=slow, n=target, weight=_in_memory_weight)
self.last_logged = 0
self.min_log_interval = min_log_interval
self.logged_pickle_errors = set() # keys logged with pickle error
Expand Down Expand Up @@ -204,7 +207,8 @@ def spilled_total(self) -> SpilledSize:
The two may differ substantially, e.g. if sizeof() is inaccurate or in case of
compression.
"""
return cast(Slow, self.slow).total_weight
slow = cast(zict.Cache, self.slow).data if has_zict_220 else self.slow
return cast(Slow, slow).total_weight


def _in_memory_weight(key: str, value: Any) -> int:
Expand All @@ -224,6 +228,7 @@ class HandledError(Exception):
pass


# zict.Func[str, Any] requires zict >= 2.2.0
class Slow(zict.Func):
max_weight: int | Literal[False]
weight_by_key: dict[str, SpilledSize]
Expand Down

0 comments on commit 7a69b5e

Please sign in to comment.