From 420bb921257304eb9b2f164116e86ac6a3040b24 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 7 Mar 2025 16:12:05 +0500 Subject: [PATCH 1/5] Introduce ComputeWith classes to detect compute features --- .../_internal/core/backends/__init__.py | 94 +++++++----- .../_internal/core/backends/aws/backend.py | 3 +- .../_internal/core/backends/aws/compute.py | 18 ++- .../core/backends/aws/configurator.py | 3 +- .../_internal/core/backends/azure/backend.py | 3 +- .../_internal/core/backends/azure/compute.py | 10 +- .../core/backends/azure/configurator.py | 3 +- .../_internal/core/backends/base/backend.py | 8 +- .../_internal/core/backends/base/compute.py | 137 +++++++++++++----- .../core/backends/base/configurator.py | 6 +- .../_internal/core/backends/configurators.py | 4 + .../_internal/core/backends/cudo/backend.py | 3 +- .../_internal/core/backends/cudo/compute.py | 6 +- .../core/backends/cudo/configurator.py | 3 +- .../core/backends/datacrunch/backend.py | 3 +- .../core/backends/datacrunch/compute.py | 6 +- .../core/backends/datacrunch/configurator.py | 3 +- .../_internal/core/backends/gcp/backend.py | 3 +- .../_internal/core/backends/gcp/compute.py | 12 +- .../core/backends/gcp/configurator.py | 3 +- .../core/backends/kubernetes/backend.py | 3 +- .../core/backends/kubernetes/compute.py | 6 +- .../core/backends/kubernetes/configurator.py | 3 +- .../core/backends/lambdalabs/backend.py | 3 +- .../core/backends/lambdalabs/compute.py | 6 +- .../core/backends/lambdalabs/configurator.py | 3 +- .../_internal/core/backends/local/backend.py | 3 +- .../_internal/core/backends/local/compute.py | 18 ++- .../_internal/core/backends/nebius/backend.py | 3 +- .../core/backends/nebius/configurator.py | 3 +- .../_internal/core/backends/oci/backend.py | 3 +- .../_internal/core/backends/oci/compute.py | 8 +- .../core/backends/oci/configurator.py | 3 +- .../_internal/core/backends/runpod/backend.py | 3 +- .../_internal/core/backends/runpod/compute.py | 6 +- .../core/backends/runpod/configurator.py | 3 +- .../core/backends/tensordock/backend.py | 3 +- .../core/backends/tensordock/compute.py | 6 +- .../core/backends/tensordock/configurator.py | 3 +- .../_internal/core/backends/vastai/backend.py | 3 +- .../core/backends/vastai/configurator.py | 3 +- .../_internal/core/backends/vultr/backend.py | 3 +- .../_internal/core/backends/vultr/compute.py | 8 +- .../core/backends/vultr/configurator.py | 3 +- .../background/tasks/process_instances.py | 9 +- .../tasks/process_placement_groups.py | 5 +- .../tasks/process_submitted_jobs.py | 5 +- .../background/tasks/process_volumes.py | 7 +- .../server/services/gateways/__init__.py | 6 +- .../server/services/jobs/__init__.py | 11 +- .../_internal/server/services/volumes.py | 5 +- src/dstack/_internal/server/testing/common.py | 27 ++++ .../background/tasks/test_process_gateways.py | 4 + .../tasks/test_process_instances.py | 2 + .../tasks/test_process_placement_groups.py | 2 + .../tasks/test_process_submitted_jobs.py | 2 + .../tasks/test_process_submitted_volumes.py | 2 + .../tasks/test_process_terminating_jobs.py | 13 +- .../_internal/server/routers/test_gateways.py | 3 + .../_internal/server/routers/test_runs.py | 3 + .../_internal/server/routers/test_volumes.py | 2 + 61 files changed, 417 insertions(+), 131 deletions(-) diff --git a/src/dstack/_internal/core/backends/__init__.py b/src/dstack/_internal/core/backends/__init__.py index ae6903b99..8ffa38df3 100644 --- a/src/dstack/_internal/core/backends/__init__.py +++ b/src/dstack/_internal/core/backends/__init__.py @@ -1,42 +1,58 @@ +from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + ComputeWithGatewaySupport, + ComputeWithMultinodeSupport, + ComputeWithPlacementGroupSupport, + ComputeWithPrivateGatewaySupport, + ComputeWithReservationSupport, + ComputeWithVolumeSupport, +) +from dstack._internal.core.backends.base.configurator import Configurator +from dstack._internal.core.backends.configurators import list_available_configurator_classes from dstack._internal.core.models.backends.base import BackendType -BACKENDS_WITH_MULTINODE_SUPPORT = [ - BackendType.AWS, - BackendType.AZURE, - BackendType.GCP, - BackendType.REMOTE, - BackendType.OCI, - BackendType.VULTR, -] -BACKENDS_WITH_CREATE_INSTANCE_SUPPORT = [ - BackendType.AWS, - BackendType.DSTACK, - BackendType.AZURE, - BackendType.CUDO, - BackendType.DATACRUNCH, - BackendType.GCP, - BackendType.LAMBDA, - BackendType.OCI, - BackendType.TENSORDOCK, - BackendType.VULTR, -] -BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT = [ - BackendType.AWS, -] -BACKENDS_WITH_RESERVATION_SUPPORT = [ - BackendType.AWS, -] -BACKENDS_WITH_GATEWAY_SUPPORT = [ - BackendType.AWS, - BackendType.AZURE, - BackendType.GCP, - BackendType.KUBERNETES, -] -BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT = [BackendType.AWS] -BACKENDS_WITH_VOLUMES_SUPPORT = [ - BackendType.AWS, - BackendType.GCP, - BackendType.LOCAL, - BackendType.RUNPOD, -] +def _get_backends_with_compute_feature( + configurator_classes: list[type[Configurator]], + compute_feature_class: type, +) -> list[BackendType]: + backend_types = [] + for configurator_class in configurator_classes: + compute_class = configurator_class.BACKEND_CLASS.COMPUTE_CLASS + if issubclass(compute_class, compute_feature_class): + backend_types.append(configurator_class.TYPE) + return backend_types + + +_configurator_classes = list_available_configurator_classes() + + +# TODO: Add LocalBackend to lists if it's enabled +BACKENDS_WITH_CREATE_INSTANCE_SUPPORT = _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithCreateInstanceSupport, +) +BACKENDS_WITH_MULTINODE_SUPPORT = [BackendType.REMOTE] + _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithMultinodeSupport, +) +BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT = _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithPlacementGroupSupport, +) +BACKENDS_WITH_RESERVATION_SUPPORT = _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithReservationSupport, +) +BACKENDS_WITH_GATEWAY_SUPPORT = _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithGatewaySupport, +) +BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT = _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithPrivateGatewaySupport, +) +BACKENDS_WITH_VOLUMES_SUPPORT = _get_backends_with_compute_feature( + configurator_classes=_configurator_classes, + compute_feature_class=ComputeWithVolumeSupport, +) diff --git a/src/dstack/_internal/core/backends/aws/backend.py b/src/dstack/_internal/core/backends/aws/backend.py index 4defe8937..2f0024408 100644 --- a/src/dstack/_internal/core/backends/aws/backend.py +++ b/src/dstack/_internal/core/backends/aws/backend.py @@ -8,7 +8,8 @@ class AWSBackend(Backend): - TYPE: BackendType = BackendType.AWS + TYPE = BackendType.AWS + COMPUTE_CLASS = AWSCompute def __init__(self, config: AWSConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 08210f02f..7714e894d 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -12,6 +12,13 @@ from dstack._internal.core.backends.aws.models import AWSAccessKeyCreds from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithCreateInstanceSupport, + ComputeWithGatewaySupport, + ComputeWithMultinodeSupport, + ComputeWithPlacementGroupSupport, + ComputeWithPrivateGatewaySupport, + ComputeWithReservationSupport, + ComputeWithVolumeSupport, generate_unique_gateway_instance_name, generate_unique_instance_name, generate_unique_volume_name, @@ -62,7 +69,16 @@ class AWSVolumeBackendData(CoreModel): iops: int -class AWSCompute(Compute): +class AWSCompute( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, + ComputeWithReservationSupport, + ComputeWithPlacementGroupSupport, + ComputeWithGatewaySupport, + ComputeWithPrivateGatewaySupport, + ComputeWithVolumeSupport, +): def __init__(self, config: AWSConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/aws/configurator.py b/src/dstack/_internal/core/backends/aws/configurator.py index 9c884d9ec..924fb7d88 100644 --- a/src/dstack/_internal/core/backends/aws/configurator.py +++ b/src/dstack/_internal/core/backends/aws/configurator.py @@ -53,7 +53,8 @@ class AWSConfigurator(Configurator): - TYPE: BackendType = BackendType.AWS + TYPE = BackendType.AWS + BACKEND_CLASS = AWSBackend def validate_config(self, config: AWSBackendConfigWithCreds, default_creds_enabled: bool): if is_core_model_instance(config.creds, AWSDefaultCreds) and not default_creds_enabled: diff --git a/src/dstack/_internal/core/backends/azure/backend.py b/src/dstack/_internal/core/backends/azure/backend.py index 1878f7db5..2022a21ac 100644 --- a/src/dstack/_internal/core/backends/azure/backend.py +++ b/src/dstack/_internal/core/backends/azure/backend.py @@ -6,7 +6,8 @@ class AzureBackend(Backend): - TYPE: BackendType = BackendType.AZURE + TYPE = BackendType.AZURE + COMPUTE_CLASS = AzureCompute def __init__(self, config: AzureConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index eb9c8fd8d..d8382ff64 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -39,6 +39,9 @@ from dstack._internal.core.backends.azure.config import AzureConfig from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithCreateInstanceSupport, + ComputeWithGatewaySupport, + ComputeWithMultinodeSupport, generate_unique_gateway_instance_name, generate_unique_instance_name, get_gateway_user_data, @@ -71,7 +74,12 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("30GB"), max=Memory.parse("4095GB")) -class AzureCompute(Compute): +class AzureCompute( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, + ComputeWithGatewaySupport, +): def __init__(self, config: AzureConfig, credential: TokenCredential): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/azure/configurator.py b/src/dstack/_internal/core/backends/azure/configurator.py index 785c328e6..efb8858b9 100644 --- a/src/dstack/_internal/core/backends/azure/configurator.py +++ b/src/dstack/_internal/core/backends/azure/configurator.py @@ -72,7 +72,8 @@ class AzureConfigurator(Configurator): - TYPE: BackendType = BackendType.AZURE + TYPE = BackendType.AZURE + BACKEND_CLASS = AzureBackend def validate_config(self, config: AzureBackendConfigWithCreds, default_creds_enabled: bool): if is_core_model_instance(config.creds, AzureDefaultCreds) and not default_creds_enabled: diff --git a/src/dstack/_internal/core/backends/base/backend.py b/src/dstack/_internal/core/backends/base/backend.py index 545f4923a..ee19fd989 100644 --- a/src/dstack/_internal/core/backends/base/backend.py +++ b/src/dstack/_internal/core/backends/base/backend.py @@ -1,12 +1,18 @@ from abc import ABC, abstractmethod +from typing import ClassVar from dstack._internal.core.backends.base.compute import Compute from dstack._internal.core.models.backends.base import BackendType class Backend(ABC): - TYPE: BackendType + TYPE: ClassVar[BackendType] + # `COMPUTE_CLASS` is used to introspect compute features without initializing it. + COMPUTE_CLASS: ClassVar[type[Compute]] @abstractmethod def compute(self) -> Compute: + """ + Returns Compute instance. + """ pass diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index e28399e58..7f9290590 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -46,6 +46,11 @@ class Compute(ABC): + """ + A base class for all compute implementations with minimal features. + If a compute supports additional features, it must also subclass `ComputeWith*` classes. + """ + def __init__(self): self._offers_cache_lock = threading.Lock() self._offers_cache = TTLCache(maxsize=5, ttl=30) @@ -86,18 +91,6 @@ def terminate_instance( """ pass - def create_instance( - self, - instance_offer: InstanceOfferWithAvailability, - instance_config: InstanceConfiguration, - ) -> JobProvisioningData: - """ - Launches a new instance. It should return `JobProvisioningData` ASAP. - If required to wait to get the IP address or SSH port, return partially filled `JobProvisioningData` - and implement `update_provisioning_data()`. - """ - raise NotImplementedError() - def update_provisioning_data( self, provisioning_data: JobProvisioningData, @@ -114,6 +107,67 @@ def update_provisioning_data( """ pass + def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> int: + # Requirements is not hashable, so we use a hack to get arguments hash + if requirements is None: + return hash(None) + return hash(requirements.json()) + + @cachedmethod( + cache=lambda self: self._offers_cache, + key=_get_offers_cached_key, + lock=lambda self: self._offers_cache_lock, + ) + def get_offers_cached( + self, requirements: Optional[Requirements] = None + ) -> List[InstanceOfferWithAvailability]: + return self.get_offers(requirements) + + +class ComputeWithCreateInstanceSupport(ABC): + """ + Must be subclassed and implemented to support fleets (instance creation without running a job). + Typically, a compute that runs VMs would implement it, + and a compute that runs containers would not. + """ + + @abstractmethod + def create_instance( + self, + instance_offer: InstanceOfferWithAvailability, + instance_config: InstanceConfiguration, + ) -> JobProvisioningData: + """ + Launches a new instance. It should return `JobProvisioningData` ASAP. + If required to wait to get the IP address or SSH port, return partially filled `JobProvisioningData` + and implement `update_provisioning_data()`. + """ + pass + + +class ComputeWithMultinodeSupport: + """ + Must be subclassed to support multinode tasks and cluster fleets. + Instances provisioned in the same project/region must be interconnected. + """ + + pass + + +class ComputeWithReservationSupport: + """ + Must be subclassed to support provisioning from reservations. + """ + + pass + + +class ComputeWithPlacementGroupSupport(ABC): + """ + Must be subclassed and implemented to support placement groups. + """ + + @abstractmethod def create_placement_group( self, placement_group: PlacementGroup, @@ -121,8 +175,9 @@ def create_placement_group( """ Creates a placement group. """ - raise NotImplementedError() + pass + @abstractmethod def delete_placement_group( self, placement_group: PlacementGroup, @@ -131,8 +186,15 @@ def delete_placement_group( Deletes a placement group. If the group does not exist, it should not raise errors but return silently. """ - raise NotImplementedError() + pass + +class ComputeWithGatewaySupport(ABC): + """ + Must be subclassed and imlemented to support gateways. + """ + + @abstractmethod def create_gateway( self, configuration: GatewayComputeConfiguration, @@ -140,8 +202,9 @@ def create_gateway( """ Creates a gateway instance. """ - raise NotImplementedError() + pass + @abstractmethod def terminate_gateway( self, instance_id: str, @@ -152,21 +215,39 @@ def terminate_gateway( Terminates a gateway instance. Generally, it passes the call to `terminate_instance()`, but may perform additional work such as deleting a load balancer when a gateway has one. """ - raise NotImplementedError() + pass + + +class ComputeWithPrivateGatewaySupport: + """ + Must be subclassed to support private gateways. + `create_gateway()` must be able to create private gateways. + """ + + pass + +class ComputeWithVolumeSupport(ABC): + """ + Must be subclassed and implemented to support volumes. + """ + + @abstractmethod def register_volume(self, volume: Volume) -> VolumeProvisioningData: """ Returns VolumeProvisioningData for an existing volume. Used to add external volumes to dstack. """ - raise NotImplementedError() + pass + @abstractmethod def create_volume(self, volume: Volume) -> VolumeProvisioningData: """ Creates a new volume. """ raise NotImplementedError() + @abstractmethod def delete_volume(self, volume: Volume): """ Deletes a volume. @@ -176,13 +257,17 @@ def delete_volume(self, volume: Volume): def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData: """ Attaches a volume to the instance. - If the volume is not found, it should raise `ComputeError()` instead of a thrid-party exception. + If the volume is not found, it should raise `ComputeError()`. + Implement only if compute may return `VolumeProvisioningData.attachable`. + Otherwise, volumes should be attached by `run_job()`. """ raise NotImplementedError() def detach_volume(self, volume: Volume, instance_id: str, force: bool = False): """ Detaches a volume from the instance. + Implement only if compute may return `VolumeProvisioningData.detachable`. + Otherwise, volumes should be detached on instance termination. """ raise NotImplementedError() @@ -195,22 +280,6 @@ def is_volume_detached(self, volume: Volume, instance_id: str) -> bool: """ return True - def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> int: - # Requirements is not hashable, so we use a hack to get arguments hash - if requirements is None: - return hash(None) - return hash(requirements.json()) - - @cachedmethod( - cache=lambda self: self._offers_cache, - key=_get_offers_cached_key, - lock=lambda self: self._offers_cache_lock, - ) - def get_offers_cached( - self, requirements: Optional[Requirements] = None - ) -> List[InstanceOfferWithAvailability]: - return self.get_offers(requirements) - def get_job_instance_name(run: Run, job: Job) -> str: return job.job_spec.job_name diff --git a/src/dstack/_internal/core/backends/base/configurator.py b/src/dstack/_internal/core/backends/base/configurator.py index 6ce555b0e..994266c43 100644 --- a/src/dstack/_internal/core/backends/base/configurator.py +++ b/src/dstack/_internal/core/backends/base/configurator.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, ClassVar, List, Optional from uuid import UUID from dstack._internal.core.backends.base.backend import Backend @@ -47,7 +47,9 @@ class Configurator(ABC): in `dstack._internal.core.backends.configurators`. """ - TYPE: BackendType + TYPE: ClassVar[BackendType] + # `BACKEND_CLASS` is used to introspect backend features without initializing it. + BACKEND_CLASS: ClassVar[type[Backend]] @abstractmethod def validate_config(self, config: AnyBackendConfigWithCreds, default_creds_enabled: bool): diff --git a/src/dstack/_internal/core/backends/configurators.py b/src/dstack/_internal/core/backends/configurators.py index 5b1b099f3..cf1db1fc5 100644 --- a/src/dstack/_internal/core/backends/configurators.py +++ b/src/dstack/_internal/core/backends/configurators.py @@ -132,6 +132,10 @@ def list_available_backend_types() -> List[BackendType]: return available_backend_types +def list_available_configurator_classes() -> List[type[Configurator]]: + return _CONFIGURATOR_CLASSES + + def register_configurator(configurator: Type[Configurator]): """ A hook to for registering new configurators without importing them. diff --git a/src/dstack/_internal/core/backends/cudo/backend.py b/src/dstack/_internal/core/backends/cudo/backend.py index 06defc36c..389eacbed 100644 --- a/src/dstack/_internal/core/backends/cudo/backend.py +++ b/src/dstack/_internal/core/backends/cudo/backend.py @@ -5,7 +5,8 @@ class CudoBackend(Backend): - TYPE: BackendType = BackendType.CUDO + TYPE = BackendType.CUDO + COMPUTE_CLASS = CudoCompute def __init__(self, config: CudoConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/cudo/compute.py b/src/dstack/_internal/core/backends/cudo/compute.py index 21dfb6599..bcdb89e08 100644 --- a/src/dstack/_internal/core/backends/cudo/compute.py +++ b/src/dstack/_internal/core/backends/cudo/compute.py @@ -4,6 +4,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, generate_unique_instance_name, get_job_instance_name, get_shim_commands, @@ -29,7 +30,10 @@ MAX_RESOURCE_NAME_LEN = 30 -class CudoCompute(Compute): +class CudoCompute( + Compute, + ComputeWithCreateInstanceSupport, +): def __init__(self, config: CudoConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/cudo/configurator.py b/src/dstack/_internal/core/backends/cudo/configurator.py index a34916e62..ee2da8381 100644 --- a/src/dstack/_internal/core/backends/cudo/configurator.py +++ b/src/dstack/_internal/core/backends/cudo/configurator.py @@ -29,7 +29,8 @@ class CudoConfigurator(Configurator): - TYPE: BackendType = BackendType.CUDO + TYPE = BackendType.CUDO + BACKEND_CLASS = CudoBackend def validate_config(self, config: CudoBackendConfigWithCreds, default_creds_enabled: bool): self._validate_cudo_api_key(config.creds.api_key) diff --git a/src/dstack/_internal/core/backends/datacrunch/backend.py b/src/dstack/_internal/core/backends/datacrunch/backend.py index 93b594d28..ef18d5729 100644 --- a/src/dstack/_internal/core/backends/datacrunch/backend.py +++ b/src/dstack/_internal/core/backends/datacrunch/backend.py @@ -5,7 +5,8 @@ class DataCrunchBackend(Backend): - TYPE: BackendType = BackendType.DATACRUNCH + TYPE = BackendType.DATACRUNCH + COMPUTE_CLASS = DataCrunchCompute def __init__(self, config: DataCrunchConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index f9373e32f..9a86869f0 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -2,6 +2,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, generate_unique_instance_name, get_shim_commands, ) @@ -33,7 +34,10 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=IMAGE_SIZE, max=None) -class DataCrunchCompute(Compute): +class DataCrunchCompute( + Compute, + ComputeWithCreateInstanceSupport, +): def __init__(self, config: DataCrunchConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/datacrunch/configurator.py b/src/dstack/_internal/core/backends/datacrunch/configurator.py index c477ce9e8..359d811e1 100644 --- a/src/dstack/_internal/core/backends/datacrunch/configurator.py +++ b/src/dstack/_internal/core/backends/datacrunch/configurator.py @@ -26,7 +26,8 @@ class DataCrunchConfigurator(Configurator): - TYPE: BackendType = BackendType.DATACRUNCH + TYPE = BackendType.DATACRUNCH + BACKEND_CLASS = DataCrunchBackend def validate_config( self, config: DataCrunchBackendConfigWithCreds, default_creds_enabled: bool diff --git a/src/dstack/_internal/core/backends/gcp/backend.py b/src/dstack/_internal/core/backends/gcp/backend.py index 84f13f172..f3720fc1a 100644 --- a/src/dstack/_internal/core/backends/gcp/backend.py +++ b/src/dstack/_internal/core/backends/gcp/backend.py @@ -5,7 +5,8 @@ class GCPBackend(Backend): - TYPE: BackendType = BackendType.GCP + TYPE = BackendType.GCP + COMPUTE_CLASS = GCPCompute def __init__(self, config: GCPConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index c8231b28c..6a344142a 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -12,6 +12,10 @@ import dstack._internal.core.backends.gcp.resources as gcp_resources from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithCreateInstanceSupport, + ComputeWithGatewaySupport, + ComputeWithMultinodeSupport, + ComputeWithVolumeSupport, generate_unique_gateway_instance_name, generate_unique_instance_name, generate_unique_volume_name, @@ -69,7 +73,13 @@ class GCPVolumeDiskBackendData(CoreModel): disk_type: str -class GCPCompute(Compute): +class GCPCompute( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, + ComputeWithGatewaySupport, + ComputeWithVolumeSupport, +): def __init__(self, config: GCPConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/gcp/configurator.py b/src/dstack/_internal/core/backends/gcp/configurator.py index 6f7449650..426d05aa4 100644 --- a/src/dstack/_internal/core/backends/gcp/configurator.py +++ b/src/dstack/_internal/core/backends/gcp/configurator.py @@ -111,7 +111,8 @@ class GCPConfigurator(Configurator): - TYPE: BackendType = BackendType.GCP + TYPE = BackendType.GCP + BACKEND_CLASS = GCPBackend def validate_config(self, config: GCPBackendConfigWithCreds, default_creds_enabled: bool): if is_core_model_instance(config.creds, GCPDefaultCreds) and not default_creds_enabled: diff --git a/src/dstack/_internal/core/backends/kubernetes/backend.py b/src/dstack/_internal/core/backends/kubernetes/backend.py index 5905950f8..4a1ca4a24 100644 --- a/src/dstack/_internal/core/backends/kubernetes/backend.py +++ b/src/dstack/_internal/core/backends/kubernetes/backend.py @@ -5,7 +5,8 @@ class KubernetesBackend(Backend): - TYPE: BackendType = BackendType.KUBERNETES + TYPE = BackendType.KUBERNETES + COMPUTE_CLASS = KubernetesCompute def __init__(self, config: KubernetesConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 3fc38380e..c1932acab 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -9,6 +9,7 @@ from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithGatewaySupport, generate_unique_gateway_instance_name, generate_unique_instance_name_for_job, get_docker_commands, @@ -54,7 +55,10 @@ NVIDIA_GPU_NAMES = NVIDIA_GPU_NAME_TO_GPU_INFO.keys() -class KubernetesCompute(Compute): +class KubernetesCompute( + Compute, + ComputeWithGatewaySupport, +): def __init__(self, config: KubernetesConfig): super().__init__() self.config = config.copy() diff --git a/src/dstack/_internal/core/backends/kubernetes/configurator.py b/src/dstack/_internal/core/backends/kubernetes/configurator.py index 972e2dfab..76482a143 100644 --- a/src/dstack/_internal/core/backends/kubernetes/configurator.py +++ b/src/dstack/_internal/core/backends/kubernetes/configurator.py @@ -19,7 +19,8 @@ class KubernetesConfigurator(Configurator): - TYPE: BackendType = BackendType.KUBERNETES + TYPE = BackendType.KUBERNETES + BACKEND_CLASS = KubernetesBackend def validate_config( self, config: KubernetesBackendConfigWithCreds, default_creds_enabled: bool diff --git a/src/dstack/_internal/core/backends/lambdalabs/backend.py b/src/dstack/_internal/core/backends/lambdalabs/backend.py index 2c9478514..471f9540d 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/backend.py +++ b/src/dstack/_internal/core/backends/lambdalabs/backend.py @@ -5,7 +5,8 @@ class LambdaBackend(Backend): - TYPE: BackendType = BackendType.LAMBDA + TYPE = BackendType.LAMBDA + COMPUTE_CLASS = LambdaCompute def __init__(self, config: LambdaConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/lambdalabs/compute.py b/src/dstack/_internal/core/backends/lambdalabs/compute.py index 281095a21..63e377a4d 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/compute.py +++ b/src/dstack/_internal/core/backends/lambdalabs/compute.py @@ -6,6 +6,7 @@ from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithCreateInstanceSupport, generate_unique_instance_name, get_job_instance_name, get_shim_commands, @@ -27,7 +28,10 @@ MAX_INSTANCE_NAME_LEN = 60 -class LambdaCompute(Compute): +class LambdaCompute( + Compute, + ComputeWithCreateInstanceSupport, +): def __init__(self, config: LambdaConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/lambdalabs/configurator.py b/src/dstack/_internal/core/backends/lambdalabs/configurator.py index 0c9775a8c..f92143d6e 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/configurator.py +++ b/src/dstack/_internal/core/backends/lambdalabs/configurator.py @@ -40,7 +40,8 @@ class LambdaConfigurator(Configurator): - TYPE: BackendType = BackendType.LAMBDA + TYPE = BackendType.LAMBDA + BACKEND_CLASS = LambdaBackend def validate_config(self, config: LambdaBackendConfigWithCreds, default_creds_enabled: bool): self._validate_lambda_api_key(config.creds.api_key) diff --git a/src/dstack/_internal/core/backends/local/backend.py b/src/dstack/_internal/core/backends/local/backend.py index 41d79f3b8..b279bc7b5 100644 --- a/src/dstack/_internal/core/backends/local/backend.py +++ b/src/dstack/_internal/core/backends/local/backend.py @@ -4,7 +4,8 @@ class LocalBackend(Backend): - TYPE: BackendType = BackendType.LOCAL + TYPE = BackendType.LOCAL + COMPUTE_CLASS = LocalCompute def __init__(self): self._compute = LocalCompute() diff --git a/src/dstack/_internal/core/backends/local/compute.py b/src/dstack/_internal/core/backends/local/compute.py index 1a158798f..86ee04de6 100644 --- a/src/dstack/_internal/core/backends/local/compute.py +++ b/src/dstack/_internal/core/backends/local/compute.py @@ -1,6 +1,10 @@ from typing import List, Optional -from dstack._internal.core.backends.base.compute import Compute +from dstack._internal.core.backends.base.compute import ( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithVolumeSupport, +) from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import ( @@ -18,7 +22,11 @@ logger = get_logger(__name__) -class LocalCompute(Compute): +class LocalCompute( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithVolumeSupport, +): def get_offers( self, requirements: Optional[Requirements] = None ) -> List[InstanceOfferWithAvailability]: @@ -85,6 +93,12 @@ def run_job( backend_data=None, ) + def register_volume(self, volume: Volume) -> VolumeProvisioningData: + return VolumeProvisioningData( + volume_id=volume.volume_id, + size_gb=volume.configuration.size_gb, + ) + def create_volume(self, volume: Volume) -> VolumeProvisioningData: return VolumeProvisioningData( volume_id=volume.name, diff --git a/src/dstack/_internal/core/backends/nebius/backend.py b/src/dstack/_internal/core/backends/nebius/backend.py index c6303c2d9..78db77aa0 100644 --- a/src/dstack/_internal/core/backends/nebius/backend.py +++ b/src/dstack/_internal/core/backends/nebius/backend.py @@ -5,7 +5,8 @@ class NebiusBackend(Backend): - TYPE: BackendType = BackendType.NEBIUS + TYPE = BackendType.NEBIUS + COMPUTE_CLASS = NebiusCompute def __init__(self, config: NebiusConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/nebius/configurator.py b/src/dstack/_internal/core/backends/nebius/configurator.py index 2ef561e29..7b6d4b40d 100644 --- a/src/dstack/_internal/core/backends/nebius/configurator.py +++ b/src/dstack/_internal/core/backends/nebius/configurator.py @@ -24,7 +24,8 @@ class NebiusConfigurator(Configurator): - TYPE: BackendType = BackendType.NEBIUS + TYPE = BackendType.NEBIUS + BACKEND_CLASS = NebiusBackend def validate_config(self, config: NebiusBackendConfigWithCreds, default_creds_enabled: bool): self._validate_nebius_creds(config.creds) diff --git a/src/dstack/_internal/core/backends/oci/backend.py b/src/dstack/_internal/core/backends/oci/backend.py index a1b1e96fc..7d1d5dfbd 100644 --- a/src/dstack/_internal/core/backends/oci/backend.py +++ b/src/dstack/_internal/core/backends/oci/backend.py @@ -5,7 +5,8 @@ class OCIBackend(Backend): - TYPE: BackendType = BackendType.OCI + TYPE = BackendType.OCI + COMPUTE_CLASS = OCICompute def __init__(self, config: OCIConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/oci/compute.py b/src/dstack/_internal/core/backends/oci/compute.py index dc7bf8426..8aa5263a4 100644 --- a/src/dstack/_internal/core/backends/oci/compute.py +++ b/src/dstack/_internal/core/backends/oci/compute.py @@ -6,6 +6,8 @@ from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, generate_unique_instance_name, get_job_instance_name, get_user_data, @@ -46,7 +48,11 @@ CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("50GB"), max=Memory.parse("32TB")) -class OCICompute(Compute): +class OCICompute( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, +): def __init__(self, config: OCIConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/oci/configurator.py b/src/dstack/_internal/core/backends/oci/configurator.py index df5166f37..546d0393f 100644 --- a/src/dstack/_internal/core/backends/oci/configurator.py +++ b/src/dstack/_internal/core/backends/oci/configurator.py @@ -44,7 +44,8 @@ class OCIConfigurator(Configurator): - TYPE: BackendType = BackendType.OCI + TYPE = BackendType.OCI + BACKEND_CLASS = OCIBackend def validate_config(self, config: OCIBackendConfigWithCreds, default_creds_enabled: bool): if is_core_model_instance(config.creds, OCIDefaultCreds) and not default_creds_enabled: diff --git a/src/dstack/_internal/core/backends/runpod/backend.py b/src/dstack/_internal/core/backends/runpod/backend.py index c8c27bdfe..fff45ed6a 100644 --- a/src/dstack/_internal/core/backends/runpod/backend.py +++ b/src/dstack/_internal/core/backends/runpod/backend.py @@ -5,7 +5,8 @@ class RunpodBackend(Backend): - TYPE: BackendType = BackendType.RUNPOD + TYPE = BackendType.RUNPOD + COMPUTE_CLASS = RunpodCompute def __init__(self, config: RunpodConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/runpod/compute.py b/src/dstack/_internal/core/backends/runpod/compute.py index e20e9cd5a..5edbc14b0 100644 --- a/src/dstack/_internal/core/backends/runpod/compute.py +++ b/src/dstack/_internal/core/backends/runpod/compute.py @@ -5,6 +5,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithVolumeSupport, generate_unique_instance_name, generate_unique_volume_name, get_docker_commands, @@ -39,7 +40,10 @@ CONTAINER_REGISTRY_AUTH_CLEANUP_INTERVAL = 60 * 60 * 24 # 24 hour -class RunpodCompute(Compute): +class RunpodCompute( + Compute, + ComputeWithVolumeSupport, +): _last_cleanup_time = None def __init__(self, config: RunpodConfig): diff --git a/src/dstack/_internal/core/backends/runpod/configurator.py b/src/dstack/_internal/core/backends/runpod/configurator.py index 8fe8ac27b..9d55bc94d 100644 --- a/src/dstack/_internal/core/backends/runpod/configurator.py +++ b/src/dstack/_internal/core/backends/runpod/configurator.py @@ -18,7 +18,8 @@ class RunpodConfigurator(Configurator): - TYPE: BackendType = BackendType.RUNPOD + TYPE = BackendType.RUNPOD + BACKEND_CLASS = RunpodBackend def validate_config(self, config: RunpodBackendConfigWithCreds, default_creds_enabled: bool): self._validate_runpod_api_key(config.creds.api_key) diff --git a/src/dstack/_internal/core/backends/tensordock/backend.py b/src/dstack/_internal/core/backends/tensordock/backend.py index 34efdca1a..b28ad8108 100644 --- a/src/dstack/_internal/core/backends/tensordock/backend.py +++ b/src/dstack/_internal/core/backends/tensordock/backend.py @@ -5,7 +5,8 @@ class TensorDockBackend(Backend): - TYPE: BackendType = BackendType.TENSORDOCK + TYPE = BackendType.TENSORDOCK + COMPUTE_CLASS = TensorDockCompute def __init__(self, config: TensorDockConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/tensordock/compute.py b/src/dstack/_internal/core/backends/tensordock/compute.py index 8a81a711b..fbeb16470 100644 --- a/src/dstack/_internal/core/backends/tensordock/compute.py +++ b/src/dstack/_internal/core/backends/tensordock/compute.py @@ -5,6 +5,7 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, generate_unique_instance_name, get_job_instance_name, get_shim_commands, @@ -31,7 +32,10 @@ MAX_INSTANCE_NAME_LEN = 60 -class TensorDockCompute(Compute): +class TensorDockCompute( + Compute, + ComputeWithCreateInstanceSupport, +): def __init__(self, config: TensorDockConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/tensordock/configurator.py b/src/dstack/_internal/core/backends/tensordock/configurator.py index 192fd820c..db9688032 100644 --- a/src/dstack/_internal/core/backends/tensordock/configurator.py +++ b/src/dstack/_internal/core/backends/tensordock/configurator.py @@ -24,7 +24,8 @@ class TensorDockConfigurator(Configurator): - TYPE: BackendType = BackendType.TENSORDOCK + TYPE = BackendType.TENSORDOCK + BACKEND_CLASS = TensorDockBackend def validate_config( self, config: TensorDockBackendConfigWithCreds, default_creds_enabled: bool diff --git a/src/dstack/_internal/core/backends/vastai/backend.py b/src/dstack/_internal/core/backends/vastai/backend.py index 4b78b9ee0..9dbc99334 100644 --- a/src/dstack/_internal/core/backends/vastai/backend.py +++ b/src/dstack/_internal/core/backends/vastai/backend.py @@ -5,7 +5,8 @@ class VastAIBackend(Backend): - TYPE: BackendType = BackendType.VASTAI + TYPE = BackendType.VASTAI + COMPUTE_CLASS = VastAICompute def __init__(self, config: VastAIConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/vastai/configurator.py b/src/dstack/_internal/core/backends/vastai/configurator.py index eaed6667e..872c2a948 100644 --- a/src/dstack/_internal/core/backends/vastai/configurator.py +++ b/src/dstack/_internal/core/backends/vastai/configurator.py @@ -24,7 +24,8 @@ class VastAIConfigurator(Configurator): - TYPE: BackendType = BackendType.VASTAI + TYPE = BackendType.VASTAI + BACKEND_CLASS = VastAIBackend def validate_config(self, config: VastAIBackendConfigWithCreds, default_creds_enabled: bool): self._validate_vastai_creds(config.creds.api_key) diff --git a/src/dstack/_internal/core/backends/vultr/backend.py b/src/dstack/_internal/core/backends/vultr/backend.py index 71fd4c4c3..8e27410c8 100644 --- a/src/dstack/_internal/core/backends/vultr/backend.py +++ b/src/dstack/_internal/core/backends/vultr/backend.py @@ -5,7 +5,8 @@ class VultrBackend(Backend): - TYPE: BackendType = BackendType.VULTR + TYPE = BackendType.VULTR + COMPUTE_CLASS = VultrCompute def __init__(self, config: VultrConfig): self.config = config diff --git a/src/dstack/_internal/core/backends/vultr/compute.py b/src/dstack/_internal/core/backends/vultr/compute.py index 06c1e2ae3..074c1d536 100644 --- a/src/dstack/_internal/core/backends/vultr/compute.py +++ b/src/dstack/_internal/core/backends/vultr/compute.py @@ -6,6 +6,8 @@ from dstack._internal.core.backends.base.backend import Compute from dstack._internal.core.backends.base.compute import ( + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, generate_unique_instance_name, get_job_instance_name, get_user_data, @@ -31,7 +33,11 @@ MAX_INSTANCE_NAME_LEN = 64 -class VultrCompute(Compute): +class VultrCompute( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, +): def __init__(self, config: VultrConfig): super().__init__() self.config = config diff --git a/src/dstack/_internal/core/backends/vultr/configurator.py b/src/dstack/_internal/core/backends/vultr/configurator.py index 5654a4544..b4d22fccd 100644 --- a/src/dstack/_internal/core/backends/vultr/configurator.py +++ b/src/dstack/_internal/core/backends/vultr/configurator.py @@ -23,7 +23,8 @@ class VultrConfigurator(Configurator): - TYPE: BackendType = BackendType.VULTR + TYPE = BackendType.VULTR + BACKEND_CLASS = VultrBackend def validate_config(self, config: VultrBackendConfigWithCreds, default_creds_enabled: bool): self._validate_vultr_api_key(config.creds.api_key) diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index a601f240c..62e434571 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -20,6 +20,8 @@ DSTACK_RUNNER_BINARY_PATH, DSTACK_SHIM_BINARY_PATH, DSTACK_WORKING_DIR, + ComputeWithCreateInstanceSupport, + ComputeWithPlacementGroupSupport, get_shim_env, get_shim_pre_start_commands, ) @@ -530,12 +532,15 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No for backend, instance_offer in offers: if instance_offer.backend not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT: continue + compute = backend.compute() + assert isinstance(compute, ComputeWithCreateInstanceSupport) instance_offer = _get_instance_offer_for_instance(instance_offer, instance) if ( instance_offer.backend in BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT and instance.fleet and instance_configuration.placement_group_name ): + assert isinstance(compute, ComputeWithPlacementGroupSupport) placement_group_model = _create_placement_group_if_does_not_exist( session=session, fleet_model=instance.fleet, @@ -546,7 +551,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No ) if placement_group_model is not None: placement_group = placement_group_model_to_placement_group(placement_group_model) - pgpd = await run_async(backend.compute().create_placement_group, placement_group) + pgpd = await run_async(compute.create_placement_group, placement_group) placement_group_model.provisioning_data = pgpd.json() session.add(placement_group_model) placement_groups.append(placement_group) @@ -559,7 +564,7 @@ async def _create_instance(session: AsyncSession, instance: InstanceModel) -> No ) try: job_provisioning_data = await run_async( - backend.compute().create_instance, + compute.create_instance, instance_offer, instance_configuration, ) diff --git a/src/dstack/_internal/server/background/tasks/process_placement_groups.py b/src/dstack/_internal/server/background/tasks/process_placement_groups.py index 3bfd1f20f..fb98f674d 100644 --- a/src/dstack/_internal/server/background/tasks/process_placement_groups.py +++ b/src/dstack/_internal/server/background/tasks/process_placement_groups.py @@ -5,6 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload +from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport from dstack._internal.core.errors import PlacementGroupInUseError from dstack._internal.server.db import get_session_ctx from dstack._internal.server.models import PlacementGroupModel, ProjectModel @@ -81,8 +82,10 @@ async def _delete_placement_group(placement_group_model: PlacementGroupModel): "Failed to delete placement group %s. Backend not available.", placement_group.name ) return + compute = backend.compute() + assert isinstance(compute, ComputeWithPlacementGroupSupport) try: - await run_async(backend.compute().delete_placement_group, placement_group) + await run_async(compute.delete_placement_group, placement_group) except PlacementGroupInUseError: logger.info( "Placement group %s is still in use. Skipping deletion for now.", placement_group.name diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index ab0efeff1..400f568a2 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import joinedload, lazyload, selectinload from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport from dstack._internal.core.errors import BackendError, ServerClientError from dstack._internal.core.models.common import NetworkMode from dstack._internal.core.models.fleets import ( @@ -700,13 +701,15 @@ async def _attach_volume( instance: InstanceModel, instance_id: str, ): + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) volume = volume_model_to_volume(volume_model) # Refresh only to check if the volume wasn't deleted before the lock await session.refresh(volume_model) if volume_model.deleted: raise ServerClientError("Cannot attach a deleted volume") attachment_data = await common_utils.run_async( - backend.compute().attach_volume, + compute.attach_volume, volume=volume, instance_id=instance_id, ) diff --git a/src/dstack/_internal/server/background/tasks/process_volumes.py b/src/dstack/_internal/server/background/tasks/process_volumes.py index 66831a0d6..cf07f15d4 100644 --- a/src/dstack/_internal/server/background/tasks/process_volumes.py +++ b/src/dstack/_internal/server/background/tasks/process_volumes.py @@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport from dstack._internal.core.errors import BackendError, BackendNotAvailable from dstack._internal.core.models.volumes import VolumeStatus from dstack._internal.server.db import get_session_ctx @@ -81,17 +82,19 @@ async def _process_submitted_volume(session: AsyncSession, volume_model: VolumeM await session.commit() return + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) try: if volume.configuration.volume_id is not None: logger.info("Registering external volume %s", volume_model.name) vpd = await run_async( - backend.compute().register_volume, + compute.register_volume, volume=volume, ) else: logger.info("Provisioning new volume %s", volume_model.name) vpd = await run_async( - backend.compute().create_volume, + compute.create_volume, volume=volume, ) except BackendError as e: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 2f192ee56..2d7f82c1d 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -16,6 +16,7 @@ ) from dstack._internal.core.backends.base.compute import ( Compute, + ComputeWithGatewaySupport, get_dstack_gateway_wheel, get_dstack_runner_version, ) @@ -91,6 +92,7 @@ async def create_gateway_compute( configuration: GatewayConfiguration, backend_id: Optional[uuid.UUID] = None, ) -> GatewayComputeModel: + assert isinstance(backend_compute, ComputeWithGatewaySupport) private_bytes, public_bytes = generate_rsa_key_pair_bytes() gateway_ssh_private_key = private_bytes.decode() gateway_ssh_public_key = public_bytes.decode() @@ -228,6 +230,8 @@ async def delete_gateways( backend = await get_project_backend_by_type_or_error( project=project, backend_type=gateway_model.backend.type ) + compute = backend.compute() + assert isinstance(compute, ComputeWithGatewaySupport) gateway_compute_configuration = get_gateway_compute_configuration(gateway_model) if ( gateway_model.gateway_compute is not None @@ -236,7 +240,7 @@ async def delete_gateways( logger.info("Deleting gateway compute for %s...", gateway_model.name) try: await run_async( - backend.compute().terminate_gateway, + compute.terminate_gateway, gateway_model.gateway_compute.instance_id, gateway_compute_configuration, gateway_model.gateway_compute.backend_data, diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 75a0c4a7b..a906e23fd 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -11,6 +11,7 @@ import dstack._internal.server.services.backends as backends_services from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT from dstack._internal.core.errors import ( BackendError, @@ -461,24 +462,26 @@ async def _detach_volume_from_job_instance( if volume.provisioning_data is None or not volume.provisioning_data.detachable: # Backends without `detach_volume` detach volumes automatically return detached + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) try: if job_model.volumes_detached_at is None: # We haven't tried detaching volumes yet, try soft detach first await run_async( - backend.compute().detach_volume, + compute.detach_volume, volume=volume, instance_id=jpd.instance_id, force=False, ) # For some backends, the volume may be detached immediately detached = await run_async( - backend.compute().is_volume_detached, + compute.is_volume_detached, volume=volume, instance_id=jpd.instance_id, ) else: detached = await run_async( - backend.compute().is_volume_detached, + compute.is_volume_detached, volume=volume, instance_id=jpd.instance_id, ) @@ -489,7 +492,7 @@ async def _detach_volume_from_job_instance( instance_model.name, ) await run_async( - backend.compute().detach_volume, + compute.detach_volume, volume=volume, instance_id=jpd.instance_id, force=True, diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 747af9c24..5157b2b45 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import joinedload, selectinload from dstack._internal.core.backends import BACKENDS_WITH_VOLUMES_SUPPORT +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport from dstack._internal.core.errors import ( BackendNotAvailable, ResourceExistsError, @@ -409,7 +410,9 @@ async def _delete_volume(session: AsyncSession, project: ProjectModel, volume_mo ) return + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) await common.run_async( - backend.compute().delete_volume, + compute.delete_volume, volume=volume, ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 696f3bf22..a1112ebbc 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -8,6 +8,16 @@ import gpuhunt from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.backends.base.compute import ( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithGatewaySupport, + ComputeWithMultinodeSupport, + ComputeWithPlacementGroupSupport, + ComputeWithPrivateGatewaySupport, + ComputeWithReservationSupport, + ComputeWithVolumeSupport, +) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode from dstack._internal.core.models.configurations import ( @@ -947,3 +957,20 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc, traceback): pass + + +class ComputeMockSpec( + Compute, + ComputeWithCreateInstanceSupport, + ComputeWithMultinodeSupport, + ComputeWithReservationSupport, + ComputeWithPlacementGroupSupport, + ComputeWithGatewaySupport, + ComputeWithPrivateGatewaySupport, + ComputeWithVolumeSupport, +): + """ + Can be used to create Compute mocks that pass all isinstance asserts. + """ + + pass diff --git a/src/tests/_internal/server/background/tasks/test_process_gateways.py b/src/tests/_internal/server/background/tasks/test_process_gateways.py index 9c4bcf029..159547af4 100644 --- a/src/tests/_internal/server/background/tasks/test_process_gateways.py +++ b/src/tests/_internal/server/background/tasks/test_process_gateways.py @@ -8,6 +8,7 @@ from dstack._internal.server.background.tasks.process_gateways import process_submitted_gateways from dstack._internal.server.testing.common import ( AsyncContextManager, + ComputeMockSpec, create_backend, create_gateway, create_project, @@ -37,6 +38,7 @@ async def test_provisions_gateway(self, test_db, session: AsyncSession): m.return_value = (backend, aws) pool_add.return_value = MagicMock() pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) + aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( instance_id="i-1234567890", ip_address="2.2.2.2", @@ -68,6 +70,7 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( ) as m: aws = Mock() m.return_value = (backend, aws) + aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.create_gateway.side_effect = BackendError("Some error") await process_submitted_gateways() m.assert_called_once() @@ -99,6 +102,7 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( aws = Mock() m.return_value = (backend, aws) connect_to_gateway_with_retry_mock.return_value = None + aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.create_gateway.return_value = GatewayProvisioningData( instance_id="i-1234567890", ip_address="2.2.2.2", diff --git a/src/tests/_internal/server/background/tasks/test_process_instances.py b/src/tests/_internal/server/background/tasks/test_process_instances.py index e13454097..8041c514d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_instances.py +++ b/src/tests/_internal/server/background/tasks/test_process_instances.py @@ -28,6 +28,7 @@ process_instances, ) from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_instance, create_job, create_pool, @@ -531,6 +532,7 @@ async def test_creates_instance( price=1.0, availability=InstanceAvailability.AVAILABLE, ) + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value.get_offers_cached.return_value = [offer] backend_mock.compute.return_value.create_instance.return_value = JobProvisioningData( backend=offer.backend, diff --git a/src/tests/_internal/server/background/tasks/test_process_placement_groups.py b/src/tests/_internal/server/background/tasks/test_process_placement_groups.py index becc8ec7c..a45051a48 100644 --- a/src/tests/_internal/server/background/tasks/test_process_placement_groups.py +++ b/src/tests/_internal/server/background/tasks/test_process_placement_groups.py @@ -7,6 +7,7 @@ process_placement_groups, ) from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_fleet, create_placement_group, create_project, @@ -34,6 +35,7 @@ async def test_deletes_placement_groups(self, test_db, session: AsyncSession): with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: aws_mock = Mock() m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) await process_placement_groups() aws_mock.compute.return_value.delete_placement_group.assert_called_once() await session.refresh(placement_group1) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 346d01b29..98054b5fb 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -29,6 +29,7 @@ from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_fleet, create_instance, create_job, @@ -506,6 +507,7 @@ async def test_assigns_job_to_instance_with_volumes(self, test_db, session: Asyn backend_mock = Mock() m.return_value = backend_mock backend_mock.TYPE = BackendType.AWS + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value.attach_volume.return_value = VolumeAttachmentData() # Submitted jobs processing happens in two steps await process_submitted_jobs() diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py b/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py index 86fbed603..0b2f0b194 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py @@ -7,6 +7,7 @@ from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_project, create_user, create_volume, @@ -40,6 +41,7 @@ async def test_provisiones_volumes(self, test_db, session: AsyncSession): ) as m: aws_mock = Mock() m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) aws_mock.compute.return_value.create_volume.return_value = VolumeProvisioningData( backend=BackendType.AWS, volume_id="1234", diff --git a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py index 937cc5156..531c3c86b 100644 --- a/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py @@ -16,6 +16,7 @@ from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_instance, create_job, create_pool, @@ -114,6 +115,7 @@ async def test_detaches_job_volumes(self, session: AsyncSession): with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: backend_mock = Mock() m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value.is_volume_detached.return_value = True await process_terminating_jobs() m.assert_awaited_once() @@ -163,6 +165,7 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession): with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: backend_mock = Mock() m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value.is_volume_detached.return_value = False await process_terminating_jobs() m.assert_awaited_once() @@ -188,6 +191,7 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession): ) + timedelta(minutes=30) backend_mock = Mock() m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value.is_volume_detached.return_value = False await process_terminating_jobs() m.assert_awaited_once() @@ -205,6 +209,7 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession): with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: backend_mock = Mock() m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value.is_volume_detached.return_value = True await process_terminating_jobs() m.assert_awaited_once() @@ -319,13 +324,15 @@ async def test_detaches_job_volumes_on_shared_instance(self, session: AsyncSessi with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: backend_mock = Mock() m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) backend_mock.compute.return_value.is_volume_detached.return_value = True await process_terminating_jobs() - m.assert_awaited_once() - backend_mock.compute.return_value.detach_volume.assert_called_once() - backend_mock.compute.return_value.is_volume_detached.assert_called_once() + m.assert_awaited_once() + backend_mock.compute.return_value.detach_volume.assert_called_once() + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + await session.refresh(job) await session.refresh(instance) assert job.status == JobStatus.TERMINATED diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index e85268074..5840735c2 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -13,6 +13,7 @@ ) from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_backend, create_gateway, create_gateway_compute, @@ -437,8 +438,10 @@ async def test_delete_gateway(self, test_db, session: AsyncSession, client: Asyn "dstack._internal.server.services.gateways.get_project_backend_by_type_or_error" ) as m: aws = Mock() + aws.compute.return_value = Mock(spec=ComputeMockSpec) aws.compute.return_value.terminate_gateway.return_value = None # success gcp = Mock() + gcp.compute.return_value = Mock(spec=ComputeMockSpec) gcp.compute.return_value.terminate_gateway.side_effect = DstackError() # fail def get_backend(project, backend_type): diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 5795eaddc..300a1aba9 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -51,6 +51,7 @@ from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.services.runs import run_model_to_run from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_backend, create_gateway, create_gateway_compute, @@ -1616,6 +1617,7 @@ async def test_error_if_backends_do_not_support_create_instance( ) backend = Mock() backend.TYPE = BackendType.AZURE + backend.compute.return_value = Mock(spec=ComputeMockSpec) backend.compute.return_value.get_offers_cached.return_value = [offer] backend.compute.return_value.create_instance.side_effect = NotImplementedError() run_plan_by_req.return_value = [(backend, offer)] @@ -1658,6 +1660,7 @@ async def test_backend_does_not_support_create_instance( backend = Mock() backend.TYPE = BackendType.VASTAI + backend.compute.return_value = Mock(spec=ComputeMockSpec) backend.compute.return_value.get_offers_cached.return_value = [offers] backend.compute.return_value.create_instance.side_effect = NotImplementedError() run_plan_by_req.return_value = [(backend, offers)] diff --git a/src/tests/_internal/server/routers/test_volumes.py b/src/tests/_internal/server/routers/test_volumes.py index 29106b758..21e73e331 100644 --- a/src/tests/_internal/server/routers/test_volumes.py +++ b/src/tests/_internal/server/routers/test_volumes.py @@ -14,6 +14,7 @@ from dstack._internal.server.models import VolumeAttachmentModel, VolumeModel from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + ComputeMockSpec, create_instance, create_pool, create_project, @@ -366,6 +367,7 @@ async def test_deletes_volumes(self, test_db, session: AsyncSession, client: Asy ) as m: aws_mock = Mock() m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) response = await client.post( f"/api/project/{project.name}/volumes/delete", headers=get_auth_headers(user.token), From da4a1698956919b6cbaffa7aaa38de0eb9170176 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 10 Mar 2025 10:56:53 +0500 Subject: [PATCH 2/5] Filter out offers based on offer.backend instead of backend.TYPE --- src/dstack/_internal/server/services/fleets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index a0b37d7b0..5cb140e92 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -308,7 +308,7 @@ async def get_create_instance_offers( offers = [ (backend, offer) for backend, offer in offers - if backend.TYPE in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT + if offer.backend in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT ] return offers @@ -600,7 +600,7 @@ async def create_instance( ) # Raise error if no backends suppport create_instance - backend_types = set((backend.TYPE for backend, _ in offers)) + backend_types = set(offer.backend for _, offer in offers) if all( (backend_type not in BACKENDS_WITH_CREATE_INSTANCE_SUPPORT) for backend_type in backend_types From 1dbeafb3862926a3373b1b292d209489e7a90fd5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 10 Mar 2025 11:30:27 +0500 Subject: [PATCH 3/5] Check backend type availability for volume and gateway configurations --- src/dstack/_internal/core/backends/__init__.py | 1 + .../_internal/core/backends/configurators.py | 9 +++++---- .../_internal/server/services/backends/__init__.py | 14 +++++++++++++- .../_internal/server/services/gateways/__init__.py | 10 +++++++--- src/dstack/_internal/server/services/volumes.py | 5 +++-- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/dstack/_internal/core/backends/__init__.py b/src/dstack/_internal/core/backends/__init__.py index 8ffa38df3..76c3210e9 100644 --- a/src/dstack/_internal/core/backends/__init__.py +++ b/src/dstack/_internal/core/backends/__init__.py @@ -27,6 +27,7 @@ def _get_backends_with_compute_feature( _configurator_classes = list_available_configurator_classes() +# The following backend lists do not include unavailable backends (i.e. backends missing deps). # TODO: Add LocalBackend to lists if it's enabled BACKENDS_WITH_CREATE_INSTANCE_SUPPORT = _get_backends_with_compute_feature( configurator_classes=_configurator_classes, diff --git a/src/dstack/_internal/core/backends/configurators.py b/src/dstack/_internal/core/backends/configurators.py index cf1db1fc5..239adf786 100644 --- a/src/dstack/_internal/core/backends/configurators.py +++ b/src/dstack/_internal/core/backends/configurators.py @@ -109,6 +109,7 @@ _BACKEND_TYPE_TO_CONFIGURATOR_CLASS_MAP = {c.TYPE: c for c in _CONFIGURATOR_CLASSES} +_BACKEND_TYPES = [c.TYPE for c in _CONFIGURATOR_CLASSES] def get_configurator(backend_type: Union[BackendType, str]) -> Optional[Configurator]: @@ -126,13 +127,13 @@ def list_available_backend_types() -> List[BackendType]: """ Lists all backend types available on the server. """ - available_backend_types = [] - for configurator_class in _BACKEND_TYPE_TO_CONFIGURATOR_CLASS_MAP.values(): - available_backend_types.append(configurator_class.TYPE) - return available_backend_types + return _BACKEND_TYPES def list_available_configurator_classes() -> List[type[Configurator]]: + """ + Lists all backend configurator classes available on the server. + """ return _CONFIGURATOR_CLASSES diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index cf509cd1e..a6081ea25 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -11,7 +11,10 @@ Configurator, StoredBackendRecord, ) -from dstack._internal.core.backends.configurators import get_configurator +from dstack._internal.core.backends.configurators import ( + get_configurator, + list_available_backend_types, +) from dstack._internal.core.backends.local.backend import LocalBackend from dstack._internal.core.backends.models import ( AnyBackendConfig, @@ -364,3 +367,12 @@ async def get_instance_offers( offers = heapq.merge(*offers_by_backend, key=lambda i: i[1].price) # Put NOT_AVAILABLE, NO_QUOTA, and BUSY instances at the end, do not sort by price return sorted(offers, key=lambda i: not i[1].availability.is_available()) + + +def check_backed_type_available(backend_type: BackendType): + if backend_type not in list_available_backend_types(): + raise BackendNotAvailable( + f"Backend {backend_type.value} not available." + " Ensure that backend dependencies are installed." + f" Available backends: {[b.value for b in list_available_backend_types()]}." + ) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 2d7f82c1d..253e4847c 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -39,6 +39,7 @@ from dstack._internal.server.db import get_db from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel from dstack._internal.server.services.backends import ( + check_backed_type_available, get_project_backend_by_type_or_error, get_project_backend_with_model_by_type_or_error, ) @@ -537,10 +538,12 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway: def _validate_gateway_configuration(configuration: GatewayConfiguration): + check_backed_type_available(configuration.backend) if configuration.backend not in BACKENDS_WITH_GATEWAY_SUPPORT: raise ServerClientError( - f"Gateways are not supported for {configuration.backend.value} backend. " - f"Supported backends: {[b.value for b in BACKENDS_WITH_GATEWAY_SUPPORT]}." + f"Gateways are not supported for {configuration.backend.value} backend." + " Available backends with gateway support:" + f" {[b.value for b in BACKENDS_WITH_GATEWAY_SUPPORT]}." ) if configuration.name is not None: @@ -552,7 +555,8 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): ): raise ServerClientError( f"Private gateways are not supported for {configuration.backend.value} backend. " - f"Supported backends: {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}." + " Available backends with private gateway support:" + f" {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}." ) if configuration.certificate is not None: diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 5157b2b45..17cf5ae99 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -376,10 +376,11 @@ async def generate_volume_name(session: AsyncSession, project: ProjectModel) -> def _validate_volume_configuration(configuration: VolumeConfiguration): if configuration.volume_id is None and configuration.size is None: raise ServerClientError("Volume must specify either volume_id or size") + backends_services.check_backed_type_available(configuration.backend) if configuration.backend not in BACKENDS_WITH_VOLUMES_SUPPORT: raise ServerClientError( - f"Volumes are not supported for {configuration.backend.value} backend. " - f"Supported backends: {[b.value for b in BACKENDS_WITH_VOLUMES_SUPPORT]}." + f"Volumes are not supported for {configuration.backend.value} backend." + f" Available backends with volumes support: {[b.value for b in BACKENDS_WITH_VOLUMES_SUPPORT]}." ) if configuration.name is not None: validate_dstack_resource_name(configuration.name) From 1290743d469231a229d4d2fd81aba9a9022cf566 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 10 Mar 2025 12:43:31 +0500 Subject: [PATCH 4/5] Mention ComputeWith* classes in backends guide --- contributing/BACKENDS.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/contributing/BACKENDS.md b/contributing/BACKENDS.md index d4f0392ac..82d8fb4bf 100644 --- a/contributing/BACKENDS.md +++ b/contributing/BACKENDS.md @@ -117,13 +117,8 @@ Refer to examples: ##### 2.4.4. Create the backend compute class Under the backend directory you've created, create the `compute.py` file and define the -backend compute class there (should extend `dstack._internal.core.backends.base.compute.Compute`). - -You'll have to implement `get_offers`, `run_job` and `terminate_instance`. -You may need to implement `update_provisioning_data`, see its docstring for details. - -For VM-based backends, also implement the `create_instance` method and add the backend name to -[`BACKENDS_WITH_CREATE_INSTANCE_SUPPORT`](`https://github.com/dstackai/dstack/blob/master/src/dstack/_internal/core/backends/__init__.py`). +backend compute class that extends the `dstack._internal.core.backends.base.compute.Compute` class. +It can also extend and implement `ComputeWith*` classes to support additional features such as fleets, volumes, gateways, placement groups, etc. Refer to examples: [datacrunch](https://github.com/dstackai/dstack/blob/master/src/dstack/_internal/core/backends/datacrunch/compute.py), From 0f9417ba2a919f92cd4dc20dada954bc0dfd1ce5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 11 Mar 2025 10:30:47 +0500 Subject: [PATCH 5/5] Fixes after review --- contributing/BACKENDS.md | 2 +- src/dstack/_internal/server/services/backends/__init__.py | 2 +- src/dstack/_internal/server/services/gateways/__init__.py | 4 ++-- src/dstack/_internal/server/services/volumes.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/contributing/BACKENDS.md b/contributing/BACKENDS.md index 82d8fb4bf..272bbeb75 100644 --- a/contributing/BACKENDS.md +++ b/contributing/BACKENDS.md @@ -118,7 +118,7 @@ Refer to examples: Under the backend directory you've created, create the `compute.py` file and define the backend compute class that extends the `dstack._internal.core.backends.base.compute.Compute` class. -It can also extend and implement `ComputeWith*` classes to support additional features such as fleets, volumes, gateways, placement groups, etc. +It can also extend and implement `ComputeWith*` classes to support additional features such as fleets, volumes, gateways, placement groups, etc. For example, it should extend `ComputeWithCreateInstanceSupport` to support fleets. Refer to examples: [datacrunch](https://github.com/dstackai/dstack/blob/master/src/dstack/_internal/core/backends/datacrunch/compute.py), diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index a6081ea25..63992d148 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -369,7 +369,7 @@ async def get_instance_offers( return sorted(offers, key=lambda i: not i[1].availability.is_available()) -def check_backed_type_available(backend_type: BackendType): +def check_backend_type_available(backend_type: BackendType): if backend_type not in list_available_backend_types(): raise BackendNotAvailable( f"Backend {backend_type.value} not available." diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 253e4847c..715829d44 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -39,7 +39,7 @@ from dstack._internal.server.db import get_db from dstack._internal.server.models import GatewayComputeModel, GatewayModel, ProjectModel from dstack._internal.server.services.backends import ( - check_backed_type_available, + check_backend_type_available, get_project_backend_by_type_or_error, get_project_backend_with_model_by_type_or_error, ) @@ -538,7 +538,7 @@ def gateway_model_to_gateway(gateway_model: GatewayModel) -> Gateway: def _validate_gateway_configuration(configuration: GatewayConfiguration): - check_backed_type_available(configuration.backend) + check_backend_type_available(configuration.backend) if configuration.backend not in BACKENDS_WITH_GATEWAY_SUPPORT: raise ServerClientError( f"Gateways are not supported for {configuration.backend.value} backend." diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 17cf5ae99..2da77d48f 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -376,7 +376,7 @@ async def generate_volume_name(session: AsyncSession, project: ProjectModel) -> def _validate_volume_configuration(configuration: VolumeConfiguration): if configuration.volume_id is None and configuration.size is None: raise ServerClientError("Volume must specify either volume_id or size") - backends_services.check_backed_type_available(configuration.backend) + backends_services.check_backend_type_available(configuration.backend) if configuration.backend not in BACKENDS_WITH_VOLUMES_SUPPORT: raise ServerClientError( f"Volumes are not supported for {configuration.backend.value} backend."