Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 2 additions & 42 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
generate_unique_instance_name,
generate_unique_volume_name,
get_gateway_user_data,
get_job_instance_name,
get_user_data,
merge_tags,
)
Expand All @@ -39,11 +38,10 @@
InstanceConfiguration,
InstanceOffer,
InstanceOfferWithAvailability,
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.resources import Memory, Range
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.core.models.volumes import (
Volume,
VolumeAttachmentData,
Expand All @@ -69,14 +67,14 @@ class AWSVolumeBackendData(CoreModel):


class AWSCompute(
Compute,
ComputeWithCreateInstanceSupport,
ComputeWithMultinodeSupport,
ComputeWithReservationSupport,
ComputeWithPlacementGroupSupport,
ComputeWithGatewaySupport,
ComputeWithPrivateGatewaySupport,
ComputeWithVolumeSupport,
Compute,
):
def __init__(self, config: AWSConfig):
super().__init__()
Expand Down Expand Up @@ -285,44 +283,6 @@ def create_instance(
continue
raise NoCapacityError()

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
# TODO: run_job is the same for vm-based backends, refactor
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=get_job_instance_name(run, job), # TODO: generate name
ssh_keys=[
SSHKey(public=project_ssh_public_key.strip()),
],
user=run.user,
volumes=volumes,
reservation=run.run_spec.configuration.reservation,
)
instance_offer = instance_offer.copy()
if len(volumes) > 0:
volume = volumes[0]
if (
volume.provisioning_data is not None
and volume.provisioning_data.availability_zone is not None
):
if instance_offer.availability_zones is None:
instance_offer.availability_zones = [
volume.provisioning_data.availability_zone
]
instance_offer.availability_zones = [
z
for z in instance_offer.availability_zones
if z == volume.provisioning_data.availability_zone
]
return self.create_instance(instance_offer, instance_config)

def create_placement_group(
self,
placement_group: PlacementGroup,
Expand Down
26 changes: 2 additions & 24 deletions src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
generate_unique_gateway_instance_name,
generate_unique_instance_name,
get_gateway_user_data,
get_job_instance_name,
get_user_data,
merge_tags,
)
Expand All @@ -62,11 +61,9 @@
InstanceOffer,
InstanceOfferWithAvailability,
InstanceType,
SSHKey,
)
from dstack._internal.core.models.resources import Memory, Range
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import Volume
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -75,10 +72,10 @@


class AzureCompute(
Compute,
ComputeWithCreateInstanceSupport,
ComputeWithMultinodeSupport,
ComputeWithGatewaySupport,
Compute,
):
def __init__(self, config: AzureConfig, credential: TokenCredential):
super().__init__()
Expand Down Expand Up @@ -198,25 +195,6 @@ def create_instance(
backend_data=None,
)

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=get_job_instance_name(run, job), # TODO: generate name
ssh_keys=[
SSHKey(public=project_ssh_public_key.strip()),
],
user=run.user,
)
return self.create_instance(instance_offer, instance_config)

