From 4e919a7f3428605e28e17c738de199fa774a2234 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Tue, 7 Oct 2025 13:45:41 +0000 Subject: [PATCH] Kubernetes: add NVIDIA GPU toleration Part-of: https://github.com/dstackai/dstack/issues/3126 --- .../core/backends/kubernetes/compute.py | 77 +++++++++++++++++-- 1 file changed, 70 insertions(+), 7 deletions(-) diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index e64c32884..9668a17f3 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -2,6 +2,7 @@ import tempfile import threading import time +from enum import Enum from typing import List, Optional, Tuple from gpuhunt import KNOWN_NVIDIA_GPUS, AcceleratorVendor @@ -62,9 +63,28 @@ NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu.name: gpu for gpu in KNOWN_NVIDIA_GPUS} NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys() +NVIDIA_GPU_RESOURCE = "nvidia.com/gpu" +NVIDIA_GPU_COUNT_LABEL = f"{NVIDIA_GPU_RESOURCE}.count" +NVIDIA_GPU_PRODUCT_LABEL = f"{NVIDIA_GPU_RESOURCE}.product" +NVIDIA_GPU_NODE_TAINT = NVIDIA_GPU_RESOURCE + +# Taints we know and tolerate when creating our objects, e.g., the jump pod. +TOLERATED_NODE_TAINTS = (NVIDIA_GPU_NODE_TAINT,) + DUMMY_REGION = "-" +class Operator(str, Enum): + EXISTS = "Exists" + IN = "In" + + +class TaintEffect(str, Enum): + NO_EXECUTE = "NoExecute" + NO_SCHEDULE = "NoSchedule" + PREFER_NO_SCHEDULE = "PreferNoSchedule" + + class KubernetesCompute( ComputeWithFilteredOffersCached, ComputeWithPrivilegedSupport, @@ -181,6 +201,7 @@ def run_job( resources_requests: dict[str, str] = {} resources_limits: dict[str, str] = {} node_affinity: Optional[client.V1NodeAffinity] = None + tolerations: list[client.V1Toleration] = [] volumes_: list[client.V1Volume] = [] volume_mounts: list[client.V1VolumeMount] = [] @@ -226,21 +247,28 @@ def run_job( "Requesting %d GPU(s), node labels: %s", gpu_min, matching_gpu_label_values ) # TODO: support other GPU vendors - resources_requests["nvidia.com/gpu"] = str(gpu_min) - resources_limits["nvidia.com/gpu"] = str(gpu_min) + resources_requests[NVIDIA_GPU_RESOURCE] = str(gpu_min) + resources_limits[NVIDIA_GPU_RESOURCE] = str(gpu_min) node_affinity = client.V1NodeAffinity( required_during_scheduling_ignored_during_execution=[ client.V1NodeSelectorTerm( match_expressions=[ client.V1NodeSelectorRequirement( - key="nvidia.com/gpu.product", - operator="In", + key=NVIDIA_GPU_PRODUCT_LABEL, + operator=Operator.IN, values=list(matching_gpu_label_values), ), ], ), ], ) + # It should be NoSchedule, but we also add NoExecute toleration just in case. + for effect in [TaintEffect.NO_SCHEDULE, TaintEffect.NO_EXECUTE]: + tolerations.append( + client.V1Toleration( + key=NVIDIA_GPU_NODE_TAINT, operator=Operator.EXISTS, effect=effect + ) + ) if (memory_min := resources_spec.memory.min) is not None: resources_requests["memory"] = _render_memory(memory_min) @@ -304,6 +332,7 @@ def run_job( ) ], affinity=node_affinity, + tolerations=tolerations, volumes=volumes_, ), ) @@ -527,8 +556,8 @@ def _get_gpus_from_node_labels(labels: dict[str, str]) -> tuple[list[Gpu], Optio # "A100" but a product name like "Tesla-T4" or "A100-SXM4-40GB". # Thus, we convert the product name to a known gpu name. # TODO: support other GPU vendors - gpu_count = labels.get("nvidia.com/gpu.count") - gpu_product = labels.get("nvidia.com/gpu.product") + gpu_count = labels.get(NVIDIA_GPU_COUNT_LABEL) + gpu_product = labels.get(NVIDIA_GPU_PRODUCT_LABEL) if gpu_count is None or gpu_product is None: return [], None gpu_count = int(gpu_count) @@ -647,6 +676,39 @@ def _create_jump_pod_service( namespace=namespace, name=pod_name, ) + + node_list = call_api_method(api.list_node, client.V1NodeList) + nodes = get_value(node_list, ".items", list[client.V1Node], required=True) + # False if we found at least one node without any "hard" taint, that is, if we don't need to + # specify the toleration. + toleration_required = True + # (key, effect) pairs. + tolerated_taints: set[tuple[str, str]] = set() + for node in nodes: + # True if the node has at least one NoExecute or NoSchedule taint. + has_hard_taint = False + taints = get_value(node, ".spec.taints", list[client.V1Taint]) or [] + for taint in taints: + effect = get_value(taint, ".effect", str, required=True) + # A "soft" taint, ignore. + if effect == TaintEffect.PREFER_NO_SCHEDULE: + continue + has_hard_taint = True + key = get_value(taint, ".key", str, required=True) + if key in TOLERATED_NODE_TAINTS: + tolerated_taints.add((key, effect)) + if not has_hard_taint: + toleration_required = False + break + tolerations: list[client.V1Toleration] = [] + if toleration_required: + for key, effect in tolerated_taints: + tolerations.append( + client.V1Toleration(key=key, operator=Operator.EXISTS, effect=effect) + ) + if not tolerations: + logger.warning("No appropriate node found, the jump pod may never be scheduled") + commands = _get_jump_pod_commands(authorized_keys=ssh_public_keys) pod = client.V1Pod( metadata=client.V1ObjectMeta( @@ -667,7 +729,8 @@ def _create_jump_pod_service( ) ], ) - ] + ], + tolerations=tolerations, ), ) call_api_method(