diff --git a/docs/docs/reference/server/config.yml.md b/docs/docs/reference/server/config.yml.md index 84190def57..c96b0b84cd 100644 --- a/docs/docs/reference/server/config.yml.md +++ b/docs/docs/reference/server/config.yml.md @@ -162,8 +162,8 @@ There are two ways to configure AWS: using an access key or using the default cr For the regions without configured `vpc_ids`, enable default VPCs by setting `default_vpcs` to `true`. ??? info "Private subnets" - By default, `dstack` utilizes public subnets and permits inbound SSH traffic exclusively for any provisioned instances. - If you want `dstack` to use private subnets, set `public_ips` to `false`. + By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic. + If you want `dstack` to use private subnets and provision instances without public IPs, set `public_ips` to `false`. ```yaml projects: @@ -176,8 +176,8 @@ There are two ways to configure AWS: using an access key or using the default cr public_ips: false ``` - Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets - (e.g., through VPC peering). + Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets. + Additionally, private subnets must have outbound internet connectivity provided by NAT Gateway, Transit Gateway, or other mechanism. #### Azure @@ -287,6 +287,44 @@ There are two ways to configure Azure: using a client secret or using the defaul } ``` +??? info "VPC" + By default, `dstack` creates new Azure networks and subnets for every configured region. + It's possible to use custom networks by specifying `vpc_ids`: + + ```yaml + projects: + - name: main + backends: + - type: azure + creds: + type: default + regions: [westeurope] + vpc_ids: + westeurope: myNetworkResourceGroup/myNetworkName + ``` + + +??? info "Private subnets" + By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic. + If you want `dstack` to use private subnets and provision instances without public IPs, + specify custom networks using `vpc_ids` and set `public_ips` to `false`. + + ```yaml + projects: + - name: main + backends: + - type: azure + creds: + type: default + regions: [westeurope] + vpc_ids: + westeurope: myNetworkResourceGroup/myNetworkName + public_ips: false + ``` + + Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets. + Additionally, private subnets must have outbound internet connectivity provided by [NAT Gateway or other mechanism](https://learn.microsoft.com/en-us/azure/nat-gateway/nat-overview). + #### GCP There are two ways to configure GCP: using a service account or using the default credentials. @@ -441,8 +479,8 @@ gcloud projects list --format="json(projectId)" * Allow `INGRESS` traffic on ports `22`, `80`, `443`, with the target tag `dstack-gateway-instance` ??? info "Private subnets" - By default, `dstack` utilizes public subnets and permits inbound SSH traffic exclusively for any provisioned instances. - If you want `dstack` to use private subnets, set `public_ips` to `false`. + By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic. + If you want `dstack` to use private subnets and provision instances without public IPs, set `public_ips` to `false`. ```yaml projects: @@ -455,7 +493,8 @@ gcloud projects list --format="json(projectId)" public_ips: false ``` - Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets (e.g., through VPC peering). Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances. + Using private subnets assumes that both the `dstack` server and users can access the configured VPC's private subnets. + Additionally, [Cloud NAT](https://cloud.google.com/nat/docs/overview) must be configured to provide access to external resources for provisioned instances. #### Lambda diff --git a/src/dstack/_internal/core/backends/aws/resources.py b/src/dstack/_internal/core/backends/aws/resources.py index 4e3914503a..729697cf19 100644 --- a/src/dstack/_internal/core/backends/aws/resources.py +++ b/src/dstack/_internal/core/backends/aws/resources.py @@ -5,7 +5,7 @@ import botocore.exceptions import dstack.version as version -from dstack._internal.core.errors import ComputeError, ComputeResourceNotFoundError +from dstack._internal.core.errors import BackendError, ComputeError, ComputeResourceNotFoundError def get_image_id(ec2_client: botocore.client.BaseClient, cuda: bool) -> str: @@ -408,7 +408,7 @@ def make_tags(tags: Dict[str, str]) -> List[Dict[str, str]]: def validate_tags(tags: Dict[str, str]): for k, v in tags.items(): if not _is_valid_tag(k, v): - raise ComputeError( + raise BackendError( "Invalid resource tags. " "See tags restrictions: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#tag-restrictions" ) diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 0d98cc2a20..f4ffd0f041 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple from azure.core.credentials import TokenCredential -from azure.core.exceptions import ResourceExistsError +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.mgmt import compute as compute_mgmt from azure.mgmt import network as network_mgmt from azure.mgmt.compute.models import ( @@ -33,6 +33,7 @@ from dstack import version from dstack._internal import settings +from dstack._internal.core.backends.azure import resources as azure_resources from dstack._internal.core.backends.azure import utils as azure_utils from dstack._internal.core.backends.azure.config import AzureConfig from dstack._internal.core.backends.base.compute import ( @@ -110,6 +111,19 @@ def create_instance( ssh_pub_keys = instance_config.get_public_keys() disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024) + allocate_public_ip = self.config.allocate_public_ips + network_resource_group, network, subnet = get_resource_group_network_subnet_or_error( + network_client=self._network_client, + resource_group=self.config.resource_group, + vpc_ids=self.config.vpc_ids, + location=location, + allocate_public_ip=allocate_public_ip, + ) + network_security_group = azure_utils.get_default_network_security_group_name( + resource_group=self.config.resource_group, + location=location, + ) + tags = { "owner": "dstack", "dstack_project": instance_config.project_name, @@ -122,18 +136,9 @@ def create_instance( subscription_id=self.config.subscription_id, location=location, resource_group=self.config.resource_group, - network_security_group=azure_utils.get_default_network_security_group_name( - resource_group=self.config.resource_group, - location=location, - ), - network=azure_utils.get_default_network_name( - resource_group=self.config.resource_group, - location=location, - ), - subnet=azure_utils.get_default_subnet_name( - resource_group=self.config.resource_group, - location=location, - ), + network_security_group=network_security_group, + network=network, + subnet=subnet, managed_identity=None, image_reference=_get_image_ref( compute_client=self._compute_client, @@ -149,6 +154,8 @@ def create_instance( spot=instance_offer.instance.resources.spot, disk_size=disk_size, computer_name="runnervm", + allocate_public_ip=allocate_public_ip, + network_resource_group=network_resource_group, tags=tags, ) logger.info("Request succeeded") @@ -157,11 +164,14 @@ def create_instance( resource_group=self.config.resource_group, vm=vm, ) + hostname = public_ip + if allocate_public_ip: + hostname = private_ip return JobProvisioningData( backend=instance_offer.backend, instance_type=instance_offer.instance, instance_id=vm.name, - hostname=public_ip, + hostname=hostname, internal_ip=private_ip, region=location, price=instance_offer.price, @@ -211,6 +221,18 @@ def create_gateway( configuration.region, ) + network_resource_group, network, subnet = get_resource_group_network_subnet_or_error( + network_client=self._network_client, + resource_group=self.config.resource_group, + vpc_ids=self.config.vpc_ids, + location=configuration.region, + allocate_public_ip=True, + ) + network_security_group = azure_utils.get_default_network_security_group_name( + resource_group=self.config.resource_group, + location=configuration.region, + ) + tags = { "Name": configuration.instance_name, "owner": "dstack", @@ -225,27 +247,19 @@ def create_gateway( subscription_id=self.config.subscription_id, location=configuration.region, resource_group=self.config.resource_group, - network_security_group=azure_utils.get_gateway_network_security_group_name( - resource_group=self.config.resource_group, - location=configuration.region, - ), - network=azure_utils.get_default_network_name( - resource_group=self.config.resource_group, - location=configuration.region, - ), - subnet=azure_utils.get_default_subnet_name( - resource_group=self.config.resource_group, - location=configuration.region, - ), + network_security_group=network_security_group, + network=network, + subnet=subnet, managed_identity=None, image_reference=_get_gateway_image_ref(), - vm_size="Standard_B1s", + vm_size="Standard_B1ms", instance_name=configuration.instance_name, user_data=get_gateway_user_data(configuration.ssh_key_pub), ssh_pub_keys=[configuration.ssh_key_pub], spot=False, disk_size=30, computer_name="gatewayvm", + network_resource_group=network_resource_group, tags=tags, ) logger.info("Request succeeded") @@ -273,6 +287,57 @@ def terminate_gateway( ) +def get_resource_group_network_subnet_or_error( + network_client: network_mgmt.NetworkManagementClient, + resource_group: Optional[str], + vpc_ids: Optional[Dict[str, str]], + location: str, + allocate_public_ip: bool, +) -> Tuple[str, str, str]: + if vpc_ids is not None: + vpc_id = vpc_ids.get(location) + if vpc_id is None: + raise ComputeError(f"Network not configured for location {location}") + try: + resource_group, network_name = _parse_config_vpc_id(vpc_id) + except Exception: + raise ComputeError( + "Network specified in incorrect format." + " Supported format for `vps_ids` values: 'networkResourceGroupName/networkName'" + ) + elif resource_group is not None: + network_name = azure_utils.get_default_network_name(resource_group, location) + else: + raise ComputeError("`resource_group` or `vpc_ids` must be specified") + + try: + subnets = azure_resources.get_network_subnets( + network_client=network_client, + resource_group=resource_group, + network_name=network_name, + private=not allocate_public_ip, + ) + except ResourceNotFoundError: + raise ComputeError( + f"Network {network_name} not found in location {location} in resource group {resource_group}" + ) + + if len(subnets) == 0: + if not allocate_public_ip: + raise ComputeError( + f"Failed to find private subnets with outbound internet connectivity in network {network_name}" + ) + raise ComputeError(f"Failed to find subnets in network {network_name}") + + subnet_name = subnets[0] + return resource_group, network_name, subnet_name + + +def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]: + resource_group, network_name = vpc_id.split("/") + return resource_group, network_name + + class VMImageVariant(enum.Enum): GRID = enum.auto() CUDA = enum.auto() @@ -396,10 +461,19 @@ def _launch_instance( spot: bool, disk_size: int, computer_name: str, + allocate_public_ip: bool = True, + network_resource_group: Optional[str] = None, tags: Optional[Dict[str, str]] = None, ) -> VirtualMachine: if tags is None: tags = {} + if network_resource_group is None: + network_resource_group = resource_group + public_ip_address_configuration = None + if allocate_public_ip: + public_ip_address_configuration = VirtualMachinePublicIPAddressConfiguration( + name="public_ip_config", + ) try: poller = compute_client.virtual_machines.begin_create_or_update( resource_group, @@ -451,14 +525,12 @@ def _launch_instance( subnet=SubResource( id=azure_utils.get_subnet_id( subscription_id, - resource_group, + network_resource_group, network, subnet, ) ), - public_ip_address_configuration=VirtualMachinePublicIPAddressConfiguration( - name="public_ip_config", - ), + public_ip_address_configuration=public_ip_address_configuration, ) ], ) @@ -505,18 +577,21 @@ def _get_vm_public_private_ips( network_client: network_mgmt.NetworkManagementClient, resource_group: str, vm: VirtualMachine, -) -> Tuple[str, str]: +) -> Tuple[Optional[str], str]: nic_id = vm.network_profile.network_interfaces[0].id nic_name = azure_utils.get_resource_name_from_resource_id(nic_id) nic = network_client.network_interfaces.get( resource_group_name=resource_group, network_interface_name=nic_name, ) + + private_ip = nic.ip_configurations[0].private_ip_address + if nic.ip_configurations[0].public_ip_address is None: + return None, private_ip + public_ip_id = nic.ip_configurations[0].public_ip_address.id public_ip_name = azure_utils.get_resource_name_from_resource_id(public_ip_id) public_ip = network_client.public_ip_addresses.get(resource_group, public_ip_name) - - private_ip = nic.ip_configurations[0].private_ip_address return public_ip.ip_address, private_ip diff --git a/src/dstack/_internal/core/backends/azure/config.py b/src/dstack/_internal/core/backends/azure/config.py index 4e7cff268f..7a25bb91a9 100644 --- a/src/dstack/_internal/core/backends/azure/config.py +++ b/src/dstack/_internal/core/backends/azure/config.py @@ -4,3 +4,9 @@ class AzureConfig(AzureStoredConfig, BackendConfig): creds: AnyAzureCreds + + @property + def allocate_public_ips(self) -> bool: + if self.public_ips is not None: + return self.public_ips + return True diff --git a/src/dstack/_internal/core/backends/azure/resources.py b/src/dstack/_internal/core/backends/azure/resources.py index e03ba484d2..dff3b2db2c 100644 --- a/src/dstack/_internal/core/backends/azure/resources.py +++ b/src/dstack/_internal/core/backends/azure/resources.py @@ -1,13 +1,84 @@ import re -from typing import Dict +from typing import Dict, List -from dstack._internal.core.errors import ComputeError +from azure.mgmt import network as network_mgmt +from azure.mgmt.network.models import Subnet + +from dstack._internal.core.errors import BackendError + + +def get_network_subnets( + network_client: network_mgmt.NetworkManagementClient, + resource_group: str, + network_name: str, + private: bool, +) -> List[str]: + res = [] + subnets = network_client.subnets.list( + resource_group_name=resource_group, virtual_network_name=network_name + ) + for subnet in subnets: + if private: + if _is_eligible_private_subnet( + network_client=network_client, + resource_group=resource_group, + network_name=network_name, + subnet=subnet, + ): + res.append(subnet.name) + else: + if _is_eligible_public_subnet( + network_client=network_client, + resource_group=resource_group, + network_name=network_name, + subnet=subnet, + ): + res.append(subnet.name) + return res + + +def _is_eligible_public_subnet( + network_client: network_mgmt.NetworkManagementClient, + resource_group: str, + network_name: str, + subnet: Subnet, +) -> bool: + # Apparently, in Azure practically any subnet can be used + # to provision instances with public IPs + return True + + +def _is_eligible_private_subnet( + network_client: network_mgmt.NetworkManagementClient, + resource_group: str, + network_name: str, + subnet: Subnet, +) -> bool: + # Azure provides default outbound connectivity but it's deprecated + # and does not work with Flexible orchestration used in dstack, + # so we require an explicit outbound method such as NAT Gateway. + + if subnet.nat_gateway is not None: + return True + + vnet_peerings = list( + network_client.virtual_network_peerings.list( + resource_group_name=resource_group, + virtual_network_name=network_name, + ) + ) + if len(vnet_peerings) > 0: + # We currently assume that any peering can provide outbound connectivity. + # There can be a more elaborate check of the peering configuration. + return True + + return False def validate_tags(tags: Dict[str, str]): for k, v in tags.items(): if not _is_valid_tag(k, v): - raise ComputeError( + raise BackendError( "Invalid Azure resource tags. " "See tags restrictions: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/tag-resources#limitations" ) diff --git a/src/dstack/_internal/core/backends/gcp/resources.py b/src/dstack/_internal/core/backends/gcp/resources.py index c34f376ebf..e7f8d2ea8a 100644 --- a/src/dstack/_internal/core/backends/gcp/resources.py +++ b/src/dstack/_internal/core/backends/gcp/resources.py @@ -11,7 +11,7 @@ from google.cloud import tpu_v2 import dstack.version as version -from dstack._internal.core.errors import ComputeError +from dstack._internal.core.errors import BackendError, ComputeError from dstack._internal.core.models.instances import Gpu from dstack._internal.utils.common import remove_prefix from dstack._internal.utils.logging import get_logger @@ -314,7 +314,7 @@ def get_accelerators( def validate_labels(labels: Dict[str, str]): for k, v in labels.items(): if not _is_valid_label(k, v): - raise ComputeError( + raise BackendError( "Invalid resource labels. " "See labels restrictions: https://cloud.google.com/compute/docs/labeling-resources#requirements" ) diff --git a/src/dstack/_internal/core/models/backends/azure.py b/src/dstack/_internal/core/models/backends/azure.py index 85a660cc11..21b5b41a0d 100644 --- a/src/dstack/_internal/core/models/backends/azure.py +++ b/src/dstack/_internal/core/models/backends/azure.py @@ -12,6 +12,8 @@ class AzureConfigInfo(CoreModel): tenant_id: str subscription_id: str locations: Optional[List[str]] = None + vpc_ids: Optional[Dict[str, str]] = None + public_ips: Optional[bool] = None tags: Optional[Dict[str, str]] = None @@ -47,6 +49,8 @@ class AzureConfigInfoWithCredsPartial(CoreModel): tenant_id: Optional[str] subscription_id: Optional[str] locations: Optional[List[str]] + vpc_ids: Optional[Dict[str, str]] + public_ips: Optional[bool] tags: Optional[Dict[str, str]] diff --git a/src/dstack/_internal/server/services/backends/configurators/aws.py b/src/dstack/_internal/server/services/backends/configurators/aws.py index 743517fb47..beeb00d090 100644 --- a/src/dstack/_internal/server/services/backends/configurators/aws.py +++ b/src/dstack/_internal/server/services/backends/configurators/aws.py @@ -6,7 +6,11 @@ from dstack._internal.core.backends.aws import AWSBackend, auth, compute, resources from dstack._internal.core.backends.aws.config import AWSConfig -from dstack._internal.core.errors import BackendAuthError, ComputeError, ServerClientError +from dstack._internal.core.errors import ( + BackendAuthError, + BackendError, + ServerClientError, +) from dstack._internal.core.models.backends.aws import ( AnyAWSConfigInfo, AWSAccessKeyCreds, @@ -144,7 +148,7 @@ def _check_tags_config(self, config: AWSConfigInfoWithCredsPartial): ) try: resources.validate_tags(config.tags) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPartial): @@ -188,5 +192,5 @@ def _check_vpc_config(self, session: Session, config: AWSConfigInfoWithCredsPart for future in concurrent.futures.as_completed(futures): try: future.result() - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) diff --git a/src/dstack/_internal/server/services/backends/configurators/azure.py b/src/dstack/_internal/server/services/backends/configurators/azure.py index 0ef3155ad6..363d331c8b 100644 --- a/src/dstack/_internal/server/services/backends/configurators/azure.py +++ b/src/dstack/_internal/server/services/backends/configurators/azure.py @@ -1,5 +1,5 @@ import json -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Optional, Tuple from azure.core.credentials import TokenCredential @@ -18,10 +18,14 @@ ) from azure.mgmt.resource.resources.models import ResourceGroup -from dstack._internal.core.backends.azure import AzureBackend, auth, resources +from dstack._internal.core.backends.azure import AzureBackend, auth, compute, resources from dstack._internal.core.backends.azure import utils as azure_utils from dstack._internal.core.backends.azure.config import AzureConfig -from dstack._internal.core.errors import BackendAuthError, ComputeError, ServerClientError +from dstack._internal.core.errors import ( + BackendAuthError, + BackendError, + ServerClientError, +) from dstack._internal.core.models.backends.azure import ( AnyAzureConfigInfo, AzureClientCreds, @@ -139,7 +143,7 @@ def get_config_values(self, config: AzureConfigInfoWithCredsPartial) -> AzureCon config_values.locations = self._get_locations_element( selected=config.locations or DEFAULT_LOCATIONS ) - self._check_config(config) + self._check_config(config=config, credential=credential) return config_values def create_backend( @@ -161,6 +165,7 @@ def create_backend( subscription_id=config.subscription_id, resource_group=resource_group, locations=config.locations, + create_default_network=config.vpc_ids is None, ) return BackendModel( project_id=project.id, @@ -285,17 +290,19 @@ def _create_network_resources( subscription_id: str, resource_group: str, locations: List[str], + create_default_network: bool, ): def func(location: str): network_manager = NetworkManager( credential=credential, subscription_id=subscription_id ) - network_manager.create_virtual_network( - resource_group=resource_group, - location=location, - name=azure_utils.get_default_network_name(resource_group, location), - subnet_name=azure_utils.get_default_subnet_name(resource_group, location), - ) + if create_default_network: + network_manager.create_virtual_network( + resource_group=resource_group, + location=location, + name=azure_utils.get_default_network_name(resource_group, location), + subnet_name=azure_utils.get_default_subnet_name(resource_group, location), + ) network_manager.create_network_security_group( resource_group=resource_group, location=location, @@ -311,8 +318,11 @@ def func(location: str): for location in locations: executor.submit(func, location) - def _check_config(self, config: AzureConfigInfoWithCredsPartial): + def _check_config( + self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential + ): self._check_tags_config(config) + self._check_vpc_config(config=config, credential=credential) def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial): if not config.tags: @@ -323,9 +333,59 @@ def _check_tags_config(self, config: AzureConfigInfoWithCredsPartial): ) try: resources.validate_tags(config.tags) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) + def _check_vpc_config( + self, config: AzureConfigInfoWithCredsPartial, credential: auth.AzureCredential + ): + if config.subscription_id is None: + return None + allocate_public_ip = config.public_ips if config.public_ips is not None else True + if config.public_ips is False and config.vpc_ids is None: + raise ServerClientError(msg="`vpc_ids` must be specified if `public_ips: false`.") + locations = config.locations + if locations is None: + locations = DEFAULT_LOCATIONS + if config.vpc_ids is not None: + vpc_ids_locations = list(config.vpc_ids.keys()) + not_configured_locations = [loc for loc in locations if loc not in vpc_ids_locations] + if len(not_configured_locations) > 0: + if config.locations is None: + raise ServerClientError( + f"`vpc_ids` not configured for regions {not_configured_locations}. " + "Configure `vpc_ids` for all regions or specify `regions`." + ) + raise ServerClientError( + f"`vpc_ids` not configured for regions {not_configured_locations}. " + "Configure `vpc_ids` for all regions specified in `regions`." + ) + network_client = network_mgmt.NetworkManagementClient( + credential=credential, + subscription_id=config.subscription_id, + ) + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + for location in locations: + future = executor.submit( + compute.get_resource_group_network_subnet_or_error, + network_client=network_client, + resource_group=None, + vpc_ids=config.vpc_ids, + location=location, + allocate_public_ip=allocate_public_ip, + ) + futures.append(future) + for future in as_completed(futures): + try: + future.result() + except BackendError as e: + raise ServerClientError(e.args[0]) + + +def _get_resource_group_name(project_name: str) -> str: + return f"dstack-{project_name}" + class ResourceManager: def __init__(self, credential: TokenCredential, subscription_id: str): @@ -347,10 +407,6 @@ def create_resource_group( return resource_group.name -def _get_resource_group_name(project_name: str) -> str: - return f"dstack-{project_name}" - - class NetworkManager: def __init__(self, credential: TokenCredential, subscription_id: str): self.network_client = network_mgmt.NetworkManagementClient( diff --git a/src/dstack/_internal/server/services/backends/configurators/gcp.py b/src/dstack/_internal/server/services/backends/configurators/gcp.py index 88e411aa3f..5eef9962cb 100644 --- a/src/dstack/_internal/server/services/backends/configurators/gcp.py +++ b/src/dstack/_internal/server/services/backends/configurators/gcp.py @@ -6,7 +6,7 @@ from dstack._internal.core.backends.gcp import GCPBackend, auth, resources from dstack._internal.core.backends.gcp.config import GCPConfig -from dstack._internal.core.errors import BackendAuthError, ComputeError, ServerClientError +from dstack._internal.core.errors import BackendAuthError, BackendError, ServerClientError from dstack._internal.core.models.backends.base import ( BackendType, ConfigElement, @@ -239,7 +239,7 @@ def _check_tags_config(self, config: GCPConfigInfoWithCredsPartial): ) try: resources.validate_labels(config.tags) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) def _check_vpc_config( @@ -259,5 +259,5 @@ def _check_vpc_config( shared_vpc_project_id=config.vpc_project_id, allocate_public_ip=allocate_public_ip, ) - except ComputeError as e: + except BackendError as e: raise ServerClientError(e.args[0]) diff --git a/src/dstack/_internal/server/services/config.py b/src/dstack/_internal/server/services/config.py index 94f9fd8fbc..255de946b6 100644 --- a/src/dstack/_internal/server/services/config.py +++ b/src/dstack/_internal/server/services/config.py @@ -65,12 +65,20 @@ class AWSConfig(CoreModel): regions: Annotated[Optional[List[str]], Field(description="The list of AWS regions")] = None vpc_name: Annotated[ Optional[str], - Field(description="The VPC name. All configured regions must have a VPC with this name"), + Field( + description=( + "The name of custom VPCs. All configured regions must have a VPC with this name." + " If your custom VPCs don't have names or have different names in different regions, use `vpc_ids` instead." + ) + ), ] = None vpc_ids: Annotated[ Optional[Dict[str, str]], Field( - description="The mapping from AWS regions to VPC IDs. If `default_vpcs: true`, omitted regions will use default VPCs" + description=( + "The mapping from AWS regions to VPC IDs." + " If `default_vpcs: true`, omitted regions will use default VPCs" + ) ), ] = None default_vpcs: Annotated[ @@ -86,7 +94,12 @@ class AWSConfig(CoreModel): public_ips: Annotated[ Optional[bool], Field( - description="A flag to enable/disable public IP assigning on instances. Defaults to `true`" + description=( + "A flag to enable/disable public IP assigning on instances." + " `public_ips: false` requires at least one private subnet with outbound internet connectivity" + " provided by a NAT Gateway or a Transit Gateway." + " Defaults to `true`" + ) ), ] = None tags: Annotated[ @@ -101,7 +114,29 @@ class AzureConfig(CoreModel): tenant_id: Annotated[str, Field(description="The tenant ID")] subscription_id: Annotated[str, Field(description="The subscription ID")] regions: Annotated[ - Optional[List[str]], Field(description="The list of Azure regions (locations)") + Optional[List[str]], + Field(description="The list of Azure regions (locations)"), + ] = None + vpc_ids: Annotated[ + Optional[Dict[str, str]], + Field( + description=( + "The mapping from configured Azure locations to network IDs." + " A network ID must have a format `networkResourceGroup/networkName`" + " If not specified, `dstack` will create a new network for every configured region" + ) + ), + ] = None + public_ips: Annotated[ + Optional[bool], + Field( + description=( + "A flag to enable/disable public IP assigning on instances." + " `public_ips: false` requires `vpc_ids` that specifies custom networks with outbound internet connectivity" + " provided by NAT Gateway or other mechanism." + " Defaults to `true`" + ) + ), ] = None tags: Annotated[ Optional[Dict[str, str]], @@ -132,9 +167,9 @@ class GCPServiceAccountCreds(CoreModel): Optional[str], Field( description=( - "The contents of the service account file. " - "When configuring via `server/config.yml`, it's automatically filled from `filename`. " - "When configuring via UI, it has to be specified explicitly" + "The contents of the service account file." + " When configuring via `server/config.yml`, it's automatically filled from `filename`." + " When configuring via UI, it has to be specified explicitly" ) ), ] = None @@ -166,7 +201,7 @@ class GCPConfig(CoreModel): type: Annotated[Literal["gcp"], Field(description="The type of backend")] = "gcp" project_id: Annotated[str, Field(description="The project ID")] regions: Optional[List[str]] = None - vpc_name: Annotated[Optional[str], Field(description="The VPC name")] = None + vpc_name: Annotated[Optional[str], Field(description="The name of a custom VPC")] = None vpc_project_id: Annotated[ Optional[str], Field(description="The shared VPC hosted project ID. Required for shared VPC only"), @@ -190,7 +225,7 @@ class GCPAPIConfig(CoreModel): type: Annotated[Literal["gcp"], Field(description="The type of backend")] = "gcp" project_id: Annotated[str, Field(description="The project ID")] regions: Optional[List[str]] = None - vpc_name: Annotated[Optional[str], Field(description="The VPC name")] = None + vpc_name: Annotated[Optional[str], Field(description="The name of a custom VPC")] = None vpc_project_id: Annotated[ Optional[str], Field(description="The shared VPC hosted project ID. Required for shared VPC only"), @@ -216,9 +251,9 @@ class KubeconfigConfig(CoreModel): Optional[str], Field( description=( - "The contents of the kubeconfig file. " - "When configuring via `server/config.yml`, it's automatically filled from `filename`. " - "When configuring via UI, it has to be specified explicitly" + "The contents of the kubeconfig file." + " When configuring via `server/config.yml`, it's automatically filled from `filename`." + " When configuring via UI, it has to be specified explicitly" ) ), ] = None @@ -312,8 +347,8 @@ class OCIConfig(CoreModel): Optional[str], Field( description=( - "Compartment where `dstack` will create all resources. " - "Omit to instruct `dstack` to create a new compartment" + "Compartment where `dstack` will create all resources." + " Omit to instruct `dstack` to create a new compartment" ) ), ] = None diff --git a/src/tests/_internal/core/backends/aws/test_resources.py b/src/tests/_internal/core/backends/aws/test_resources.py index 040c60e459..7b23bbf5b9 100644 --- a/src/tests/_internal/core/backends/aws/test_resources.py +++ b/src/tests/_internal/core/backends/aws/test_resources.py @@ -5,7 +5,7 @@ _is_valid_tag_value, validate_tags, ) -from dstack._internal.core.errors import ComputeError +from dstack._internal.core.errors import BackendError class TestIsValidTagKey: @@ -69,5 +69,5 @@ def test_validate_valid_tags(self): def test_validate_invalid_tags(self): tags = {"aws:ReservedKey": "SomeValue", "ValidKey": "Invalid#Value"} - with pytest.raises(ComputeError, match="Invalid resource tags"): + with pytest.raises(BackendError, match="Invalid resource tags"): validate_tags(tags) diff --git a/src/tests/_internal/core/backends/azure/test_resources.py b/src/tests/_internal/core/backends/azure/test_resources.py index 643fff0b94..12498fa9a0 100644 --- a/src/tests/_internal/core/backends/azure/test_resources.py +++ b/src/tests/_internal/core/backends/azure/test_resources.py @@ -5,7 +5,7 @@ _is_valid_tag_value, validate_tags, ) -from dstack._internal.core.errors import ComputeError +from dstack._internal.core.errors import BackendError class TestValidateTags: @@ -15,7 +15,7 @@ def test_valid_tags(self): def test_invalid_tags(self): tags = {"Invalid