Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 84 additions & 23 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@
AMD_GPU_NAME_TO_DEVICE_IDS,
AMD_GPU_NODE_TAINT,
AMD_GPU_RESOURCE,
LABEL_VALUE_MAX_LENGTH,
NVIDIA_GPU_NAME_TO_GPU_INFO,
NVIDIA_GPU_NODE_TAINT,
NVIDIA_GPU_PRODUCT_LABEL,
NVIDIA_GPU_RESOURCE,
OBJECT_NAME_MAX_LENGTH,
PodPhase,
TaintEffect,
build_base_labels,
build_dockerconfigjson,
filter_invalid_labels,
format_dstack_label_key,
format_memory,
get_amd_gpu_from_node_labels,
get_gpu_request_from_gpu_spec,
Expand Down Expand Up @@ -191,20 +192,38 @@ def run_job(
_create_jump_pod_service_if_not_exists(
api=api,
namespace=namespace,
project_name=run.project_name,
jump_pod_name=jump_pod_name,
jump_pod_service_name=jump_pod_service_name,
jump_pod_port=cluster.proxy_jump.port,
project_ssh_public_key=project_ssh_public_key.strip(),
)

pod_name = generate_unique_instance_name_for_job(run, job)
pod_name = generate_unique_instance_name_for_job(
run, job, max_length=LABEL_VALUE_MAX_LENGTH
)

base_labels = build_base_labels(
component="job",
unique_name=pod_name,
project=run.project_name,
name=job.job_spec.job_name,
user=run.user,
)
labels = merge_tags(
base_tags=base_labels,
resource_tags=run.run_spec.configuration.tags,
)
labels = filter_invalid_labels(labels)

registry_auth_secret_name: Optional[str] = None
with ExitStack() as exit_stack:
if job.job_spec.registry_auth is not None:
registry_auth_secret_name = _get_registry_auth_secret_name(pod_name)
_create_registry_auth_secret(
api=api,
namespace=namespace,
labels=labels,
secret_name=registry_auth_secret_name,
image_name=job.job_spec.image_name,
username=job.job_spec.registry_auth.username,
Expand All @@ -224,6 +243,7 @@ def run_job(
_create_job_pod(
api=api,
namespace=namespace,
labels=labels,
pod_name=pod_name,
registry_auth_secret_name=registry_auth_secret_name,
run_spec=run.run_spec,
Expand Down Expand Up @@ -264,10 +284,13 @@ def run_job(
api.create_namespaced_service(
namespace=namespace,
body=client.V1Service(
metadata=client.V1ObjectMeta(name=pod_service_name),
metadata=client.V1ObjectMeta(
name=pod_service_name,
labels=labels,
),
spec=client.V1ServiceSpec(
type="ClusterIP",
selector={"app.kubernetes.io/name": pod_name},
selector=_build_service_selector_from_labels(base_labels),
ports=[client.V1ServicePort(port=DSTACK_RUNNER_SSH_PORT)],
),
),
Expand Down Expand Up @@ -444,14 +467,30 @@ def create_gateway(
"The `kubernetes` backend does not support the `instance_type`"
" gateway configuration property"
)
instance_name = generate_unique_gateway_instance_name(configuration)

instance_name = generate_unique_gateway_instance_name(
configuration, max_length=LABEL_VALUE_MAX_LENGTH
)

base_labels = build_base_labels(
component="gateway",
unique_name=instance_name,
project=configuration.project_name,
name=configuration.instance_name,
)
labels = merge_tags(
base_tags=base_labels,
resource_tags=configuration.tags,
)
labels = filter_invalid_labels(labels)

commands = _get_gateway_commands(
authorized_keys=[configuration.ssh_key_pub], router=configuration.router
)
pod = client.V1Pod(
metadata=client.V1ObjectMeta(
name=instance_name,
labels={"app.kubernetes.io/name": instance_name},
labels=labels,
),
spec=client.V1PodSpec(
containers=[
Expand Down Expand Up @@ -486,10 +525,11 @@ def create_gateway(
service = client.V1Service(
metadata=client.V1ObjectMeta(
name=_get_pod_service_name(instance_name),
labels=labels,
),
spec=client.V1ServiceSpec(
type="LoadBalancer",
selector={"app.kubernetes.io/name": instance_name},
selector=_build_service_selector_from_labels(base_labels),
ports=[
client.V1ServicePort(
name="ssh",
Expand Down Expand Up @@ -608,6 +648,7 @@ def register_volume(self, volume: Volume) -> VolumeProvisioningData:

def create_volume(self, volume: Volume) -> VolumeProvisioningData:
assert isinstance(volume.configuration, KubernetesVolumeConfiguration)
assert volume.configuration.size is not None

region = volume.configuration.region
cluster = self.region_cluster_map.get(region)
Expand All @@ -618,21 +659,21 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
api = client.CoreV1Api(cluster.api_client)
namespace = cluster.namespace

labels = {
format_dstack_label_key("owner"): "dstack",
format_dstack_label_key("project"): volume.project_name,
format_dstack_label_key("name"): volume.name,
format_dstack_label_key("user"): volume.user,
}
pvc_name = generate_unique_volume_name(volume, max_length=LABEL_VALUE_MAX_LENGTH)

base_labels = build_base_labels(
component="volume",
unique_name=pvc_name,
project=volume.project_name,
name=volume.name,
user=volume.user,
)
labels = merge_tags(
base_tags=labels,
base_tags=base_labels,
resource_tags=volume.configuration.tags,
)
labels = filter_invalid_labels(labels)

assert volume.configuration.size is not None

pvc_name = generate_unique_volume_name(volume, max_length=OBJECT_NAME_MAX_LENGTH)
pvc = client.V1PersistentVolumeClaim(
metadata=client.V1ObjectMeta(
name=pvc_name,
Expand Down Expand Up @@ -789,11 +830,19 @@ def _gpu_matches_gpu_spec(gpu: Gpu, gpu_spec: GPUSpec) -> bool:
def _create_jump_pod_service_if_not_exists(
api: client.CoreV1Api,
namespace: str,
project_name: str,
jump_pod_name: str,
jump_pod_service_name: str,
jump_pod_port: Optional[int],
project_ssh_public_key: str,
) -> None:
base_labels = build_base_labels(
component="ssh-proxy",
unique_name=jump_pod_name,
project=project_name,
)
labels = filter_invalid_labels(base_labels)

service: Optional[client.V1Service] = None
pod: Optional[client.V1Pod] = None
_namespace = call_api_method(
Expand All @@ -805,7 +854,6 @@ def _create_jump_pod_service_if_not_exists(
_namespace = client.V1Namespace(
metadata=client.V1ObjectMeta(
name=namespace,
labels={"app.kubernetes.io/name": namespace},
),
)
api.create_namespace(body=_namespace)
Expand Down Expand Up @@ -867,7 +915,7 @@ def _create_jump_pod_service_if_not_exists(
pod = client.V1Pod(
metadata=client.V1ObjectMeta(
name=jump_pod_name,
labels={"app.kubernetes.io/name": jump_pod_name},
labels=labels,
),
spec=client.V1PodSpec(
containers=[
Expand Down Expand Up @@ -897,10 +945,13 @@ def _create_jump_pod_service_if_not_exists(
name=jump_pod_service_name,
)
service = client.V1Service(
metadata=client.V1ObjectMeta(name=jump_pod_service_name),
metadata=client.V1ObjectMeta(
name=jump_pod_service_name,
labels=labels,
),
spec=client.V1ServiceSpec(
type="NodePort",
selector={"app.kubernetes.io/name": jump_pod_name},
selector=_build_service_selector_from_labels(base_labels),
ports=[
client.V1ServicePort(
port=JUMP_POD_SSH_PORT,
Expand Down Expand Up @@ -1038,6 +1089,7 @@ def _get_jump_pod_commands(authorized_keys: list[str]) -> list[str]:
def _create_registry_auth_secret(
api: client.CoreV1Api,
namespace: str,
labels: dict[str, str],
secret_name: str,
image_name: str,
username: str,
Expand All @@ -1049,7 +1101,10 @@ def _create_registry_auth_secret(
password=password,
)
secret = client.V1Secret(
metadata=client.V1ObjectMeta(name=secret_name),
metadata=client.V1ObjectMeta(
name=secret_name,
labels=labels,
),
type="kubernetes.io/dockerconfigjson",
string_data={".dockerconfigjson": dockerconfigjson},
)
Expand All @@ -1062,6 +1117,7 @@ def _create_registry_auth_secret(
def _create_job_pod(
api: client.CoreV1Api,
namespace: str,
labels: dict[str, str],
pod_name: str,
registry_auth_secret_name: Optional[str],
run_spec: RunSpec,
Expand Down Expand Up @@ -1186,7 +1242,7 @@ def _create_job_pod(
pod = client.V1Pod(
metadata=client.V1ObjectMeta(
name=pod_name,
labels={"app.kubernetes.io/name": pod_name},
labels=labels,
),
spec=client.V1PodSpec(
containers=[
Expand Down Expand Up @@ -1399,6 +1455,11 @@ def _run_ssh_command(
return proc.returncode, proc.stdout


def _build_service_selector_from_labels(labels: dict[str, str]) -> dict[str, str]:
label_key = "app.kubernetes.io/instance"
return {label_key: labels[label_key]}


def _get_pod_service_name(pod_name: str) -> str:
return f"{pod_name}-service"

Expand Down
29 changes: 24 additions & 5 deletions src/dstack/_internal/core/backends/kubernetes/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Mapping
from decimal import Decimal
from enum import Enum
from typing import Callable, Optional, Union, cast
from typing import Callable, Literal, Optional, Union, cast

from gpuhunt import KNOWN_AMD_GPUS, KNOWN_NVIDIA_GPUS, AcceleratorVendor

Expand Down Expand Up @@ -135,6 +135,29 @@ def __sub__(self, other: Self) -> Self:
return type(self)(**dct)


def build_base_labels(
*,
component: Literal["ssh-proxy", "job", "gateway", "volume"],
unique_name: str,
project: str,
name: Optional[str] = None,
user: Optional[str] = None,
) -> dict[str, str]:
labels = {
"app.kubernetes.io/name": f"dstack-{component}",
# app.kubernetes.io/component would be redundant as app.kubernetes.io/name already includes
# it with dstack- prefix
"app.kubernetes.io/instance": unique_name,
"app.kubernetes.io/managed-by": "dstack",
"k8s.dstack.ai/project": project,
}
if name is not None:
labels["k8s.dstack.ai/name"] = name
if user is not None:
labels["k8s.dstack.ai/user"] = user
return labels


def filter_invalid_labels(labels: dict[str, str]) -> dict[str, str]:
filtered_labels: dict[str, str] = {}
for k, v in labels.items():
Expand Down Expand Up @@ -178,10 +201,6 @@ def validate_label_value(value: str) -> None:
raise ValueError("Invalid value")


def format_dstack_label_key(name: str) -> str:
return f"k8s.dstack.ai/{name}"


def build_dockerconfigjson(image_name: str, username: str, password: str) -> str:
registry = docker_utils.parse_image_name(image_name).registry
if registry is None or docker_utils.is_default_registry(registry):
Expand Down
Loading