Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/dstack/_internal/core/backends/nebius/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ComputeWithPrivilegedSupport,
generate_unique_instance_name,
get_user_data,
merge_tags,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers, get_offers_disk_modifier
from dstack._internal.core.backends.nebius import resources
Expand Down Expand Up @@ -150,6 +151,18 @@ def create_instance(
if backend_data.cluster is not None:
cluster_id = backend_data.cluster.id

labels = {
"owner": "dstack",
"dstack_project": instance_config.project_name.lower(),
"dstack_name": instance_config.instance_name,
"dstack_user": instance_config.user.lower(),
}
labels = merge_tags(
base_tags=labels,
backend_tags=self.config.tags,
resource_tags=instance_config.tags,
)
labels = resources.filter_invalid_labels(labels)
gpus = instance_offer.instance.resources.gpus
create_disk_op = resources.create_disk(
sdk=self._sdk,
Expand All @@ -159,6 +172,7 @@ def create_instance(
image_family="ubuntu24.04-cuda12"
if gpus and gpus[0].name == "B200"
else "ubuntu22.04-cuda12",
labels=labels,
)
create_instance_op = None
try:
Expand All @@ -184,6 +198,7 @@ def create_instance(
disk_id=create_disk_op.resource_id,
subnet_id=self._get_subnet_id(instance_offer.region),
preemptible=instance_offer.instance.resources.spot,
labels=labels,
)
_wait_for_instance(self._sdk, create_instance_op)
except BaseException:
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/core/backends/nebius/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from nebius.aio.service_error import RequestError

from dstack._internal.core.backends.base.configurator import (
TAGS_MAX_NUM,
BackendRecord,
Configurator,
raise_invalid_credentials_error,
Expand All @@ -18,6 +19,7 @@
NebiusServiceAccountCreds,
NebiusStoredConfig,
)
from dstack._internal.core.errors import BackendError, ServerClientError
from dstack._internal.core.models.backends.base import BackendType


Expand Down Expand Up @@ -53,6 +55,19 @@ def validate_config(self, config: NebiusBackendConfigWithCreds, default_creds_en
f" some of the valid options: {sorted(valid_fabrics)}"
),
)
self._check_config_tags(config)

def _check_config_tags(self, config: NebiusBackendConfigWithCreds):
if not config.tags:
return
if len(config.tags) > TAGS_MAX_NUM:
raise ServerClientError(
f"Maximum number of tags exceeded. Up to {TAGS_MAX_NUM} tags is allowed."
)
try:
resources.validate_labels(config.tags)
except BackendError as e:
raise ServerClientError(e.args[0])

def create_backend(
self, project_name: str, config: NebiusBackendConfigWithCreds
Expand Down
8 changes: 7 additions & 1 deletion src/dstack/_internal/core/backends/nebius/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Dict, Literal, Optional, Union

from pydantic import Field, root_validator

Expand Down Expand Up @@ -141,6 +141,12 @@ class NebiusBackendConfig(CoreModel):
)
),
] = None
tags: Annotated[
Optional[Dict[str, str]],
Field(
description="The tags (labels) that will be assigned to resources created by `dstack`"
),
] = None


class NebiusBackendConfigWithCreds(NebiusBackendConfig):
Expand Down
47 changes: 45 additions & 2 deletions src/dstack/_internal/core/backends/nebius/resources.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import re
import time
from collections import defaultdict
from collections.abc import Container as ContainerT
from collections.abc import Generator, Iterable, Sequence
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
from typing import Optional
from typing import Dict, Optional

from nebius.aio.authorization.options import options_to_metadata
from nebius.aio.operation import Operation as SDKOperation
Expand Down Expand Up @@ -249,13 +250,14 @@ def get_default_subnet(sdk: SDK, project_id: str) -> Subnet:


def create_disk(
sdk: SDK, name: str, project_id: str, size_mib: int, image_family: str
sdk: SDK, name: str, project_id: str, size_mib: int, image_family: str, labels: Dict[str, str]
) -> SDKOperation[Operation]:
client = DiskServiceClient(sdk)
request = CreateDiskRequest(
metadata=ResourceMetadata(
name=name,
parent_id=project_id,
labels=labels,
),
spec=DiskSpec(
size_mebibytes=size_mib,
Expand Down Expand Up @@ -288,12 +290,14 @@ def create_instance(
disk_id: str,
subnet_id: str,
preemptible: bool,
labels: Dict[str, str],
) -> SDKOperation[Operation]:
client = InstanceServiceClient(sdk)
request = CreateInstanceRequest(
metadata=ResourceMetadata(
name=name,
parent_id=project_id,
labels=labels,
),
spec=InstanceSpec(
cloud_init_user_data=user_data,
Expand Down Expand Up @@ -367,3 +371,42 @@ def delete_cluster(sdk: SDK, cluster_id: str) -> None:
metadata=REQUEST_MD,
)
)


def filter_invalid_labels(labels: Dict[str, str]) -> Dict[str, str]:
filtered_labels = {}
for k, v in labels.items():
if not _is_valid_label(k, v):
logger.warning("Skipping invalid label '%s: %s'", k, v)
continue
filtered_labels[k] = v
return filtered_labels


def validate_labels(labels: Dict[str, str]):
for k, v in labels.items():
if not _is_valid_label(k, v):
raise BackendError("Invalid resource labels")


def _is_valid_label(key: str, value: str) -> bool:
# TODO: [Nebius] current validation logic reuses GCP's approach.
# There is no public information on Nebius labels restrictions.
return is_valid_resource_name(key) and is_valid_label_value(value)


MAX_RESOURCE_NAME_LEN = 63
NAME_PATTERN = re.compile(r"^[a-z][_\-a-z0-9]{0,62}$")
LABEL_VALUE_PATTERN = re.compile(r"^[_\-a-z0-9]{0,63}$")


def is_valid_resource_name(name: str) -> bool:
if len(name) < 1 or len(name) > MAX_RESOURCE_NAME_LEN:
return False
match = re.match(NAME_PATTERN, name)
return match is not None


def is_valid_label_value(value: str) -> bool:
match = re.match(LABEL_VALUE_PATTERN, value)
return match is not None