Skip to content

Commit

Permalink
Merge pull request #224 from jacobtomlinson/async-support
Browse files Browse the repository at this point in the history
Async cluster creation
  • Loading branch information
quasiben committed Dec 18, 2020
2 parents d72ba79 + 5062260 commit 02b50e2
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 60 deletions.
75 changes: 47 additions & 28 deletions dask_cloudprovider/azure/azurevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ def __init__(
self.env_vars = env_vars

async def create_vm(self):
[subnet_info, *_] = self.cluster.network_client.subnets.list(
self.cluster.resource_group, self.cluster.vnet
[subnet_info, *_] = await self.cluster.call_async(
self.cluster.network_client.subnets.list,
self.cluster.resource_group,
self.cluster.vnet,
)

nic_parameters = {
Expand All @@ -76,15 +78,19 @@ async def create_vm(self):
}
],
"networkSecurityGroup": {
"id": self.cluster.network_client.network_security_groups.get(
self.cluster.resource_group, self.security_group
"id": (
await self.cluster.call_async(
self.cluster.network_client.network_security_groups.get,
self.cluster.resource_group,
self.security_group,
)
).id,
"location": self.location,
},
"tags": self.cluster.get_tags(),
}
if self.public_ingress:
self.public_ip = (
self.public_ip = await self.cluster.call_async(
self.cluster.network_client.public_ip_addresses.begin_create_or_update(
self.cluster.resource_group,
self.nic_name,
Expand All @@ -95,18 +101,18 @@ async def create_vm(self):
"public_ip_address_version": "IPV4",
"tags": self.cluster.get_tags(),
},
).result()
).result
)
nic_parameters["ip_configurations"][0]["public_ip_address"] = {
"id": self.public_ip.id
}
self.cluster._log("Assigned public IP")
self.nic = (
self.nic = await self.cluster.call_async(
self.cluster.network_client.network_interfaces.begin_create_or_update(
self.cluster.resource_group,
self.nic_name,
nic_parameters,
).result()
).result
)
self.cluster._log("Network interface ready")

