Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 70 additions & 7 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import threading
import time
from enum import Enum
from typing import List, Optional, Tuple

from gpuhunt import KNOWN_NVIDIA_GPUS, AcceleratorVendor
Expand Down Expand Up @@ -62,9 +63,28 @@
NVIDIA_GPU_NAME_TO_GPU_INFO = {gpu.name: gpu for gpu in KNOWN_NVIDIA_GPUS}
NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys()

NVIDIA_GPU_RESOURCE = "nvidia.com/gpu"
NVIDIA_GPU_COUNT_LABEL = f"{NVIDIA_GPU_RESOURCE}.count"
NVIDIA_GPU_PRODUCT_LABEL = f"{NVIDIA_GPU_RESOURCE}.product"
NVIDIA_GPU_NODE_TAINT = NVIDIA_GPU_RESOURCE

# Taints we know and tolerate when creating our objects, e.g., the jump pod.
TOLERATED_NODE_TAINTS = (NVIDIA_GPU_NODE_TAINT,)

DUMMY_REGION = "-"


class Operator(str, Enum):
EXISTS = "Exists"
IN = "In"


class TaintEffect(str, Enum):
NO_EXECUTE = "NoExecute"
NO_SCHEDULE = "NoSchedule"
PREFER_NO_SCHEDULE = "PreferNoSchedule"


class KubernetesCompute(
ComputeWithFilteredOffersCached,
ComputeWithPrivilegedSupport,
Expand Down Expand Up @@ -181,6 +201,7 @@ def run_job(
resources_requests: dict[str, str] = {}
resources_limits: dict[str, str] = {}
node_affinity: Optional[client.V1NodeAffinity] = None
tolerations: list[client.V1Toleration] = []
volumes_: list[client.V1Volume] = []
volume_mounts: list[client.V1VolumeMount] = []

Expand Down Expand Up @@ -226,21 +247,28 @@ def run_job(
"Requesting %d GPU(s), node labels: %s", gpu_min, matching_gpu_label_values
)
# TODO: support other GPU vendors
resources_requests["nvidia.com/gpu"] = str(gpu_min)
resources_limits["nvidia.com/gpu"] = str(gpu_min)
resources_requests[NVIDIA_GPU_RESOURCE] = str(gpu_min)
resources_limits[NVIDIA_GPU_RESOURCE] = str(gpu_min)
node_affinity = client.V1NodeAffinity(
required_during_scheduling_ignored_during_execution=[
client.V1NodeSelectorTerm(
match_expressions=[
client.V1NodeSelectorRequirement(
key="nvidia.com/gpu.product",
operator="In",
key=NVIDIA_GPU_PRODUCT_LABEL,
operator=Operator.IN,
values=list(matching_gpu_label_values),
),
],
),
],
)
# It should be NoSchedule, but we also add NoExecute toleration just in case.
for effect in [TaintEffect.NO_SCHEDULE, TaintEffect.NO_EXECUTE]:
tolerations.append(
client.V1Toleration(
key=NVIDIA_GPU_NODE_TAINT, operator=Operator.EXISTS, effect=effect
)
)

if (memory_min := resources_spec.memory.min) is not None:
resources_requests["memory"] = _render_memory(memory_min)
Expand Down Expand Up @@ -304,6 +332,7 @@ def run_job(
)
],
affinity=node_affinity,
tolerations=tolerations,
volumes=volumes_,
),
)
Expand Down Expand Up @@ -527,8 +556,8 @@ def _get_gpus_from_node_labels(labels: dict[str, str]) -> tuple[list[Gpu], Optio
# "A100" but a product name like "Tesla-T4" or "A100-SXM4-40GB".
# Thus, we convert the product name to a known gpu name.
# TODO: support other GPU vendors
gpu_count = labels.get("nvidia.com/gpu.count")
gpu_product = labels.get("nvidia.com/gpu.product")
gpu_count = labels.get(NVIDIA_GPU_COUNT_LABEL)
gpu_product = labels.get(NVIDIA_GPU_PRODUCT_LABEL)
if gpu_count is None or gpu_product is None:
return [], None
gpu_count = int(gpu_count)
Expand Down Expand Up @@ -647,6 +676,39 @@ def _create_jump_pod_service(
namespace=namespace,
name=pod_name,
)

node_list = call_api_method(api.list_node, client.V1NodeList)
nodes = get_value(node_list, ".items", list[client.V1Node], required=True)
# False if we found at least one node without any "hard" taint, that is, if we don't need to
# specify the toleration.
toleration_required = True
# (key, effect) pairs.
tolerated_taints: set[tuple[str, str]] = set()
for node in nodes:
# True if the node has at least one NoExecute or NoSchedule taint.
has_hard_taint = False
taints = get_value(node, ".spec.taints", list[client.V1Taint]) or []
for taint in taints:
effect = get_value(taint, ".effect", str, required=True)
# A "soft" taint, ignore.
if effect == TaintEffect.PREFER_NO_SCHEDULE:
continue
has_hard_taint = True
key = get_value(taint, ".key", str, required=True)
if key in TOLERATED_NODE_TAINTS:
tolerated_taints.add((key, effect))
if not has_hard_taint:
toleration_required = False
break
tolerations: list[client.V1Toleration] = []
if toleration_required:
for key, effect in tolerated_taints:
tolerations.append(
client.V1Toleration(key=key, operator=Operator.EXISTS, effect=effect)
)
if not tolerations:
logger.warning("No appropriate node found, the jump pod may never be scheduled")

commands = _get_jump_pod_commands(authorized_keys=ssh_public_keys)
pod = client.V1Pod(
metadata=client.V1ObjectMeta(
Expand All @@ -667,7 +729,8 @@ def _create_jump_pod_service(
)
],
)
]
],
tolerations=tolerations,
),
)
call_api_method(
Expand Down