diff --git a/dask_kubernetes/experimental/kubecluster.py b/dask_kubernetes/experimental/kubecluster.py index c635f99bc..e1aee1420 100644 --- a/dask_kubernetes/experimental/kubecluster.py +++ b/dask_kubernetes/experimental/kubecluster.py @@ -1,11 +1,19 @@ +from __future__ import annotations + import asyncio +import atexit +from contextlib import suppress from enum import Enum +import time +from typing import ClassVar +import weakref + import kubernetes_asyncio as kubernetes -from distributed.core import rpc +from distributed.core import Status, rpc from distributed.deploy import Cluster -from distributed.utils import Log, Logs, LoopRunner +from distributed.utils import Log, Logs, LoopRunner, TimeoutError from dask_kubernetes.common.auth import ClusterAuth from dask_kubernetes.operator import ( @@ -103,6 +111,8 @@ class KubeCluster(Cluster): KubeCluster.from_name """ + _instances: ClassVar[weakref.WeakSet[KubeCluster]] = weakref.WeakSet() + def __init__( self, name, @@ -133,6 +143,8 @@ def __init__( self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous) self.loop = self._loop_runner.loop + self._instances.add(self) + super().__init__(asynchronous=asynchronous, **kwargs) if not self.asynchronous: self._loop_runner.start() @@ -363,11 +375,11 @@ async def _delete_worker_group(self, name): name=f"{self.name}-cluster-{name}", ) - def close(self): + def close(self, timeout=3600): """Delete the dask cluster""" - return self.sync(self._close) + return self.sync(self._close, timeout=timeout) - async def _close(self): + async def _close(self, timeout=None): await super()._close() if self.shutdown_on_close: async with kubernetes.client.api_client.ApiClient() as api_client: @@ -379,7 +391,12 @@ async def _close(self): namespace=self.namespace, name=self.cluster_name, ) + start = time.time() while (await self._get_cluster()) is not None: + if time.time() > start + timeout: + raise TimeoutError( + f"Timed out deleting cluster resource {self.cluster_name}" + ) await asyncio.sleep(1) def scale(self, n, worker_group="default"): @@ -537,3 +554,19 @@ def from_name(cls, name, **kwargs): >>> cluster = KubeCluster.from_name(name="simple-cluster") """ return cls(name=name, create_mode=CreateMode.CONNECT_ONLY, **kwargs) + + +@atexit.register +def reap_clusters(): + async def _reap_clusters(): + for cluster in list(KubeCluster._instances): + if cluster.shutdown_on_close and cluster.status != Status.closed: + await ClusterAuth.load_first(cluster.auth) + with suppress(TimeoutError): + if cluster.asynchronous: + await cluster.close(timeout=10) + else: + cluster.close(timeout=10) + + loop = asyncio.get_event_loop() + loop.run_until_complete(_reap_clusters())