diff --git a/dask_kubernetes/common/networking.py b/dask_kubernetes/common/networking.py index 95ea501f8..9008856c4 100644 --- a/dask_kubernetes/common/networking.py +++ b/dask_kubernetes/common/networking.py @@ -11,13 +11,17 @@ async def get_external_address_for_scheduler_service( - core_api, service, port_forward_cluster_ip=None, service_name_resolution_retries=20 + core_api, + service, + port_forward_cluster_ip=None, + service_name_resolution_retries=20, + port_name="comm", ): """Take a service object and return the scheduler address.""" [port] = [ port.port for port in service.spec.ports - if port.name == service.metadata.name or port.name == "comm" + if port.name == service.metadata.name or port.name == port_name ] if service.spec.type == "LoadBalancer": lb = service.status.load_balancer.ingress[0] @@ -104,13 +108,16 @@ async def port_forward_dashboard(service_name, namespace): return port -async def get_scheduler_address(service_name, namespace): +async def get_scheduler_address(service_name, namespace, port_name="comm"): async with kubernetes.client.api_client.ApiClient() as api_client: api = kubernetes.client.CoreV1Api(api_client) service = await api.read_namespaced_service(service_name, namespace) port_forward_cluster_ip = None address = await get_external_address_for_scheduler_service( - api, service, port_forward_cluster_ip=port_forward_cluster_ip + api, + service, + port_forward_cluster_ip=port_forward_cluster_ip, + port_name=port_name, ) return address diff --git a/dask_kubernetes/operator/operator.py b/dask_kubernetes/operator/operator.py index b900c4dee..553337389 100644 --- a/dask_kubernetes/operator/operator.py +++ b/dask_kubernetes/operator/operator.py @@ -1,12 +1,14 @@ import asyncio - -from distributed.core import rpc +import aiohttp +from contextlib import suppress import kopf import kubernetes_asyncio as kubernetes from uuid import uuid4 +from distributed.core import rpc + from dask_kubernetes.common.auth import ClusterAuth from dask_kubernetes.common.networking import ( get_scheduler_address, @@ -195,6 +197,52 @@ async def daskworkergroup_create(spec, name, namespace, logger, **kwargs): ) +async def retire_workers( + n_workers, scheduler_service_name, worker_group_name, namespace, logger +): + # Try gracefully retiring via the HTTP API + dashboard_address = await get_scheduler_address( + scheduler_service_name, + namespace, + port_name="dashboard", + ) + async with aiohttp.ClientSession() as session: + url = f"{dashboard_address}/api/v1/retire_workers" + params = {"n": n_workers} + async with session.post(url, json=params) as resp: + if resp.status <= 300: + retired_workers = await resp.json() + return [retired_workers[w]["name"] for w in retired_workers.keys()] + + # Otherwise try gracefully retiring via the RPC + logger.info( + f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC" + ) + # Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways + with suppress(Exception): + comm_address = await get_scheduler_address( + scheduler_service_name, + namespace, + ) + async with rpc(comm_address) as scheduler_comm: + return await scheduler_comm.workers_to_close( + n=n_workers, + attribute="name", + ) + + # Finally fall back to last-in-first-out scaling + logger.info( + f"Scaling {worker_group_name} failed via the Dask RPC, falling back to LIFO scaling" + ) + async with kubernetes.client.api_client.ApiClient() as api_client: + api = kubernetes.client.CoreV1Api(api_client) + workers = await api.list_namespaced_pod( + namespace=namespace, + label_selector=f"dask.org/workergroup-name={worker_group_name}", + ) + return [w["metadata"]["name"] for w in workers.items[:-n_workers]] + + @kopf.on.update("daskworkergroup") async def daskworkergroup_update(spec, name, namespace, logger, **kwargs): async with kubernetes.client.api_client.ApiClient() as api_client: @@ -226,17 +274,13 @@ async def daskworkergroup_update(spec, name, namespace, logger, **kwargs): f"Scaled worker group {name} up to {spec['worker']['replicas']} workers." ) if workers_needed < 0: - service_address = await get_scheduler_address( - f"{spec['cluster']}-service", namespace + worker_ids = await retire_workers( + n_workers=-workers_needed, + scheduler_service_name=f"{spec['cluster']}-service", + worker_group_name=name, + namespace=namespace, + logger=logger, ) - logger.info( - f"Asking scheduler to retire {-workers_needed} on {service_address}" - ) - async with rpc(service_address) as scheduler: - worker_ids = await scheduler.workers_to_close( - n=-workers_needed, attribute="name" - ) - # TODO: Check that were deting workers in the right worker group logger.info(f"Workers to close: {worker_ids}") for wid in worker_ids: await api.delete_namespaced_pod(