Skip to content

Commit

Permalink
Migrate daskworkergroup_replica_update to kr8s (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtomlinson committed Jul 26, 2023
1 parent 19b48b8 commit 26e137a
Showing 1 changed file with 64 additions and 90 deletions.
154 changes: 64 additions & 90 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import aiohttp
import kopf
import kr8s
from kr8s.asyncio.objects import Pod
from kr8s.asyncio.objects import Pod, Deployment
import kubernetes_asyncio as kubernetes
from importlib_metadata import entry_points
from kubernetes_asyncio.client import ApiException

from dask_kubernetes.operator._objects import (
DaskCluster,
Expand Down Expand Up @@ -570,103 +569,78 @@ async def daskworkergroup_replica_update(
name, namespace, meta, spec, new, body, logger, **kwargs
):
cluster_name = spec["cluster"]
wg = await DaskWorkerGroup(body, namespace=namespace)
try:
cluster = await wg.cluster()
except kr8s.NotFoundError:
# No need to scale if cluster is deleted, pods will be cleaned up
return

# Replica updates can come in quick succession and the changes must be applied atomically to ensure
# the number of workers ends in the correct state
async with worker_group_scale_locks[f"{namespace}/{name}"]:
async with kubernetes.client.api_client.ApiClient() as api_client:
customobjectsapi = kubernetes.client.CustomObjectsApi(api_client)
corev1api = kubernetes.client.CoreV1Api(api_client)

try:
cluster = await customobjectsapi.get_namespaced_custom_object(
group="kubernetes.dask.org",
version="v1",
plural="daskclusters",
namespace=namespace,
name=cluster_name,
)
except ApiException as e:
if e.status == 404:
# No need to scale if worker group is deleted, pods will be cleaned up
return
else:
raise e

cluster_labels = cluster.get("metadata", {}).get("labels", {})

workers = await corev1api.list_namespaced_pod(
current_workers = len(
await kr8s.asyncio.get(
"deployments",
namespace=namespace,
label_selector=f"dask.org/workergroup-name={name}",
)
current_workers = len(
[w for w in workers.items if w.status.phase != "Terminating"]
label_selector={"dask.org/workergroup-name": name},
)
desired_workers = new
workers_needed = desired_workers - current_workers
labels = _get_labels(meta)
annotations = _get_annotations(meta)
worker_spec = spec["worker"]
if "metadata" in worker_spec:
if "annotations" in worker_spec["metadata"]:
annotations.update(**worker_spec["metadata"]["annotations"])
if "labels" in worker_spec["metadata"]:
labels.update(**worker_spec["metadata"]["labels"])

SIZE = int(
dask.config.get("kubernetes.controller.worker-allocation.batch-size")
or 0
)
DELAY = int(
dask.config.get("kubernetes.controller.worker-allocation.delay") or 0
)
batch_size = min(workers_needed, SIZE) if SIZE else workers_needed
if workers_needed > 0:
for _ in range(batch_size):
data = build_worker_deployment_spec(
worker_group_name=name,
namespace=namespace,
cluster_name=cluster_name,
uuid=uuid4().hex[:10],
pod_spec=worker_spec["spec"],
annotations=annotations,
labels=labels,
)
kopf.adopt(data, owner=body)
kopf.label(data, labels=cluster_labels)
await kubernetes.client.AppsV1Api(
api_client
).create_namespaced_deployment(
namespace=namespace,
body=data,
)
if SIZE:
if workers_needed > SIZE:
raise kopf.TemporaryError(
"Added maximum number of workers for this batch but still need to create more workers, "
f"waiting for {DELAY} seconds before continuing.",
delay=DELAY,
)
logger.info(f"Scaled worker group {name} up to {desired_workers} workers.")
if workers_needed < 0:
worker_ids = await retire_workers(
n_workers=-workers_needed,
scheduler_service_name=f"{cluster_name}-scheduler",
)
desired_workers = new
workers_needed = desired_workers - current_workers
labels = _get_labels(meta)
annotations = _get_annotations(meta)
worker_spec = spec["worker"]
if "metadata" in worker_spec:
if "annotations" in worker_spec["metadata"]:
annotations.update(**worker_spec["metadata"]["annotations"])
if "labels" in worker_spec["metadata"]:
labels.update(**worker_spec["metadata"]["labels"])

batch_size = int(
dask.config.get("kubernetes.controller.worker-allocation.batch-size") or 0
)
batch_size = min(workers_needed, batch_size) if batch_size else workers_needed
batch_delay = int(
dask.config.get("kubernetes.controller.worker-allocation.delay") or 0
)
if workers_needed > 0:
for _ in range(batch_size):
data = build_worker_deployment_spec(
worker_group_name=name,
namespace=namespace,
logger=logger,
)
logger.info(f"Workers to close: {worker_ids}")
for wid in worker_ids:
await kubernetes.client.AppsV1Api(
api_client
).delete_namespaced_deployment(
name=wid,
namespace=namespace,
)
logger.info(
f"Scaled worker group {name} down to {desired_workers} workers."
cluster_name=cluster_name,
uuid=uuid4().hex[:10],
pod_spec=worker_spec["spec"],
annotations=annotations,
labels=labels,
)
kopf.adopt(data, owner=body)
kopf.label(data, labels=cluster.labels)
worker_deployment = await Deployment(data, namespace=namespace)
await worker_deployment.create()
if workers_needed > batch_size:
raise kopf.TemporaryError(
"Added maximum number of workers for this batch but still need to create more workers, "
f"waiting for {batch_delay} seconds before continuing.",
delay=batch_delay,
)
logger.info(f"Scaled worker group {name} up to {desired_workers} workers.")
if workers_needed < 0:
worker_ids = await retire_workers(
n_workers=-workers_needed,
scheduler_service_name=f"{cluster_name}-scheduler",
worker_group_name=name,
namespace=namespace,
logger=logger,
)
logger.info(f"Workers to close: {worker_ids}")
for wid in worker_ids:
worker_deployment = await Deployment(wid, namespace=namespace)
await worker_deployment.delete()
logger.info(
f"Scaled worker group {name} down to {desired_workers} workers."
)


@kopf.on.delete("daskworkergroup.kubernetes.dask.org", optional=True)
Expand Down

0 comments on commit 26e137a

Please sign in to comment.