Skip to content

Commit

Permalink
Merge pull request #91 from joeschmid/cli-extra-args
Browse files Browse the repository at this point in the history
Allow extra command line args to scheduler & worker
  • Loading branch information
martindurant committed Apr 28, 2020
2 parents 20d10f0 + 8c5eff3 commit 915596e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
2 changes: 2 additions & 0 deletions dask_cloudprovider/cloudprovider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ cloudprovider:
fargate_workers: False # Use fargate mode for the workers
scheduler_cpu: 1024 # Millicpu (1024ths of a CPU core)
scheduler_mem: 4096 # Memory in MB
# scheduler_extra_args: "--tls-cert,/path/to/cert.pem,--tls-key,/path/to/cert.key,--tls-ca-file,/path/to/ca.key"
worker_cpu: 4096 # Millicpu (1024ths of a CPU core)
worker_mem: 16384 # Memory in MB
worker_gpu: 0 # Number of GPUs for each worker
# worker_extra_args: "--tls-cert,/path/to/cert.pem,--tls-key,/path/to/cert.key,--tls-ca-file,/path/to/ca.key"
n_workers: 0 # Number of workers to start the cluster with
scheduler_timeout: "5 minutes" # Length of inactivity to wait before closing the cluster

Expand Down
52 changes: 49 additions & 3 deletions dask_cloudprovider/providers/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import uuid
import warnings
import weakref
from typing import List

from botocore.exceptions import ClientError
import aiobotocore
Expand Down Expand Up @@ -358,7 +359,15 @@ class Worker(Task):
Other kwargs to be passed to :class:`Task`.
"""

def __init__(self, scheduler: str, cpu: int, mem: int, gpu: int, **kwargs):
def __init__(
self,
scheduler: str,
cpu: int,
mem: int,
gpu: int,
extra_args: List[str],
**kwargs
):
super().__init__(**kwargs)
self.task_type = "worker"
self.scheduler = scheduler
Expand All @@ -378,6 +387,7 @@ def __init__(self, scheduler: str, cpu: int, mem: int, gpu: int, **kwargs):
"--death-timeout",
"60",
]
+ (list() if not extra_args else extra_args)
}


Expand Down Expand Up @@ -413,6 +423,10 @@ class ECSCluster(SpecCluster):
The scheduler task will exit after this amount of time if there are no clients connected.
Defaults to ``5 minutes``.
scheduler_extra_args: List[str] (optional)
Any extra command line arguments to pass to dask-scheduler, e.g. ``["--tls-cert", "/path/to/cert.pem"]``
Defaults to `None`, no extra command line arguments.
worker_cpu: int (optional)
The amount of CPU to request for worker tasks in milli-cpu (1/1024).
Expand All @@ -429,6 +443,11 @@ class ECSCluster(SpecCluster):
cluster. Fargate is not supported at this time.
Defaults to `None`, no GPUs.
worker_extra_args: List[str] (optional)
Any extra command line arguments to pass to dask-worker, e.g. ``["--tls-cert", "/path/to/cert.pem"]``
Defaults to `None`, no extra command line arguments.
n_workers: int (optional)
Number of workers to start on cluster creation.
Expand Down Expand Up @@ -547,9 +566,11 @@ def __init__(
scheduler_cpu=None,
scheduler_mem=None,
scheduler_timeout=None,
scheduler_extra_args=None,
worker_cpu=None,
worker_mem=None,
worker_gpu=None,
worker_extra_args=None,
n_workers=None,
cluster_arn=None,
cluster_name_template=None,
Expand Down Expand Up @@ -578,9 +599,11 @@ def __init__(
self._scheduler_cpu = scheduler_cpu
self._scheduler_mem = scheduler_mem
self._scheduler_timeout = scheduler_timeout
self._scheduler_extra_args = scheduler_extra_args
self._worker_cpu = worker_cpu
self._worker_mem = worker_mem
self._worker_gpu = worker_gpu
self._worker_extra_args = worker_extra_args
self._n_workers = n_workers
self.cluster_arn = cluster_arn
self.cluster_name = None
Expand Down Expand Up @@ -664,12 +687,24 @@ async def _start(self,):
if self._scheduler_timeout is None:
self._scheduler_timeout = self.config.get("scheduler_timeout")

if self._scheduler_extra_args is None:
comma_separated_args = self.config.get("scheduler_extra_args")
self._scheduler_extra_args = (
comma_separated_args.split(",") if comma_separated_args else None
)

if self._worker_cpu is None:
self._worker_cpu = self.config.get("worker_cpu")

if self._worker_mem is None:
self._worker_mem = self.config.get("worker_mem")

if self._worker_extra_args is None:
comma_separated_args = self.config.get("worker_extra_args")
self._worker_extra_args = (
comma_separated_args.split(",") if comma_separated_args else None
)

if self._n_workers is None:
self._n_workers = self.config.get("n_workers")

Expand Down Expand Up @@ -765,6 +800,7 @@ async def _start(self,):
"cpu": self._worker_cpu,
"mem": self._worker_mem,
"gpu": self._worker_gpu,
"extra_args": self._worker_extra_args,
**options,
}

Expand Down Expand Up @@ -997,7 +1033,12 @@ async def _create_scheduler_task_definition_arn(self):
"dask-scheduler",
"--idle-timeout",
self._scheduler_timeout,
],
]
+ (
list()
if not self._scheduler_extra_args
else self._scheduler_extra_args
),
"logConfiguration": {
"logDriver": "awslogs",
"options": {
Expand Down Expand Up @@ -1053,7 +1094,12 @@ async def _create_worker_task_definition_arn(self):
"{}MB".format(int(self._worker_mem)),
"--death-timeout",
"60",
],
]
+ (
list()
if not self._worker_extra_args
else self._worker_extra_args
),
"logConfiguration": {
"logDriver": "awslogs",
"options": {
Expand Down

0 comments on commit 915596e

Please sign in to comment.