Skip to content
Merged
15 changes: 11 additions & 4 deletions dask_kubernetes/common/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@


async def get_external_address_for_scheduler_service(
core_api, service, port_forward_cluster_ip=None, service_name_resolution_retries=20
core_api,
service,
port_forward_cluster_ip=None,
service_name_resolution_retries=20,
port_name="comm",
):
"""Take a service object and return the scheduler address."""
[port] = [
port.port
for port in service.spec.ports
if port.name == service.metadata.name or port.name == "comm"
if port.name == service.metadata.name or port.name == port_name
]
if service.spec.type == "LoadBalancer":
lb = service.status.load_balancer.ingress[0]
Expand Down Expand Up @@ -104,13 +108,16 @@ async def port_forward_dashboard(service_name, namespace):
return port


async def get_scheduler_address(service_name, namespace):
async def get_scheduler_address(service_name, namespace, port_name="comm"):
async with kubernetes.client.api_client.ApiClient() as api_client:
api = kubernetes.client.CoreV1Api(api_client)
service = await api.read_namespaced_service(service_name, namespace)
port_forward_cluster_ip = None
address = await get_external_address_for_scheduler_service(
api, service, port_forward_cluster_ip=port_forward_cluster_ip
api,
service,
port_forward_cluster_ip=port_forward_cluster_ip,
port_name=port_name,
)
return address

Expand Down
68 changes: 56 additions & 12 deletions dask_kubernetes/operator/operator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio

from distributed.core import rpc
import aiohttp
from contextlib import suppress

import kopf
import kubernetes_asyncio as kubernetes

from uuid import uuid4

from distributed.core import rpc

from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.common.networking import (
get_scheduler_address,
Expand Down Expand Up @@ -195,6 +197,52 @@ async def daskworkergroup_create(spec, name, namespace, logger, **kwargs):
)


async def retire_workers(
n_workers, scheduler_service_name, worker_group_name, namespace, logger
):
# Try gracefully retiring via the HTTP API
dashboard_address = await get_scheduler_address(
scheduler_service_name,
namespace,
port_name="dashboard",
)
async with aiohttp.ClientSession() as session:
url = f"{dashboard_address}/api/v1/retire_workers"
params = {"n": n_workers}
async with session.post(url, json=params) as resp:
if resp.status <= 300:
retired_workers = await resp.json()
return [retired_workers[w]["name"] for w in retired_workers.keys()]

# Otherwise try gracefully retiring via the RPC
logger.info(
f"Scaling {worker_group_name} failed via the HTTP API, falling back to the Dask RPC"
)
# Dask version mismatches between the operator and scheduler may cause this to fail in any number of unexpected ways
with suppress(Exception):
comm_address = await get_scheduler_address(
scheduler_service_name,
namespace,
)
async with rpc(comm_address) as scheduler_comm:
return await scheduler_comm.workers_to_close(
n=n_workers,
attribute="name",
)

# Finally fall back to last-in-first-out scaling
logger.info(
f"Scaling {worker_group_name} failed via the Dask RPC, falling back to LIFO scaling"
)
async with kubernetes.client.api_client.ApiClient() as api_client:
api = kubernetes.client.CoreV1Api(api_client)
workers = await api.list_namespaced_pod(
namespace=namespace,
label_selector=f"dask.org/workergroup-name={worker_group_name}",
)
return [w["metadata"]["name"] for w in workers.items[:-n_workers]]


@kopf.on.update("daskworkergroup")
async def daskworkergroup_update(spec, name, namespace, logger, **kwargs):
async with kubernetes.client.api_client.ApiClient() as api_client:
Expand Down Expand Up @@ -226,17 +274,13 @@ async def daskworkergroup_update(spec, name, namespace, logger, **kwargs):
f"Scaled worker group {name} up to {spec['worker']['replicas']} workers."
)
if workers_needed < 0:
service_address = await get_scheduler_address(
f"{spec['cluster']}-service", namespace
worker_ids = await retire_workers(
n_workers=-workers_needed,
scheduler_service_name=f"{spec['cluster']}-service",
worker_group_name=name,
namespace=namespace,
logger=logger,
)
logger.info(
f"Asking scheduler to retire {-workers_needed} on {service_address}"
)
async with rpc(service_address) as scheduler:
worker_ids = await scheduler.workers_to_close(
n=-workers_needed, attribute="name"
)
# TODO: Check that were deting workers in the right worker group
logger.info(f"Workers to close: {worker_ids}")
for wid in worker_ids:
await api.delete_namespaced_pod(
Expand Down