diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e38447c70b..cb73a0d0dc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2621,7 +2621,17 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: on the given worker. """ dts: TaskState - deps: set = ts.dependencies.difference(ws.has_what) + deps: set + if 10 * len(ts.dependencies) < len(ws.has_what): + # In the common case where the number of dependencies is + # much less than the number of tasks that we have, + # construct the set of deps that require communication in + # O(len(dependencies)) rather than O(len(has_what)) time. + # Factor of 10 is a guess at the overhead of explicit + # iteration as opposed to just calling set.difference + deps = {dep for dep in ts.dependencies if dep not in ws.has_what} + else: + deps = ts.dependencies.difference(ws.has_what) nbytes: int = 0 for dts in deps: nbytes += dts.nbytes