Expand Down Expand Up @@ -150,44 +156,57 @@ async def create_vm(self):
"tags": self.cluster.get_tags(),
}
self.cluster._log("Creating VM")
async_vm_creation = (
await self.cluster.call_async(
self.cluster.compute_client.virtual_machines.begin_create_or_update(
self.cluster.resource_group, self.name, vm_parameters
)
).wait
)
async_vm_creation.wait()
self.vm = self.cluster.compute_client.virtual_machines.get(
self.cluster.resource_group, self.name
self.vm = await self.cluster.call_async(
self.cluster.compute_client.virtual_machines.get,
self.cluster.resource_group,
self.name,
)
self.nic = self.cluster.network_client.network_interfaces.get(
self.cluster.resource_group, self.nic.name
self.nic = await self.cluster.call_async(
self.cluster.network_client.network_interfaces.get,
self.cluster.resource_group,
self.nic.name,
)
self.cluster._log(f"Created VM {self.name}")
if self.public_ingress:
return self.public_ip.ip_address
return self.nic.ip_configurations[0].private_ip_address

async def destroy_vm(self):
self.cluster.compute_client.virtual_machines.begin_delete(
self.cluster.resource_group, self.name
).wait()
await self.cluster.call_async(
self.cluster.compute_client.virtual_machines.begin_delete(
self.cluster.resource_group, self.name
).wait
)
self.cluster._log(f"Terminated VM {self.name}")
for disk in self.cluster.compute_client.disks.list_by_resource_group(
self.cluster.resource_group
for disk in await self.cluster.call_async(
self.cluster.compute_client.disks.list_by_resource_group,
self.cluster.resource_group,
):
if self.name in disk.name:
self.cluster.compute_client.disks.begin_delete(
self.cluster.resource_group, disk.name
await self.cluster.call_async(
self.cluster.compute_client.disks.begin_delete(
self.cluster.resource_group,
disk.name,
).wait
)
self.cluster._log(f"Removed disks for VM {self.name}")
self.cluster.network_client.network_interfaces.begin_delete(
self.cluster.resource_group, self.nic.name
).wait()
await self.cluster.call_async(
self.cluster.network_client.network_interfaces.begin_delete(
self.cluster.resource_group, self.nic.name
).wait
)
self.cluster._log("Deleted network interface")
if self.public_ingress:
self.cluster.network_client.public_ip_addresses.begin_delete(
self.cluster.resource_group, self.public_ip.name
).wait()
await self.cluster.call_async(
self.cluster.network_client.public_ip_addresses.begin_delete(
self.cluster.resource_group, self.public_ip.name
).wait
)
self.cluster._log("Unassigned public IP")


Expand Down
6 changes: 3 additions & 3 deletions dask_cloudprovider/azure/tests/test_azurevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def skip_without_credentials(func):
vnet = dask.config.get("cloudprovider.azure.azurevm.vnet", None)
security_group = dask.config.get("cloudprovider.azure.azurevm.security_group", None)
location = dask.config.get("cloudprovider.azure.location", None)
if rg is None or vnet is None or security_group or location is None:
if rg is None or vnet is None or security_group is None or location is None:
return pytest.mark.skip(
reason="""
You must configure your Azure resource group and vnet to run this test.
Expand Down Expand Up @@ -60,9 +60,9 @@ async def test_create_cluster():
async with AzureVMCluster(asynchronous=True) as cluster:
assert cluster.status == Status.running

cluster.scale(1)
cluster.scale(2)
await cluster
assert len(cluster.workers) == 1
assert len(cluster.workers) == 2

async with Client(cluster, asynchronous=True) as client:

Expand Down
6 changes: 3 additions & 3 deletions dask_cloudprovider/digitalocean/droplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ async def create_vm(self):
env_vars=self.env_vars,
),
)
self.droplet.create()
await self.cluster.call_async(self.droplet.create)
for action in self.droplet.get_actions():
while action.status != "completed":
action.load()
await asyncio.sleep(0.1)
while self.droplet.ip_address is None:
self.droplet.load()
await self.cluster.call_async(self.droplet.load)
await asyncio.sleep(0.1)
self.cluster._log(f"Created droplet {self.name}")

return self.droplet.ip_address

async def destroy_vm(self):
self.droplet.destroy()
await self.cluster.call_async(self.droplet.destroy)
self.cluster._log(f"Terminated droplet {self.name}")


Expand Down
65 changes: 42 additions & 23 deletions dask_cloudprovider/gcp/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
VMInterface,
SchedulerMixin,
)
from dask_cloudprovider.gcp.utils import build_request


from distributed.core import Status
Expand Down Expand Up @@ -191,49 +192,57 @@ async def create_vm(self):
self.gcp_config = self.create_gcp_config()

try:
inst = (
inst = await self.cluster.call_async(
self.cluster.compute.instances()
.insert(project=self.projectid, zone=self.zone, body=self.gcp_config)
.execute()
.execute
)
self.gcp_inst = inst
self.id = self.gcp_inst["id"]
except HttpError as e:
# something failed
print(str(e))
raise Exception(str(e))
while self.update_status() != "RUNNING":
while await self.update_status() != "RUNNING":
await asyncio.sleep(0.5)

self.internal_ip = self.get_internal_ip()
self.internal_ip = await self.get_internal_ip()
if self.config.get("public_ingress", True):
self.external_ip = self.get_external_ip()
self.external_ip = await self.get_external_ip()
else:
self.external_ip = None
self.cluster._log(
f"{self.name}\n\tInternal IP: {self.internal_ip}\n\tExternal IP: {self.external_ip}"
)
return self.internal_ip, self.external_ip

def get_internal_ip(self):
async def get_internal_ip(self):
return (
self.cluster.compute.instances()
.list(project=self.projectid, zone=self.zone, filter=f"name={self.name}")
.execute()["items"][0]["networkInterfaces"][0]["networkIP"]
)
await self.cluster.call_async(
self.cluster.compute.instances()
.list(
project=self.projectid, zone=self.zone, filter=f"name={self.name}"
)
.execute
)
)["items"][0]["networkInterfaces"][0]["networkIP"]

def get_external_ip(self):
async def get_external_ip(self):
return (
self.cluster.compute.instances()
.list(project=self.projectid, zone=self.zone, filter=f"name={self.name}")
.execute()["items"][0]["networkInterfaces"][0]["accessConfigs"][0]["natIP"]
)
await self.cluster.call_async(
self.cluster.compute.instances()
.list(
project=self.projectid, zone=self.zone, filter=f"name={self.name}"
)
.execute
)
)["items"][0]["networkInterfaces"][0]["accessConfigs"][0]["natIP"]

def update_status(self):
d = (
async def update_status(self):
d = await self.cluster.call_async(
self.cluster.compute.instances()
.list(project=self.projectid, zone=self.zone, filter=f"name={self.name}")
.execute()
.execute
)
self.gcp_inst = d

Expand All @@ -253,9 +262,11 @@ def expand_source_image(self, source_image):

async def close(self):
self.cluster._log(f"Closing Instance: {self.name}")
self.cluster.compute.instances().delete(
project=self.projectid, zone=self.zone, instance=self.name
).execute()
await self.cluster.call_async(
self.cluster.compute.instances()
.delete(project=self.projectid, zone=self.zone, instance=self.name)
.execute
)


class GCPScheduler(SchedulerMixin, GCPInstance):
Expand Down Expand Up @@ -565,8 +576,14 @@ def __init__(self):
self._compute = self.refresh_client()

def refresh_client(self):

if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", False):
return googleapiclient.discovery.build("compute", "v1")
import google.oauth2.service_account # google-auth

creds = google.oauth2.service_account.Credentials.from_service_account_file(
os.environ["GOOGLE_APPLICATION_CREDENTIALS"],
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
else:
import google.auth.credentials # google-auth

Expand All @@ -582,7 +599,9 @@ def refresh_client(self):
# take first row
f_.write(creds_rows[0][1])
creds, _ = google.auth.load_credentials_from_file(filename=f)
return googleapiclient.discovery.build("compute", "v1", credentials=creds)
return googleapiclient.discovery.build(
"compute", "v1", credentials=creds, requestBuilder=build_request(creds)
)

def instances(self):
try:
Expand Down
4 changes: 2 additions & 2 deletions dask_cloudprovider/gcp/tests/test_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ async def test_create_cluster():

assert cluster.status == Status.running

cluster.scale(1)
cluster.scale(2)
await cluster
assert len(cluster.workers) == 1
assert len(cluster.workers) == 2

async with Client(cluster, asynchronous=True) as client:

Expand Down
5 changes: 5 additions & 0 deletions dask_cloudprovider/gcp/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from dask_cloudprovider.gcp.utils import build_request


def test_build_request():
assert build_request()(None, lambda x: x, "https://example.com")
14 changes: 14 additions & 0 deletions dask_cloudprovider/gcp/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import httplib2
import googleapiclient.http
import google_auth_httplib2


def build_request(credentials=None):
def inner(http, *args, **kwargs):
new_http = httplib2.Http()
if credentials is not None:
new_http = google_auth_httplib2.AuthorizedHttp(credentials, http=new_http)

return googleapiclient.http.HttpRequest(new_http, *args, **kwargs)

return inner
48 changes: 47 additions & 1 deletion dask_cloudprovider/generic/tests/test_vmcluster.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,55 @@
import pytest

from dask_cloudprovider.generic.vmcluster import VMCluster
import asyncio
import time

from dask_cloudprovider.generic.vmcluster import VMCluster, VMInterface


class DummyWorker(VMInterface):
"""A dummy worker for testing."""


class DummyScheduler(VMInterface):
"""A dummy scheduler for testing."""


class DummyCluster(VMCluster):
"""A dummy cluster for testing."""

scheduler_class = DummyScheduler
worker_class = DummyWorker


@pytest.mark.asyncio
async def test_init():
with pytest.raises(RuntimeError):
_ = VMCluster(asynchronous=True)


@pytest.mark.asyncio
async def test_call_async():
cluster = DummyCluster(asynchronous=True)

def blocking(string):
time.sleep(0.1)
return string

start = time.time()

a, b, c, d = await asyncio.gather(
cluster.call_async(blocking, "hello"),
cluster.call_async(blocking, "world"),
cluster.call_async(blocking, "foo"),
cluster.call_async(blocking, "bar"),
)

assert a == "hello"
assert b == "world"
assert c == "foo"
assert d == "bar"

# Each call to ``blocking`` takes 0.1 seconds, but they should've been run concurrently.
assert time.time() - start < 0.2

await cluster.close()

0 comments on commit 02b50e2

Please sign in to comment.