diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d0e77a5541..873a41be56 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2687,7 +2687,8 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float: nbytes: int = 0 for dts in deps: nbytes += dts.nbytes - return nbytes / self.bandwidth + # Add a fixed 10ms penalty per transfer. See distributed#5324 + return nbytes / self.bandwidth + 0.01 * len(deps) def get_task_duration(self, ts: TaskState) -> float: """Get the estimated computation cost of the given task (not including @@ -2799,13 +2800,17 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: """ dts: TaskState comm_bytes: int = 0 + xfers = 0 for dts in ts.dependencies: if ws not in dts.who_has: nbytes = dts.get_nbytes() - comm_bytes += nbytes + # amortize transfer cost over all waiters + comm_bytes += nbytes / len(dts.waiters) + xfers += 1 - stack_time: float = ws.occupancy / ws.nthreads - start_time: float = stack_time + comm_bytes / self.bandwidth + stack_time = ws.occupancy / ws.nthreads + # Add a fixed 10ms penalty per transfer. See distributed#5324 + start_time = stack_time + comm_bytes / self.bandwidth + xfers * 0.01 if ts.actor: return (len(ws.actors), start_time, ws.nbytes)