Skip to content

Commit

Permalink
[Azure] Improve cluster config interface (#298)
Browse files Browse the repository at this point in the history
* Updated cluster config getters

* Move config class

* Add usage example to ClusterConfig docstring

* Linting

Co-authored-by: Jacob Tomlinson <jtomlinson@nvidia.com>
  • Loading branch information
DPeterK and jacobtomlinson committed Jul 1, 2021
1 parent 862f88f commit e49944a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
58 changes: 21 additions & 37 deletions dask_cloudprovider/azure/azurevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid

import dask
from dask_cloudprovider.config import ClusterConfig
from dask_cloudprovider.generic.vmcluster import (
VMCluster,
VMInterface,
Expand Down Expand Up @@ -460,27 +461,19 @@ def __init__(
marketplace_plan: dict = {},
**kwargs,
):
self.config = dask.config.get("cloudprovider.azure.azurevm", {})
self.config = ClusterConfig(dask.config.get("cloudprovider.azure", {}))
self.scheduler_class = AzureVMScheduler
self.worker_class = AzureVMWorker
self.location = (
location
if location is not None
else dask.config.get("cloudprovider.azure.location")
)
self.location = self.config.get("location", override_with=location)
if self.location is None:
raise ConfigError("You must configure a location")
self.resource_group = (
resource_group
if resource_group is not None
else self.config.get("resource_group")
self.resource_group = self.config.get(
"resource_group", override_with=resource_group
)
if self.resource_group is None:
raise ConfigError("You must configure a resource_group")
self.public_ingress = (
public_ingress
if public_ingress is not None
else self.config.get("public_ingress")
self.public_ingress = self.config.get(
"azurevm.public_ingress", override_with=public_ingress
)
self.credentials, self.subscription_id = get_azure_cli_credentials()
self.compute_client = ComputeManagementClient(
Expand All @@ -489,50 +482,41 @@ def __init__(
self.network_client = NetworkManagementClient(
self.credentials, self.subscription_id
)
self.vnet = vnet if vnet is not None else self.config.get("vnet")
self.vnet = self.config.get("azurevm.vnet", override_with=vnet)
if self.vnet is None:
raise ConfigError("You must configure a vnet")
self.security_group = (
security_group
if security_group is not None
else self.config.get("security_group")
self.security_group = self.config.get(
"azurevm.security_group", override_with=security_group
)
if self.security_group is None:
raise ConfigError(
"You must configure a security group which allows traffic on 8786 and 8787"
)
self.vm_size = vm_size if vm_size is not None else self.config.get("vm_size")
self.disk_size = (
disk_size if disk_size is not None else self.config.get("disk_size")
)
self.vm_size = self.config.get("azurevm.vm_size", override_with=vm_size)
self.disk_size = self.config.get("azurevm.disk_size", override_with=disk_size)
if self.disk_size > 1023:
raise ValueError(
"VM OS disk canot be larger than 1023. Please change the ``disk_size`` config option."
)
self.scheduler_vm_size = (
scheduler_vm_size
if scheduler_vm_size is not None
else self.config.get("scheduler_vm_size")
self.scheduler_vm_size = self.config.get(
"azurevm.scheduler_vm_size", override_with=scheduler_vm_size
)
if self.scheduler_vm_size is None:
self.scheduler_vm_size = self.vm_size
self.gpu_instance = (
"_NC" in self.vm_size.upper() or "_ND" in self.vm_size.upper()
)
self.vm_image = self.config.get("vm_image")
self.vm_image = self.config.get("azurevm.vm_image")
for key in vm_image:
self.vm_image[key] = vm_image[key]
self.bootstrap = (
bootstrap if bootstrap is not None else self.config.get("bootstrap")
self.bootstrap = self.config.get("azurevm.bootstrap", override_with=bootstrap)
self.auto_shutdown = self.config.get(
"azurevm.auto_shutdown", override_with=auto_shutdown
)
self.auto_shutdown = (
auto_shutdown
if auto_shutdown is not None
else self.config.get("auto_shutdown")
)
self.docker_image = docker_image or self.config.get("docker_image")
self.debug = debug
self.marketplace_plan = marketplace_plan or self.config.get("marketplace_plan")
self.marketplace_plan = marketplace_plan or self.config.get(
"azurevm.marketplace_plan"
)
if self.marketplace_plan:
# Check that self.marketplace_plan contains the right options with values
if not all(
Expand Down
25 changes: 25 additions & 0 deletions dask_cloudprovider/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,31 @@
import yaml


class ClusterConfig(dict):
"""Simple config interface for dask-cloudprovider clusters, such as `AzureVMCluster`.
Enables '.' notation for nested access, as per `dask.config.get`.
Example
-------
>>> from dask_cloudprovider.config import ClusterConfig
>>> class RandomCluster(VMCluster):
... def __init__(self, option=None):
... self.config = ClusterConfig(dask.config.get("cloudprovider.random", {}))
... self.option = self.config.get("option", override_with=option)
"""

def __new__(cls, d):
return super().__new__(cls, d)

def get(self, key, default=None, override_with=None):
return dask.config.get(
key, default=default, config=self, override_with=override_with
)


fn = os.path.join(os.path.dirname(__file__), "cloudprovider.yaml")
dask.config.ensure_file(source=fn)

Expand Down

0 comments on commit e49944a

Please sign in to comment.