Skip to content

Commit

Permalink
Add support for extra keyword arguments for ECS tasks (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
skozlovf committed Oct 29, 2020
1 parent f829d94 commit 027bc1a
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions dask_cloudprovider/providers/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class Task:
Whether to use a private IP (if True) or public IP (if False) with Fargate.
Defaults to False, i.e. public IP.
task_kwargs: dict (optional)
Additional keyword arguments for the ECS task.
kwargs:
Any additional kwargs which may need to be stored for later use.
Expand All @@ -117,6 +120,7 @@ def __init__(
name=None,
platform_version=None,
fargate_use_private_ip=False,
task_kwargs=None,
**kwargs
):
self.lock = asyncio.Lock()
Expand Down Expand Up @@ -144,6 +148,7 @@ def __init__(
self.platform_version = platform_version
self._fargate_use_private_ip = fargate_use_private_ip
self.kwargs = kwargs
self.task_kwargs = task_kwargs
self.status = "created"

def __await__(self):
Expand Down Expand Up @@ -210,18 +215,19 @@ async def start(self):
timeout = Timeout(60, "Unable to start %s after 60 seconds" % self.task_type)
while timeout.run():
try:
kwargs = (
{"tags": dict_to_aws(self.tags)}
if await self._is_long_arn_format_enabled()
else {}
) # Tags are only supported if you opt into long arn format so we need to check for that
kwargs = self.task_kwargs.copy() if self.task_kwargs is not None else {}

# Tags are only supported if you opt into long arn format so we need to check for that
if await self._is_long_arn_format_enabled():
kwargs["tags"] = dict_to_aws(self.tags)
if self.platform_version and self.fargate:
kwargs["platformVersion"] = self.platform_version
async with self._client("ecs") as ecs:
response = await ecs.run_task(
cluster=self.cluster_arn,
taskDefinition=self.task_definition_arn,
overrides={

kwargs.update(
{
"cluster": self.cluster_arn,
"taskDefinition": self.task_definition_arn,
"overrides": {
"containerOverrides": [
{
"name": "dask-{}".format(self.task_type),
Expand All @@ -232,9 +238,9 @@ async def start(self):
}
]
},
count=1,
launchType="FARGATE" if self.fargate else "EC2",
networkConfiguration={
"count": 1,
"launchType": "FARGATE" if self.fargate else "EC2",
"networkConfiguration": {
"awsvpcConfiguration": {
"subnets": self._vpc_subnets,
"securityGroups": self._security_groups,
Expand All @@ -243,8 +249,11 @@ async def start(self):
else "DISABLED",
}
},
**kwargs
)
}
)

async with self._client("ecs") as ecs:
response = await ecs.run_task(**kwargs)

if not response.get("tasks"):
raise RuntimeError(response) # print entire response
Expand Down Expand Up @@ -450,6 +459,8 @@ class ECSCluster(SpecCluster):
Any extra command line arguments to pass to dask-scheduler, e.g. ``["--tls-cert", "/path/to/cert.pem"]``
Defaults to `None`, no extra command line arguments.
scheduler_task_kwargs: dict (optional)
Additional keyword arguments for the scheduler ECS task.
worker_cpu: int (optional)
The amount of CPU to request for worker tasks in milli-cpu (1/1024).
Expand All @@ -471,6 +482,8 @@ class ECSCluster(SpecCluster):
Any extra command line arguments to pass to dask-worker, e.g. ``["--tls-cert", "/path/to/cert.pem"]``
Defaults to `None`, no extra command line arguments.
worker_task_kwargs: dict (optional)
Additional keyword arguments for the workers ECS task.
n_workers: int (optional)
Number of workers to start on cluster creation.
Expand Down Expand Up @@ -609,10 +622,12 @@ def __init__(
scheduler_mem=None,
scheduler_timeout=None,
scheduler_extra_args=None,
scheduler_task_kwargs=None,
worker_cpu=None,
worker_mem=None,
worker_gpu=None,
worker_extra_args=None,
worker_task_kwargs=None,
n_workers=None,
cluster_arn=None,
cluster_name_template=None,
Expand Down Expand Up @@ -646,10 +661,12 @@ def __init__(
self._scheduler_mem = scheduler_mem
self._scheduler_timeout = scheduler_timeout
self._scheduler_extra_args = scheduler_extra_args
self._scheduler_task_kwargs = scheduler_task_kwargs
self._worker_cpu = worker_cpu
self._worker_mem = worker_mem
self._worker_gpu = worker_gpu
self._worker_extra_args = worker_extra_args
self._worker_task_kwargs = worker_task_kwargs
self._n_workers = n_workers
self.cluster_arn = cluster_arn
self.cluster_name = None
Expand Down Expand Up @@ -846,6 +863,7 @@ async def _start(self,):
scheduler_options = {
"task_definition_arn": self.scheduler_task_definition_arn,
"fargate": self._fargate_scheduler,
"task_kwargs": self._scheduler_task_kwargs,
**options,
}
worker_options = {
Expand All @@ -855,6 +873,7 @@ async def _start(self,):
"mem": self._worker_mem,
"gpu": self._worker_gpu,
"extra_args": self._worker_extra_args,
"task_kwargs": self._worker_task_kwargs,
**options,
}

Expand Down

0 comments on commit 027bc1a

Please sign in to comment.