Skip to content

Commit

Permalink
Destroy VM if creation fails to avoid leaking VMs (#367)
Browse files Browse the repository at this point in the history
* 365 EC2Cluster leaks VMs when adaptive scaling enabled if worker creation raises an exception

* flake8

* formatting
  • Loading branch information
martinedgefocus committed Aug 2, 2022
1 parent 75fc8ec commit 3592253
Showing 1 changed file with 58 additions and 36 deletions.
94 changes: 58 additions & 36 deletions dask_cloudprovider/aws/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
try:
from aiobotocore.session import get_session
import botocore.exceptions
import botocore.config
except ImportError as e:
msg = (
"Dask Cloud Provider AWS requirements are not installed.\n\n"
Expand Down Expand Up @@ -86,8 +87,9 @@ async def create_vm(self):
"""
# TODO Enable Spot support

boto_config = botocore.config.Config(retries=dict(max_attempts=10))
async with self.cluster.boto_session.create_client(
"ec2", region_name=self.region
"ec2", region_name=self.region, config=boto_config
) as client:
self.vpc = self.vpc or await get_default_vpc(client)
self.subnet_id = (
Expand Down Expand Up @@ -159,46 +161,66 @@ async def create_vm(self):

response = await client.run_instances(**vm_kwargs)
[self.instance] = response["Instances"]
await client.create_tags(
Resources=[self.instance["InstanceId"]],
Tags=[
{"Key": "Name", "Value": self.name},
{"Key": "Dask Cluster", "Value": self.cluster.uuid},
],
)
self.cluster._log(
f"Created instance {self.instance['InstanceId']} as {self.name}"
)

address_type = "Private" if self.use_private_ip else "Public"
ip_address_key = f"{address_type}IpAddress"
try: # Ensure we tear down any resources we allocated if something goes wrong
return await self.configure_vm(client)
except Exception:
self.cluster._log(
f"reclaiming vm because configure_vm failed {self.name}"
)
await self.destroy_vm()
raise

async def configure_vm(self, client):
timeout = Timeout(300, f"Failed to add tags for {self.instance['InstanceId']}")
backoff = 0.1
while timeout.run():
try:
await client.create_tags(
Resources=[self.instance["InstanceId"]],
Tags=[
{"Key": "Name", "Value": self.name},
{"Key": "Dask Cluster", "Value": self.cluster.uuid},
],
)
break
except Exception as e:
timeout.set_exception(e)

await asyncio.sleep(min(backoff, 10) + backoff % 1)
# Exponential backoff with a cap of 10 seconds and some jitter
backoff = backoff * 2

self.cluster._log(
f"Created instance {self.instance['InstanceId']} as {self.name}"
)

timeout = Timeout(
300,
f"Failed {address_type} IP for instance {self.instance['InstanceId']}",
)
while (
ip_address_key not in self.instance
or self.instance[ip_address_key] is None
) and timeout.run():
backoff = 0.1
await asyncio.sleep(
min(backoff, 10) + backoff % 1
) # Exponential backoff with a cap of 10 seconds and some jitter
try:
response = await client.describe_instances(
InstanceIds=[self.instance["InstanceId"]], DryRun=False
)
[reservation] = response["Reservations"]
[self.instance] = reservation["Instances"]
except botocore.exceptions.ClientError as e:
timeout.set_exception(e)
backoff = backoff * 2
return self.instance[ip_address_key]
address_type = "Private" if self.use_private_ip else "Public"
ip_address_key = f"{address_type}IpAddress"

default_error = (
f"Failed {address_type} IP for instance {self.instance['InstanceId']}"
)
timeout = Timeout(300, default_error)
backoff = 0.1
while self.instance.get(ip_address_key) is None and timeout.run():
try:
response = await client.describe_instances(
InstanceIds=[self.instance["InstanceId"]], DryRun=False
)
[reservation] = response["Reservations"]
[self.instance] = reservation["Instances"]
except botocore.exceptions.ClientError as e:
timeout.set_exception(e)
await asyncio.sleep(min(backoff, 10) + backoff % 1)
# Exponential backoff with a cap of 10 seconds and some jitter
backoff = backoff * 2
return self.instance[ip_address_key]

async def destroy_vm(self):
boto_config = botocore.config.Config(retries=dict(max_attempts=10))
async with self.cluster.boto_session.create_client(
"ec2", region_name=self.region
"ec2", region_name=self.region, config=boto_config
) as client:
await client.terminate_instances(
InstanceIds=[self.instance["InstanceId"]], DryRun=False
Expand Down

0 comments on commit 3592253

Please sign in to comment.