Skip to content

Commit

Permalink
Move all default values into config file (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson committed Aug 20, 2019
1 parent 3fdb831 commit a554351
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 134 deletions.
45 changes: 28 additions & 17 deletions dask_cloudprovider/cloudprovider.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
cloudprovider: {}
cloudprovider:

# ecs: {}
# # scheduler_cpu: 1024 # Millicpu (1024ths of a CPU core)
# # scheduler_mem: 4096 # Memory in MB
# # worker_cpu: 4096 # Millicpu (1024ths of a CPU core)
# # worker_mem: 16384 # Memory in MB
# # n_workers: 2
ecs:
fargate_scheduler: False # Use fargate mode for the scheduler
fargate_workers: False # Use fargate mode for the workers
scheduler_cpu: 1024 # Millicpu (1024ths of a CPU core)
scheduler_mem: 4096 # Memory in MB
worker_cpu: 4096 # Millicpu (1024ths of a CPU core)
worker_mem: 16384 # Memory in MB
worker_gpu: 0 # Number of GPUs for each worker
n_workers: 0 # Number of workers to start the cluster with
scheduler_timeout: '5 minutes' # Length of inactivity to wait before closing the cluster

# # image: 'daskdev/dask:1.2.0'
# # cluster_name_template: 'dask-{uuid}' # Template to use when creating a cluster
# # cluster_arn: null # ARN of existing ECS cluster to use (if not set one will be created)
# # execution_role_arn: null # Arn of existing execution role to use (if not set one will be created)
# # task_role_arn: null # Arn of existing task role to use (if not set one will be created)
image: 'daskdev/dask:latest' # Docker image to use for non GPU tasks
gpu_image: 'rapidsai/rapidsai:latest' # Docker image to use for GPU tasks
cluster_name_template: 'dask-{uuid}' # Template to use when creating a cluster
cluster_arn: '' # ARN of existing ECS cluster to use (if not set one will be created)
execution_role_arn: '' # Arn of existing execution role to use (if not set one will be created)
task_role_arn: '' # Arn of existing task role to use (if not set one will be created)
task_role_policies: [] # List of policy arns to attach to tasks (e.g S3 read only access)

# # cloudwatch_logs_group: null # Name of existing cloudwatch logs group to use (if not set one will be created)
# # cloudwatch_logs_stream_prefix: '{cluster_name}' # Stream prefix template
# # cloudwatch_logs_default_retention: 30 # Number of days to retain logs (only applied if not using existing group)
cloudwatch_logs_group: '' # Name of existing cloudwatch logs group to use (if not set one will be created)
cloudwatch_logs_stream_prefix: '{cluster_name}' # Stream prefix template
cloudwatch_logs_default_retention: 30 # Number of days to retain logs (only applied if not using existing group)

# # vpc: default # VPC to use for tasks
# # security_groups: [] # Security groups to use (if not set one will be created)
vpc: 'default' # VPC to use for tasks
subnets: [] # VPC subnets to use (will use all available if not set)
security_groups: [] # Security groups to use (if not set one will be created)

tags: {} # Tags to apply to all AWS resources created by the cluster manager

skip_cleanup: False # Skip cleaning up of stale resources
203 changes: 86 additions & 117 deletions dask_cloudprovider/providers/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
DEFAULT_TAGS = {
"createdBy": "dask-cloudprovider"
} # Package tags to apply to all resources
DEFAULT_CLUSTER_NAME_TEMPLATE = "dask-{uuid}"


class Task:
Expand Down Expand Up @@ -529,6 +528,7 @@ def __init__(
self._worker_gpu = worker_gpu
self._n_workers = n_workers
self.cluster_arn = cluster_arn
self.cluster_name = None
self._cluster_name_template = cluster_name_template
self._execution_role_arn = execution_role_arn
self._task_role_arn = task_role_arn
Expand Down Expand Up @@ -556,84 +556,58 @@ async def _start(self,):
self.config = dask.config.get("cloudprovider.ecs", {})

# Cleanup any stale resources before we start
self._skip_cleanup = (
self.config.get("skip_cleanup", False)
if self._skip_cleanup is None
else self._skip_cleanup
)
if self._skip_cleanup is None:
self._skip_cleanup = self.config.get("skip_cleanup")
if not self._skip_cleanup:
await _cleanup_stale_resources()

self._clients = await self._get_clients()
self._fargate_scheduler = (
self.config.get("fargate_scheduler", False)
if self._fargate_scheduler is None
else self._fargate_scheduler
)
self._fargate_workers = (
self.config.get("fargate_workers", False)
if self._fargate_workers is None
else self._fargate_workers
)
self._tags = self.config.get("tags", {}) if self._tags is None else self._tags
self._worker_gpu = (
self.config.get("worker_gpu")
if self._worker_gpu is None
else self._worker_gpu
) # TODO Detect whether cluster is GPU capable
self.image = (
self.config.get(
"image",
"rapidsai/rapidsai:latest"
if self._worker_gpu
else "daskdev/dask:latest",
)
if self.image is None
else self.image
)
self._scheduler_cpu = (
self.config.get("scheduler_cpu", 1024)
if self._scheduler_cpu is None
else self._scheduler_cpu
)
self._scheduler_mem = (
self.config.get("scheduler_mem", 4096)
if self._scheduler_mem is None
else self._scheduler_mem
)
self._scheduler_timeout = (
self.config.get("scheduler_timeout", "5 minutes")
if self._scheduler_timeout is None
else self._scheduler_timeout
)
self._worker_cpu = (
self.config.get("worker_cpu", 4096)
if self._worker_cpu is None
else self._worker_cpu
)
self._worker_mem = (
self.config.get("worker_mem", 16384)
if self._worker_mem is None
else self._worker_mem
)
self._n_workers = (
self.config.get("n_workers", 0)
if self._n_workers is None
else self._n_workers
)

self.cluster_name = None
self._cluster_name_template = (
self.config.get("cluster_name", DEFAULT_CLUSTER_NAME_TEMPLATE)
if self._cluster_name_template is None
else self._cluster_name_template
)
if self._fargate_scheduler is None:
self._fargate_scheduler = self.config.get("fargate_scheduler")
if self._fargate_workers is None:
self._fargate_workers = self.config.get("fargate_workers")

if self._tags is None:
self._tags = self.config.get("tags")

if self._worker_gpu is None:
self._worker_gpu = self.config.get(
"worker_gpu"
) # TODO Detect whether cluster is GPU capable

if self.image is None:
if self._worker_gpu:
self.image = self.config.get("gpu_image")
else:
self.image = self.config.get("image")

if self._scheduler_cpu is None:
self._scheduler_cpu = self.config.get("scheduler_cpu")

if self._scheduler_mem is None:
self._scheduler_mem = self.config.get("scheduler_mem")

if self._scheduler_timeout is None:
self._scheduler_timeout = self.config.get("scheduler_timeout")

if self._worker_cpu is None:
self._worker_cpu = self.config.get("worker_cpu")

if self._worker_mem is None:
self._worker_mem = self.config.get("worker_mem")

if self._n_workers is None:
self._n_workers = self.config.get("n_workers")

if self._cluster_name_template is None:
self._cluster_name_template = self.config.get("cluster_name_template")

if self.cluster_arn is None:
self.cluster_arn = (
self.config.get("cluster_arn") or await self._create_cluster()
)

self.cluster_arn = (
self.config.get("cluster_arn", await self._create_cluster())
if self.cluster_arn is None
else self.cluster_arn
)
if self.cluster_name is None:
[cluster_info] = (
await self._clients["ecs"].describe_clusters(
Expand All @@ -642,57 +616,52 @@ async def _start(self,):
)["clusters"]
self.cluster_name = cluster_info["clusterName"]

self._execution_role_arn = (
self.config.get("execution_role_arn", await self._create_execution_role())
if self._execution_role_arn is None
else self._execution_role_arn
)
self._task_role_policies = (
self.config.get("task_role_policies", [])
if self._task_role_policies is None
else self._task_role_policies
)
self._task_role_arn = (
self.config.get("task_role_arn", await self._create_task_role())
if self._task_role_arn is None
else self._task_role_arn
)
if self._execution_role_arn is None:
self._execution_role_arn = (
self.config.get("execution_role_arn")
or await self._create_execution_role()
)

self._cloudwatch_logs_stream_prefix = (
self.config.get("cloudwatch_logs_stream_prefix", "{cluster_name}")
if self._cloudwatch_logs_stream_prefix is None
else self._cloudwatch_logs_stream_prefix
).format(cluster_name=self.cluster_name)
self._cloudwatch_logs_default_retention = (
self.config.get("cloudwatch_logs_default_retention", 30)
if self._cloudwatch_logs_default_retention is None
else self._cloudwatch_logs_default_retention
)
self.cloudwatch_logs_group = (
self.config.get(
"cloudwatch_logs_group", await self._create_cloudwatch_logs_group()
if self._task_role_policies is None:
self._task_role_policies = self.config.get("task_role_policies")

if self._task_role_arn is None:
self._task_role_arn = (
self.config.get("task_role_arn") or await self._create_task_role()
)
if self.cloudwatch_logs_group is None
else self.cloudwatch_logs_group
)

self._vpc = (
self.config.get("vpc", "default") if self._vpc is None else self._vpc
)
if self._cloudwatch_logs_stream_prefix is None:
self._cloudwatch_logs_stream_prefix = self.config.get(
"cloudwatch_logs_stream_prefix"
).format(cluster_name=self.cluster_name)

if self._cloudwatch_logs_default_retention is None:
self._cloudwatch_logs_default_retention = self.config.get(
"cloudwatch_logs_default_retention"
)

if self.cloudwatch_logs_group is None:
self.cloudwatch_logs_group = (
self.config.get("cloudwatch_logs_group")
or await self._create_cloudwatch_logs_group()
)

if self._vpc is None:
self._vpc = self.config.get("vpc")

if self._vpc == "default":
self._vpc = await self._get_default_vpc()

self._vpc_subnets = (
self.config.get("subnets", await self._get_vpc_subnets())
if self._vpc_subnets is None
else self._vpc_subnets
)
if self._vpc_subnets is None:
self._vpc_subnets = (
self.config.get("subnets") or await self._get_vpc_subnets()
)

self._security_groups = (
self.config.get("security_groups", await self._create_security_groups())
if self._security_groups is None
else self._security_groups
)
if self._security_groups is None:
self._security_groups = (
self.config.get("security_groups")
or await self._create_security_groups()
)

self.scheduler_task_definition_arn = (
await self._create_scheduler_task_definition_arn()
Expand Down

0 comments on commit a554351

Please sign in to comment.