diff --git a/src/dstack/_internal/core/backends/nebius/compute.py b/src/dstack/_internal/core/backends/nebius/compute.py index 610636986..7c1335f6b 100644 --- a/src/dstack/_internal/core/backends/nebius/compute.py +++ b/src/dstack/_internal/core/backends/nebius/compute.py @@ -19,6 +19,7 @@ ComputeWithPrivilegedSupport, generate_unique_instance_name, get_user_data, + merge_tags, ) from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier from dstack._internal.core.backends.nebius import resources @@ -150,6 +151,18 @@ def create_instance( if backend_data.cluster is not None: cluster_id = backend_data.cluster.id + labels = { + "owner": "dstack", + "dstack_project": instance_config.project_name.lower(), + "dstack_name": instance_config.instance_name, + "dstack_user": instance_config.user.lower(), + } + labels = merge_tags( + base_tags=labels, + backend_tags=self.config.tags, + resource_tags=instance_config.tags, + ) + labels = resources.filter_invalid_labels(labels) gpus = instance_offer.instance.resources.gpus create_disk_op = resources.create_disk( sdk=self._sdk, @@ -159,6 +172,7 @@ def create_instance( image_family="ubuntu24.04-cuda12" if gpus and gpus[0].name == "B200" else "ubuntu22.04-cuda12", + labels=labels, ) create_instance_op = None try: @@ -184,6 +198,7 @@ def create_instance( disk_id=create_disk_op.resource_id, subnet_id=self._get_subnet_id(instance_offer.region), preemptible=instance_offer.instance.resources.spot, + labels=labels, ) _wait_for_instance(self._sdk, create_instance_op) except BaseException: diff --git a/src/dstack/_internal/core/backends/nebius/configurator.py b/src/dstack/_internal/core/backends/nebius/configurator.py index 331e15344..09b756a9b 100644 --- a/src/dstack/_internal/core/backends/nebius/configurator.py +++ b/src/dstack/_internal/core/backends/nebius/configurator.py @@ -3,6 +3,7 @@ from nebius.aio.service_error import RequestError from dstack._internal.core.backends.base.configurator import ( + TAGS_MAX_NUM, BackendRecord, Configurator, raise_invalid_credentials_error, @@ -18,6 +19,7 @@ NebiusServiceAccountCreds, NebiusStoredConfig, ) +from dstack._internal.core.errors import BackendError, ServerClientError from dstack._internal.core.models.backends.base import BackendType @@ -53,6 +55,19 @@ def validate_config(self, config: NebiusBackendConfigWithCreds, default_creds_en f" some of the valid options: {sorted(valid_fabrics)}" ), ) + self._check_config_tags(config) + + def _check_config_tags(self, config: NebiusBackendConfigWithCreds): + if not config.tags: + return + if len(config.tags) > TAGS_MAX_NUM: + raise ServerClientError( + f"Maximum number of tags exceeded. Up to {TAGS_MAX_NUM} tags is allowed." + ) + try: + resources.validate_labels(config.tags) + except BackendError as e: + raise ServerClientError(e.args[0]) def create_backend( self, project_name: str, config: NebiusBackendConfigWithCreds diff --git a/src/dstack/_internal/core/backends/nebius/models.py b/src/dstack/_internal/core/backends/nebius/models.py index 27a2b7c1a..6c7554bab 100644 --- a/src/dstack/_internal/core/backends/nebius/models.py +++ b/src/dstack/_internal/core/backends/nebius/models.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Dict, Literal, Optional, Union from pydantic import Field, root_validator @@ -141,6 +141,12 @@ class NebiusBackendConfig(CoreModel): ) ), ] = None + tags: Annotated[ + Optional[Dict[str, str]], + Field( + description="The tags (labels) that will be assigned to resources created by `dstack`" + ), + ] = None class NebiusBackendConfigWithCreds(NebiusBackendConfig): diff --git a/src/dstack/_internal/core/backends/nebius/resources.py b/src/dstack/_internal/core/backends/nebius/resources.py index 795d6a5a5..7e1eb9e2b 100644 --- a/src/dstack/_internal/core/backends/nebius/resources.py +++ b/src/dstack/_internal/core/backends/nebius/resources.py @@ -1,11 +1,12 @@ import logging +import re import time from collections import defaultdict from collections.abc import Container as ContainerT from collections.abc import Generator, Iterable, Sequence from contextlib import contextmanager from tempfile import NamedTemporaryFile -from typing import Optional +from typing import Dict, Optional from nebius.aio.authorization.options import options_to_metadata from nebius.aio.operation import Operation as SDKOperation @@ -249,13 +250,14 @@ def get_default_subnet(sdk: SDK, project_id: str) -> Subnet: def create_disk( - sdk: SDK, name: str, project_id: str, size_mib: int, image_family: str + sdk: SDK, name: str, project_id: str, size_mib: int, image_family: str, labels: Dict[str, str] ) -> SDKOperation[Operation]: client = DiskServiceClient(sdk) request = CreateDiskRequest( metadata=ResourceMetadata( name=name, parent_id=project_id, + labels=labels, ), spec=DiskSpec( size_mebibytes=size_mib, @@ -288,12 +290,14 @@ def create_instance( disk_id: str, subnet_id: str, preemptible: bool, + labels: Dict[str, str], ) -> SDKOperation[Operation]: client = InstanceServiceClient(sdk) request = CreateInstanceRequest( metadata=ResourceMetadata( name=name, parent_id=project_id, + labels=labels, ), spec=InstanceSpec( cloud_init_user_data=user_data, @@ -367,3 +371,42 @@ def delete_cluster(sdk: SDK, cluster_id: str) -> None: metadata=REQUEST_MD, ) ) + + +def filter_invalid_labels(labels: Dict[str, str]) -> Dict[str, str]: + filtered_labels = {} + for k, v in labels.items(): + if not _is_valid_label(k, v): + logger.warning("Skipping invalid label '%s: %s'", k, v) + continue + filtered_labels[k] = v + return filtered_labels + + +def validate_labels(labels: Dict[str, str]): + for k, v in labels.items(): + if not _is_valid_label(k, v): + raise BackendError("Invalid resource labels") + + +def _is_valid_label(key: str, value: str) -> bool: + # TODO: [Nebius] current validation logic reuses GCP's approach. + # There is no public information on Nebius labels restrictions. + return is_valid_resource_name(key) and is_valid_label_value(value) + + +MAX_RESOURCE_NAME_LEN = 63 +NAME_PATTERN = re.compile(r"^[a-z][_\-a-z0-9]{0,62}$") +LABEL_VALUE_PATTERN = re.compile(r"^[_\-a-z0-9]{0,63}$") + + +def is_valid_resource_name(name: str) -> bool: + if len(name) < 1 or len(name) > MAX_RESOURCE_NAME_LEN: + return False + match = re.match(NAME_PATTERN, name) + return match is not None + + +def is_valid_label_value(value: str) -> bool: + match = re.match(LABEL_VALUE_PATTERN, value) + return match is not None