def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
):
Expand Down
46 changes: 46 additions & 0 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dstack._internal.core.models.instances import (
InstanceConfiguration,
InstanceOfferWithAvailability,
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
Expand Down Expand Up @@ -144,6 +145,51 @@ def create_instance(
"""
pass

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
"""
The default `run_job()` implementation for all backends that support `create_instance()`.
Override only if custom `run_job()` behavior is required.
"""
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=get_job_instance_name(run, job),
user=run.user,
ssh_keys=[SSHKey(public=project_ssh_public_key.strip())],
volumes=volumes,
reservation=run.run_spec.configuration.reservation,
)
instance_offer = instance_offer.copy()
self._restrict_instance_offer_az_to_volumes_az(instance_offer, volumes)
return self.create_instance(instance_offer, instance_config)

def _restrict_instance_offer_az_to_volumes_az(
self,
instance_offer: InstanceOfferWithAvailability,
volumes: List[Volume],
):
if len(volumes) == 0:
return
volume = volumes[0]
if (
volume.provisioning_data is not None
and volume.provisioning_data.availability_zone is not None
):
if instance_offer.availability_zones is None:
instance_offer.availability_zones = [volume.provisioning_data.availability_zone]
instance_offer.availability_zones = [
z
for z in instance_offer.availability_zones
if z == volume.provisioning_data.availability_zone
]


class ComputeWithMultinodeSupport:
"""
Expand Down
26 changes: 2 additions & 24 deletions src/dstack/_internal/core/backends/cudo/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dstack._internal.core.backends.base.compute import (
ComputeWithCreateInstanceSupport,
generate_unique_instance_name,
get_job_instance_name,
get_shim_commands,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
Expand All @@ -18,10 +17,8 @@
InstanceAvailability,
InstanceConfiguration,
InstanceOfferWithAvailability,
SSHKey,
)
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import Volume
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -31,8 +28,8 @@


class CudoCompute(
Compute,
ComputeWithCreateInstanceSupport,
Compute,
):
def __init__(self, config: CudoConfig):
super().__init__()
Expand All @@ -55,25 +52,6 @@ def get_offers(
]
return offers

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=get_job_instance_name(run, job),
ssh_keys=[
SSHKey(public=project_ssh_public_key.strip()),
],
user=run.user,
)
return self.create_instance(instance_offer, instance_config)

def create_instance(
self,
instance_offer: InstanceOfferWithAvailability,
Expand Down
25 changes: 2 additions & 23 deletions src/dstack/_internal/core/backends/datacrunch/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
InstanceConfiguration,
InstanceOffer,
InstanceOfferWithAvailability,
SSHKey,
)
from dstack._internal.core.models.resources import Memory, Range
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import Volume
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.utils.logging import get_logger

logger = get_logger("datacrunch.compute")
Expand All @@ -35,8 +33,8 @@


class DataCrunchCompute(
Compute,
ComputeWithCreateInstanceSupport,
Compute,
):
def __init__(self, config: DataCrunchConfig):
super().__init__()
Expand Down Expand Up @@ -152,25 +150,6 @@ def create_instance(
backend_data=None,
)

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=job.job_spec.job_name, # TODO: generate name
ssh_keys=[
SSHKey(public=project_ssh_public_key.strip()),
],
user=run.user,
)
return self.create_instance(instance_offer, instance_config)

def terminate_instance(
self, instance_id: str, region: str, backend_data: Optional[str] = None
) -> None:
Expand Down
44 changes: 2 additions & 42 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
generate_unique_instance_name,
generate_unique_volume_name,
get_gateway_user_data,
get_job_instance_name,
get_shim_commands,
get_user_data,
merge_tags,
Expand All @@ -46,10 +45,9 @@
InstanceOfferWithAvailability,
InstanceType,
Resources,
SSHKey,
)
from dstack._internal.core.models.resources import Memory, Range
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.runs import JobProvisioningData, Requirements
from dstack._internal.core.models.volumes import (
Volume,
VolumeAttachmentData,
Expand All @@ -74,11 +72,11 @@ class GCPVolumeDiskBackendData(CoreModel):


class GCPCompute(
Compute,
ComputeWithCreateInstanceSupport,
ComputeWithMultinodeSupport,
ComputeWithGatewaySupport,
ComputeWithVolumeSupport,
Compute,
):
def __init__(self, config: GCPConfig):
super().__init__()
Expand Down Expand Up @@ -373,44 +371,6 @@ def update_provisioning_data(
f"Failed to get instance IP address. Instance status: {instance.status}"
)

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
volumes: List[Volume],
) -> JobProvisioningData:
# TODO: run_job is the same for vm-based backends, refactor
instance_config = InstanceConfiguration(
project_name=run.project_name,
instance_name=get_job_instance_name(run, job), # TODO: generate name
ssh_keys=[
SSHKey(public=project_ssh_public_key.strip()),
],
user=run.user,
volumes=volumes,
reservation=run.run_spec.configuration.reservation,
)
instance_offer = instance_offer.copy()
if len(volumes) > 0:
volume = volumes[0]
if (
volume.provisioning_data is not None
and volume.provisioning_data.availability_zone is not None
):
if instance_offer.availability_zones is None:
instance_offer.availability_zones = [
volume.provisioning_data.availability_zone
]
instance_offer.availability_zones = [
z
for z in instance_offer.availability_zones
if z == volume.provisioning_data.availability_zone
]
return self.create_instance(instance_offer, instance_config)

def create_gateway(
self,
configuration: GatewayComputeConfiguration,
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@


class KubernetesCompute(
Compute,
ComputeWithGatewaySupport,
Compute,
):
def __init__(self, config: KubernetesConfig):
super().__init__()
Expand Down
Loading
Loading