Skip to content

Commit

Permalink
Support explicit runtime platform architecture selection for ECS and …
Browse files Browse the repository at this point in the history
…Fargate clusters
  • Loading branch information
dmitry-livchak committed Mar 12, 2024
1 parent 1eb17ff commit 4991917
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions dask_cloudprovider/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,10 @@ class ECSCluster(SpecCluster, ConfigMixin):
The docker image to use for the scheduler and worker tasks.
Defaults to ``daskdev/dask:latest`` or ``rapidsai/rapidsai:latest`` if ``worker_gpu`` is set.
cpu_architecture: str (optional)
Runtime platform CPU architecture
Defaults to ``X86_64``.
scheduler_cpu: int (optional)
The amount of CPU to request for the scheduler in milli-cpu (1/1024).
Expand Down Expand Up @@ -712,6 +716,7 @@ def __init__(
fargate_workers=None,
fargate_spot=None,
image=None,
cpu_architecture="X86_64",
scheduler_cpu=None,
scheduler_mem=None,
scheduler_port=8786,
Expand Down Expand Up @@ -754,6 +759,7 @@ def __init__(
mount_volumes_on_scheduler=False,
**kwargs,
):
self._cpu_architecture = cpu_architecture
self._fargate_scheduler = fargate_scheduler
self._fargate_workers = fargate_workers
self._fargate_spot = fargate_spot
Expand Down Expand Up @@ -1223,6 +1229,7 @@ async def _create_scheduler_task_definition_arn(self):
if self._volumes and self._mount_volumes_on_scheduler
else [],
requiresCompatibilities=["FARGATE"] if self._fargate_scheduler else [],
runtimePlatform={"cpuArchitecture": self._cpu_architecture},
cpu=str(self._scheduler_cpu),
memory=str(self._scheduler_mem),
tags=dict_to_aws(self.tags),
Expand Down Expand Up @@ -1297,6 +1304,7 @@ async def _create_worker_task_definition_arn(self):
],
volumes=self._volumes if self._volumes else [],
requiresCompatibilities=["FARGATE"] if self._fargate_workers else [],
runtimePlatform={"cpuArchitecture": self._cpu_architecture},
cpu=str(self._worker_cpu),
memory=str(self._worker_mem),
tags=dict_to_aws(self.tags),
Expand Down

0 comments on commit 4991917

Please sign in to comment.