Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions dask_kubernetes/experimental/kubecluster.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

import asyncio
import atexit
from contextlib import suppress
from enum import Enum
import time
from typing import ClassVar
import weakref

import kubernetes_asyncio as kubernetes

from distributed.core import rpc
from distributed.core import Status, rpc
from distributed.deploy import Cluster

from distributed.utils import Log, Logs, LoopRunner
from distributed.utils import Log, Logs, LoopRunner, TimeoutError

from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.operator import (
Expand Down Expand Up @@ -103,6 +111,8 @@ class KubeCluster(Cluster):
KubeCluster.from_name
"""

_instances: ClassVar[weakref.WeakSet[KubeCluster]] = weakref.WeakSet()

def __init__(
self,
name,
Expand Down Expand Up @@ -133,6 +143,8 @@ def __init__(
self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
self.loop = self._loop_runner.loop

self._instances.add(self)

super().__init__(asynchronous=asynchronous, **kwargs)
if not self.asynchronous:
self._loop_runner.start()
Expand Down Expand Up @@ -363,11 +375,11 @@ async def _delete_worker_group(self, name):
name=f"{self.name}-cluster-{name}",
)

def close(self):
def close(self, timeout=3600):
"""Delete the dask cluster"""
return self.sync(self._close)
return self.sync(self._close, timeout=timeout)

async def _close(self):
async def _close(self, timeout=None):
await super()._close()
if self.shutdown_on_close:
async with kubernetes.client.api_client.ApiClient() as api_client:
Expand All @@ -379,7 +391,12 @@ async def _close(self):
namespace=self.namespace,
name=self.cluster_name,
)
start = time.time()
while (await self._get_cluster()) is not None:
if time.time() > start + timeout:
raise TimeoutError(
f"Timed out deleting cluster resource {self.cluster_name}"
)
await asyncio.sleep(1)

def scale(self, n, worker_group="default"):
Expand Down Expand Up @@ -537,3 +554,19 @@ def from_name(cls, name, **kwargs):
>>> cluster = KubeCluster.from_name(name="simple-cluster")
"""
return cls(name=name, create_mode=CreateMode.CONNECT_ONLY, **kwargs)


@atexit.register
def reap_clusters():
async def _reap_clusters():
for cluster in list(KubeCluster._instances):
if cluster.shutdown_on_close and cluster.status != Status.closed:
await ClusterAuth.load_first(cluster.auth)
with suppress(TimeoutError):
if cluster.asynchronous:
await cluster.close(timeout=10)
else:
cluster.close(timeout=10)

loop = asyncio.get_event_loop()
loop.run_until_complete(_reap_clusters())