Skip to content

Commit

Permalink
ECSCluster attr cleanup (#357)
Browse files Browse the repository at this point in the history
* Mixin for updating class attrs from config

* Add default value in cloudprovider.yaml

* Black formatting

* Format cloudprovider.yaml
  • Loading branch information
pwerth committed Aug 5, 2022
1 parent 3592253 commit cd2de56
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 73 deletions.
101 changes: 31 additions & 70 deletions dask_cloudprovider/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_default_vpc,
get_vpc_subnets,
create_default_security_group,
ConfigMixin,
)

from distributed.deploy.spec import SpecCluster
Expand Down Expand Up @@ -455,7 +456,7 @@ def __init__(
}


class ECSCluster(SpecCluster):
class ECSCluster(SpecCluster, ConfigMixin):
"""Deploy a Dask cluster using ECS
This creates a dask scheduler and workers on an existing ECS cluster.
Expand Down Expand Up @@ -696,9 +697,9 @@ class ECSCluster(SpecCluster):

def __init__(
self,
fargate_scheduler=False,
fargate_workers=False,
fargate_spot=False,
fargate_scheduler=None,
fargate_workers=None,
fargate_spot=None,
image=None,
scheduler_cpu=None,
scheduler_mem=None,
Expand Down Expand Up @@ -807,91 +808,62 @@ async def _start(

self.config = dask.config.get("cloudprovider.ecs", {})

if self._region_name is None:
self._region_name = self.config.get("region_name")

if self._aws_access_key_id is None:
self._aws_access_key_id = self.config.get("aws_access_key_id")

if self._aws_secret_access_key is None:
self._aws_secret_access_key = self.config.get("aws_secret_access_key")
for attr in [
"aws_access_key_id",
"aws_secret_access_key",
"cloudwatch_logs_default_retention",
"cluster_name_template",
"environment",
"fargate_scheduler",
"fargate_spot",
"fargate_workers",
"fargate_use_private_ip",
"n_workers",
"platform_version",
"region_name",
"scheduler_cpu",
"scheduler_mem",
"scheduler_timeout",
"skip_cleanup",
"tags",
"task_role_policies",
"worker_cpu",
"worker_gpu", # TODO Detect whether cluster is GPU capable
"worker_mem",
"worker_nthreads",
"vpc",
]:
self.update_attr_from_config(attr=attr, private=True)

# Cleanup any stale resources before we start
if self._skip_cleanup is None:
self._skip_cleanup = self.config.get("skip_cleanup")
if not self._skip_cleanup:
await _cleanup_stale_resources(
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
region_name=self._region_name,
)

if self._fargate_scheduler is None:
self._fargate_scheduler = self.config.get("fargate_scheduler")
if self._fargate_workers is None:
self._fargate_workers = self.config.get("fargate_workers")
if self._fargate_spot is None:
self._fargate_spot = self.config.get("fargate_spot")

if self._tags is None:
self._tags = self.config.get("tags")

if self._environment is None:
self._environment = self.config.get("environment")

if self._find_address_timeout is None:
self._find_address_timeout = self.config.get("find_address_timeout", 60)

if self._worker_gpu is None:
self._worker_gpu = self.config.get(
"worker_gpu"
) # TODO Detect whether cluster is GPU capable

if self.image is None:
if self._worker_gpu:
self.image = self.config.get("gpu_image")
else:
self.image = self.config.get("image")

if self._scheduler_cpu is None:
self._scheduler_cpu = self.config.get("scheduler_cpu")

if self._scheduler_mem is None:
self._scheduler_mem = self.config.get("scheduler_mem")

if self._scheduler_timeout is None:
self._scheduler_timeout = self.config.get("scheduler_timeout")

if self._scheduler_extra_args is None:
comma_separated_args = self.config.get("scheduler_extra_args")
self._scheduler_extra_args = (
comma_separated_args.split(",") if comma_separated_args else None
)

if self._worker_cpu is None:
self._worker_cpu = self.config.get("worker_cpu")

if self._worker_nthreads is None:
self._worker_nthreads = self.config.get("worker_nthreads")

if self._worker_mem is None:
self._worker_mem = self.config.get("worker_mem")

if self._worker_extra_args is None:
comma_separated_args = self.config.get("worker_extra_args")
self._worker_extra_args = (
comma_separated_args.split(",") if comma_separated_args else None
)

if self._n_workers is None:
self._n_workers = self.config.get("n_workers")

if self._cluster_name_template is None:
self._cluster_name_template = self.config.get("cluster_name_template")

if self._platform_version is None:
self._platform_version = self.config.get("platform_version")

if self.cluster_arn is None:
self.cluster_arn = (
self.config.get("cluster_arn") or await self._create_cluster()
Expand All @@ -910,9 +882,6 @@ async def _start(
or await self._create_execution_role()
)

if self._task_role_policies is None:
self._task_role_policies = self.config.get("task_role_policies")

if self._task_role_arn is None:
self._task_role_arn = (
self.config.get("task_role_arn") or await self._create_task_role()
Expand All @@ -923,20 +892,12 @@ async def _start(
"cloudwatch_logs_stream_prefix"
).format(cluster_name=self.cluster_name)

if self._cloudwatch_logs_default_retention is None:
self._cloudwatch_logs_default_retention = self.config.get(
"cloudwatch_logs_default_retention"
)

if self.cloudwatch_logs_group is None:
self.cloudwatch_logs_group = (
self.config.get("cloudwatch_logs_group")
or await self._create_cloudwatch_logs_group()
)

if self._vpc is None:
self._vpc = self.config.get("vpc")

if self._vpc == "default":
async with self._client("ec2") as client:
self._vpc = await get_default_vpc(client)
Expand Down
12 changes: 12 additions & 0 deletions dask_cloudprovider/aws/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def get_sleep_duration(current_try, min_sleep_millis=10, max_sleep_millis=5000):
return min(current_sleep_millis, max_sleep_millis) / 1000 # return in seconds


class ConfigMixin:
def update_attr_from_config(self, attr: str, private: bool):
"""Update class attribute of given cluster based on config, if not already set. If `private` is True, the class
attribute will be prefixed with an underscore.
This mixin can be applied to any class that has a config dict attribute.
"""
prefix = "_" if private else ""
if getattr(self, f"{prefix}{attr}") is None:
setattr(self, f"{prefix}{attr}", self.config.get(attr))


async def get_latest_ami_id(client, name_glob, owner):
images = await client.describe_images(
Filters=[
Expand Down
23 changes: 23 additions & 0 deletions dask_cloudprovider/aws/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,26 @@ def test_get_sleep_duration_negative_try():
current_try=-1, min_sleep_millis=10, max_sleep_millis=5000
)
assert duration == 0.01


def test_config_mixin():
from dask_cloudprovider.aws.helper import ConfigMixin

class MockCluster(ConfigMixin):
config = None
_attr1 = "foo"
attr2 = None

def __init__(self):
self.config = {"attr2": "bar"}

cluster_with_mixin = MockCluster()

# Test that nothing happens if attr is already set
attr1 = cluster_with_mixin._attr1
cluster_with_mixin.update_attr_from_config(attr="attr1", private=True)
assert cluster_with_mixin._attr1 == attr1

# Test that attr is updated if existing value is None
cluster_with_mixin.update_attr_from_config(attr="attr2", private=False)
assert cluster_with_mixin.attr2 == "bar"
8 changes: 5 additions & 3 deletions dask_cloudprovider/cloudprovider.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
cloudprovider:
ecs:
fargate_scheduler: False # Use fargate mode for the scheduler
fargate_workers: False # Use fargate mode for the workers
fargate_scheduler: false # Use fargate mode for the scheduler
fargate_spot: false
fargate_workers: false # Use fargate mode for the workers
fargate_use_private_ip: false
scheduler_cpu: 1024 # Millicpu (1024ths of a CPU core)
scheduler_mem: 4096 # Memory in MB
# scheduler_extra_args: "--tls-cert,/path/to/cert.pem,--tls-key,/path/to/cert.key,--tls-ca-file,/path/to/ca.key"
Expand Down Expand Up @@ -32,7 +34,7 @@ cloudprovider:
tags: {} # Tags to apply to all AWS resources created by the cluster manager
environment: {} # Environment variables that are set within a task container
find_address_timeout: 60 # Configurable timeout in seconds for finding the task IP from the cloudwatch logs.
skip_cleanup: False # Skip cleaning up of stale resources
skip_cleanup: false # Skip cleaning up of stale resources

ec2:
region: null # AWS region to create cluster. Defaults to environment or account default region.
Expand Down

0 comments on commit cd2de56

Please sign in to comment.