Skip to content

Commit

Permalink
new argument worker_nthreads in ECSCluster (#321)
Browse files Browse the repository at this point in the history
  • Loading branch information
drorspei committed Dec 1, 2021
1 parent 8cf979b commit 09c72a2
Showing 1 changed file with 23 additions and 3 deletions.
26 changes: 23 additions & 3 deletions dask_cloudprovider/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
import warnings
import weakref
from typing import List
from typing import List, Optional

import dask

Expand Down Expand Up @@ -423,6 +423,7 @@ def __init__(
cpu: int,
mem: int,
gpu: int,
nthreads: Optional[int],
extra_args: List[str],
**kwargs
):
Expand All @@ -432,14 +433,19 @@ def __init__(
self._cpu = cpu
self._mem = mem
self._gpu = gpu
self._nthreads = nthreads
self._overrides = {
"command": [
"dask-cuda-worker" if self._gpu else "dask-worker",
self.scheduler,
"--name",
str(self.name),
"--nthreads",
"{}".format(max(int(self._cpu / 1024), 1)),
"{}".format(
max(int(self._cpu / 1024), 1)
if nthreads is None
else self._nthreads
),
"--memory-limit",
"{}GB".format(int(self._mem / 1024)),
"--death-timeout",
Expand Down Expand Up @@ -507,6 +513,10 @@ class ECSCluster(SpecCluster):
Defaults to ``4096`` (four vCPUs).
See the `troubleshooting guide`_ for information on the valid values for this argument.
worker_nthreads: int (optional)
The number of threads to use in each worker.
Defaults to 1 per vCPU.
worker_mem: int (optional)
The amount of memory to request for worker tasks in MB.
Expand Down Expand Up @@ -697,6 +707,7 @@ def __init__(
scheduler_task_kwargs=None,
scheduler_address=None,
worker_cpu=None,
worker_nthreads=None,
worker_mem=None,
worker_gpu=None,
worker_extra_args=None,
Expand Down Expand Up @@ -740,6 +751,7 @@ def __init__(
self._scheduler_task_kwargs = scheduler_task_kwargs
self._scheduler_address = scheduler_address
self._worker_cpu = worker_cpu
self._worker_nthreads = worker_nthreads
self._worker_mem = worker_mem
self._worker_gpu = worker_gpu
self._worker_extra_args = worker_extra_args
Expand Down Expand Up @@ -846,6 +858,9 @@ async def _start(
if self._worker_cpu is None:
self._worker_cpu = self.config.get("worker_cpu")

if self._worker_nthreads is None:
self._worker_nthreads = self.config.get("worker_nthreads")

if self._worker_mem is None:
self._worker_mem = self.config.get("worker_mem")

Expand Down Expand Up @@ -957,6 +972,7 @@ async def _start(
"fargate": self._fargate_workers,
"fargate_capacity_provider": "FARGATE_SPOT" if self._fargate_spot else None,
"cpu": self._worker_cpu,
"nthreads": self._worker_nthreads,
"mem": self._worker_mem,
"gpu": self._worker_gpu,
"extra_args": self._worker_extra_args,
Expand Down Expand Up @@ -1234,7 +1250,11 @@ async def _create_worker_task_definition_arn(self):
"command": [
"dask-cuda-worker" if self._worker_gpu else "dask-worker",
"--nthreads",
"{}".format(max(int(self._worker_cpu / 1024), 1)),
"{}".format(
max(int(self._worker_cpu / 1024), 1)
if self._worker_nthreads is None
else self._worker_nthreads
),
"--memory-limit",
"{}MB".format(int(self._worker_mem)),
"--death-timeout",
Expand Down

0 comments on commit 09c72a2

Please sign in to comment.