From 285893037fe9eac83f363611b4799168aabb3992 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 20 Sep 2023 16:16:25 +0200 Subject: [PATCH] Reduce memory usage during culling for shuffling and merge (#8197) --- distributed/shuffle/_merge.py | 14 ++++++-------- distributed/shuffle/_shuffle.py | 7 +++++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index 75e72e884d..7d69f40d14 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -1,7 +1,6 @@ # mypy: ignore-errors from __future__ import annotations -from collections import defaultdict from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any @@ -243,15 +242,14 @@ def _cull_dependencies( all input partitions. This method does not require graph materialization. """ - deps = defaultdict(set) + deps = {} parts_out = parts_out or self._keys_to_parts(keys) + keys = {(self.name_input_left, i) for i in range(self.npartitions)} + keys |= {(self.name_input_right, i) for i in range(self.npartitions)} + # Protect against mutations later on with frozenset + keys = frozenset(keys) for part in parts_out: - deps[(self.name, part)] |= { - (self.name_input_left, i) for i in range(self.npartitions) - } - deps[(self.name, part)] |= { - (self.name_input_right, i) for i in range(self.npartitions) - } + deps[(self.name, part)] = keys return deps def _keys_to_parts(self, keys: Iterable[str]) -> set[str]: diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 43fe87b4c5..6c80dc26d3 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -227,8 +227,11 @@ def cull( parameter. """ parts_out = self._keys_to_parts(keys) - input_parts = {(self.name_input, i) for i in range(self.npartitions_input)} - culled_deps = {(self.name, part): input_parts.copy() for part in parts_out} + # Protect against mutations later on with frozenset + input_parts = frozenset( + {(self.name_input, i) for i in range(self.npartitions_input)} + ) + culled_deps = {(self.name, part): input_parts for part in parts_out} if parts_out != set(self.parts_out): culled_layer = self._cull(parts_out)