Skip to content

Commit

Permalink
Pass GPU info to worker task (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson committed Aug 27, 2019
1 parent e16a730 commit 00ffbc6
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions dask_cloudprovider/providers/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,17 @@ class Worker(Task):
Other kwargs to be passed to :class:`Task`.
"""

def __init__(self, scheduler: str, cpu: int, mem: int, **kwargs):
def __init__(self, scheduler: str, cpu: int, mem: int, gpu: int, **kwargs):
super().__init__(**kwargs)
self.task_type = "worker"
self.scheduler = scheduler
self._cpu = cpu
self._mem = mem
self._gpu = gpu
self._overrides = {
"command": [
"dask-worker",
"dask-cuda-worker" if self._gpu else "dask-worker",
self.scheduler,
"--name",
str(self.name),
"--nthreads",
Expand All @@ -341,7 +343,6 @@ def __init__(self, scheduler: str, cpu: int, mem: int, **kwargs):
"60",
]
}
self.environment["DASK_SCHEDULER_ADDRESS"] = self.scheduler


class ECSCluster(SpecCluster):
Expand Down Expand Up @@ -695,6 +696,7 @@ async def _start(self,):
"fargate": self._fargate_workers,
"cpu": self._worker_cpu,
"mem": self._worker_mem,
"gpu": self._worker_gpu,
**options,
}

Expand Down

0 comments on commit 00ffbc6

Please sign in to comment.