diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 5513109b6..ea0b45851 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -22,7 +22,6 @@ generate_unique_instance_name, generate_unique_volume_name, get_gateway_user_data, - get_job_instance_name, get_user_data, merge_tags, ) @@ -39,11 +38,10 @@ InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, - SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData from dstack._internal.core.models.resources import Memory, Range -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.core.models.volumes import ( Volume, VolumeAttachmentData, @@ -69,7 +67,6 @@ class AWSVolumeBackendData(CoreModel): class AWSCompute( - Compute, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithReservationSupport, @@ -77,6 +74,7 @@ class AWSCompute( ComputeWithGatewaySupport, ComputeWithPrivateGatewaySupport, ComputeWithVolumeSupport, + Compute, ): def __init__(self, config: AWSConfig): super().__init__() @@ -285,44 +283,6 @@ def create_instance( continue raise NoCapacityError() - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - # TODO: run_job is the same for vm-based backends, refactor - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), # TODO: generate name - ssh_keys=[ - SSHKey(public=project_ssh_public_key.strip()), - ], - user=run.user, - volumes=volumes, - reservation=run.run_spec.configuration.reservation, - ) - instance_offer = instance_offer.copy() - if len(volumes) > 0: - volume = volumes[0] - if ( - volume.provisioning_data is not None - and volume.provisioning_data.availability_zone is not None - ): - if instance_offer.availability_zones is None: - instance_offer.availability_zones = [ - volume.provisioning_data.availability_zone - ] - instance_offer.availability_zones = [ - z - for z in instance_offer.availability_zones - if z == volume.provisioning_data.availability_zone - ] - return self.create_instance(instance_offer, instance_config) - def create_placement_group( self, placement_group: PlacementGroup, diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 66576722b..f0925b8b9 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -45,7 +45,6 @@ generate_unique_gateway_instance_name, generate_unique_instance_name, get_gateway_user_data, - get_job_instance_name, get_user_data, merge_tags, ) @@ -62,11 +61,9 @@ InstanceOffer, InstanceOfferWithAvailability, InstanceType, - SSHKey, ) from dstack._internal.core.models.resources import Memory, Range -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -75,10 +72,10 @@ class AzureCompute( - Compute, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithGatewaySupport, + Compute, ): def __init__(self, config: AzureConfig, credential: TokenCredential): super().__init__() @@ -198,25 +195,6 @@ def create_instance( backend_data=None, ) - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), # TODO: generate name - ssh_keys=[ - SSHKey(public=project_ssh_public_key.strip()), - ], - user=run.user, - ) - return self.create_instance(instance_offer, instance_config) - def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 7f9290590..c1185d084 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -25,6 +25,7 @@ from dstack._internal.core.models.instances import ( InstanceConfiguration, InstanceOfferWithAvailability, + SSHKey, ) from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run @@ -144,6 +145,51 @@ def create_instance( """ pass + def run_job( + self, + run: Run, + job: Job, + instance_offer: InstanceOfferWithAvailability, + project_ssh_public_key: str, + project_ssh_private_key: str, + volumes: List[Volume], + ) -> JobProvisioningData: + """ + The default `run_job()` implementation for all backends that support `create_instance()`. + Override only if custom `run_job()` behavior is required. + """ + instance_config = InstanceConfiguration( + project_name=run.project_name, + instance_name=get_job_instance_name(run, job), + user=run.user, + ssh_keys=[SSHKey(public=project_ssh_public_key.strip())], + volumes=volumes, + reservation=run.run_spec.configuration.reservation, + ) + instance_offer = instance_offer.copy() + self._restrict_instance_offer_az_to_volumes_az(instance_offer, volumes) + return self.create_instance(instance_offer, instance_config) + + def _restrict_instance_offer_az_to_volumes_az( + self, + instance_offer: InstanceOfferWithAvailability, + volumes: List[Volume], + ): + if len(volumes) == 0: + return + volume = volumes[0] + if ( + volume.provisioning_data is not None + and volume.provisioning_data.availability_zone is not None + ): + if instance_offer.availability_zones is None: + instance_offer.availability_zones = [volume.provisioning_data.availability_zone] + instance_offer.availability_zones = [ + z + for z in instance_offer.availability_zones + if z == volume.provisioning_data.availability_zone + ] + class ComputeWithMultinodeSupport: """ diff --git a/src/dstack/_internal/core/backends/cudo/compute.py b/src/dstack/_internal/core/backends/cudo/compute.py index 977cdd69e..11534f4fc 100644 --- a/src/dstack/_internal/core/backends/cudo/compute.py +++ b/src/dstack/_internal/core/backends/cudo/compute.py @@ -6,7 +6,6 @@ from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, generate_unique_instance_name, - get_job_instance_name, get_shim_commands, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -18,10 +17,8 @@ InstanceAvailability, InstanceConfiguration, InstanceOfferWithAvailability, - SSHKey, ) -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -31,8 +28,8 @@ class CudoCompute( - Compute, ComputeWithCreateInstanceSupport, + Compute, ): def __init__(self, config: CudoConfig): super().__init__() @@ -55,25 +52,6 @@ def get_offers( ] return offers - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), - ssh_keys=[ - SSHKey(public=project_ssh_public_key.strip()), - ], - user=run.user, - ) - return self.create_instance(instance_offer, instance_config) - def create_instance( self, instance_offer: InstanceOfferWithAvailability, diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 7c84df3e0..ff40d2b81 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -15,11 +15,9 @@ InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, - SSHKey, ) from dstack._internal.core.models.resources import Memory, Range -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger logger = get_logger("datacrunch.compute") @@ -35,8 +33,8 @@ class DataCrunchCompute( - Compute, ComputeWithCreateInstanceSupport, + Compute, ): def __init__(self, config: DataCrunchConfig): super().__init__() @@ -152,25 +150,6 @@ def create_instance( backend_data=None, ) - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=job.job_spec.job_name, # TODO: generate name - ssh_keys=[ - SSHKey(public=project_ssh_public_key.strip()), - ], - user=run.user, - ) - return self.create_instance(instance_offer, instance_config) - def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ) -> None: diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 66996fe1c..5ba315ab2 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -20,7 +20,6 @@ generate_unique_instance_name, generate_unique_volume_name, get_gateway_user_data, - get_job_instance_name, get_shim_commands, get_user_data, merge_tags, @@ -46,10 +45,9 @@ InstanceOfferWithAvailability, InstanceType, Resources, - SSHKey, ) from dstack._internal.core.models.resources import Memory, Range -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.core.models.volumes import ( Volume, VolumeAttachmentData, @@ -74,11 +72,11 @@ class GCPVolumeDiskBackendData(CoreModel): class GCPCompute( - Compute, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, ComputeWithGatewaySupport, ComputeWithVolumeSupport, + Compute, ): def __init__(self, config: GCPConfig): super().__init__() @@ -373,44 +371,6 @@ def update_provisioning_data( f"Failed to get instance IP address. Instance status: {instance.status}" ) - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - # TODO: run_job is the same for vm-based backends, refactor - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), # TODO: generate name - ssh_keys=[ - SSHKey(public=project_ssh_public_key.strip()), - ], - user=run.user, - volumes=volumes, - reservation=run.run_spec.configuration.reservation, - ) - instance_offer = instance_offer.copy() - if len(volumes) > 0: - volume = volumes[0] - if ( - volume.provisioning_data is not None - and volume.provisioning_data.availability_zone is not None - ): - if instance_offer.availability_zones is None: - instance_offer.availability_zones = [ - volume.provisioning_data.availability_zone - ] - instance_offer.availability_zones = [ - z - for z in instance_offer.availability_zones - if z == volume.provisioning_data.availability_zone - ] - return self.create_instance(instance_offer, instance_config) - def create_gateway( self, configuration: GatewayComputeConfiguration, diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index aa7c2017c..b5213c74d 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -58,8 +58,8 @@ class KubernetesCompute( - Compute, ComputeWithGatewaySupport, + Compute, ): def __init__(self, config: KubernetesConfig): super().__init__() diff --git a/src/dstack/_internal/core/backends/lambdalabs/compute.py b/src/dstack/_internal/core/backends/lambdalabs/compute.py index e002f3d68..ba5859b9d 100644 --- a/src/dstack/_internal/core/backends/lambdalabs/compute.py +++ b/src/dstack/_internal/core/backends/lambdalabs/compute.py @@ -8,7 +8,6 @@ Compute, ComputeWithCreateInstanceSupport, generate_unique_instance_name, - get_job_instance_name, get_shim_commands, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -20,17 +19,15 @@ InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, - SSHKey, ) -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements MAX_INSTANCE_NAME_LEN = 60 class LambdaCompute( - Compute, ComputeWithCreateInstanceSupport, + Compute, ): def __init__(self, config: LambdaConfig): super().__init__() @@ -106,28 +103,6 @@ def update_provisioning_data( ) thread.start() - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), # TODO: generate name - ssh_keys=[ - SSHKey( - public=project_ssh_public_key.strip(), private=project_ssh_private_key.strip() - ), - SSHKey(public=run.run_spec.ssh_key_pub.strip()), - ], - user=run.user, - ) - return self.create_instance(instance_offer, instance_config) - def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): diff --git a/src/dstack/_internal/core/backends/local/compute.py b/src/dstack/_internal/core/backends/local/compute.py index 86ee04de6..0b3eea8ab 100644 --- a/src/dstack/_internal/core/backends/local/compute.py +++ b/src/dstack/_internal/core/backends/local/compute.py @@ -23,9 +23,9 @@ class LocalCompute( - Compute, ComputeWithCreateInstanceSupport, ComputeWithVolumeSupport, + Compute, ): def get_offers( self, requirements: Optional[Requirements] = None diff --git a/src/dstack/_internal/core/backends/oci/compute.py b/src/dstack/_internal/core/backends/oci/compute.py index 5c41bc2b6..ec3fc3807 100644 --- a/src/dstack/_internal/core/backends/oci/compute.py +++ b/src/dstack/_internal/core/backends/oci/compute.py @@ -9,7 +9,6 @@ ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, generate_unique_instance_name, - get_job_instance_name, get_user_data, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -23,11 +22,9 @@ InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, - SSHKey, ) from dstack._internal.core.models.resources import Memory, Range -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements SUPPORTED_SHAPE_FAMILIES = [ "VM.Standard2.", @@ -49,9 +46,9 @@ class OCICompute( - Compute, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, + Compute, ): def __init__(self, config: OCIConfig): super().__init__() @@ -98,23 +95,6 @@ def get_offers( return offers_with_availability - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), - ssh_keys=[SSHKey(public=project_ssh_public_key.strip())], - user=run.user, - ) - return self.create_instance(instance_offer, instance_config) - def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ) -> None: diff --git a/src/dstack/_internal/core/backends/runpod/compute.py b/src/dstack/_internal/core/backends/runpod/compute.py index 16d873498..5e5950094 100644 --- a/src/dstack/_internal/core/backends/runpod/compute.py +++ b/src/dstack/_internal/core/backends/runpod/compute.py @@ -41,8 +41,8 @@ class RunpodCompute( - Compute, ComputeWithVolumeSupport, + Compute, ): _last_cleanup_time = None diff --git a/src/dstack/_internal/core/backends/tensordock/compute.py b/src/dstack/_internal/core/backends/tensordock/compute.py index 2ca1cec7d..01e605b97 100644 --- a/src/dstack/_internal/core/backends/tensordock/compute.py +++ b/src/dstack/_internal/core/backends/tensordock/compute.py @@ -7,7 +7,6 @@ from dstack._internal.core.backends.base.compute import ( ComputeWithCreateInstanceSupport, generate_unique_instance_name, - get_job_instance_name, get_shim_commands, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -19,10 +18,8 @@ InstanceAvailability, InstanceConfiguration, InstanceOfferWithAvailability, - SSHKey, ) -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -33,8 +30,8 @@ class TensorDockCompute( - Compute, ComputeWithCreateInstanceSupport, + Compute, ): def __init__(self, config: TensorDockConfig): super().__init__() @@ -117,26 +114,6 @@ def create_instance( backend_data=None, ) - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), # TODO: generate name - ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), - SSHKey(public=project_ssh_public_key.strip()), - ], - user=run.user, - ) - return self.create_instance(instance_offer, instance_config) - def terminate_instance( self, instance_id: str, region: str, backend_data: Optional[str] = None ): diff --git a/src/dstack/_internal/core/backends/vultr/compute.py b/src/dstack/_internal/core/backends/vultr/compute.py index fb43dfb1b..f63e9951e 100644 --- a/src/dstack/_internal/core/backends/vultr/compute.py +++ b/src/dstack/_internal/core/backends/vultr/compute.py @@ -9,7 +9,6 @@ ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, generate_unique_instance_name, - get_job_instance_name, get_user_data, ) from dstack._internal.core.backends.base.offers import get_catalog_offers @@ -22,10 +21,8 @@ InstanceConfiguration, InstanceOffer, InstanceOfferWithAvailability, - SSHKey, ) -from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume +from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -34,9 +31,9 @@ class VultrCompute( - Compute, ComputeWithCreateInstanceSupport, ComputeWithMultinodeSupport, + Compute, ): def __init__(self, config: VultrConfig): super().__init__() @@ -60,23 +57,6 @@ def get_offers( ] return offers - def run_job( - self, - run: Run, - job: Job, - instance_offer: InstanceOfferWithAvailability, - project_ssh_public_key: str, - project_ssh_private_key: str, - volumes: List[Volume], - ) -> JobProvisioningData: - instance_config = InstanceConfiguration( - project_name=run.project_name, - instance_name=get_job_instance_name(run, job), - ssh_keys=[SSHKey(public=project_ssh_public_key.strip())], - user=run.user, - ) - return self.create_instance(instance_offer, instance_config) - def create_instance( self, instance_offer: InstanceOfferWithAvailability, instance_config: InstanceConfiguration ) -> JobProvisioningData: