diff --git a/cli/dstack/_internal/backend/aws/__init__.py b/cli/dstack/_internal/backend/aws/__init__.py index 612ed1f6d..6e4f8326d 100644 --- a/cli/dstack/_internal/backend/aws/__init__.py +++ b/cli/dstack/_internal/backend/aws/__init__.py @@ -1,8 +1,8 @@ from typing import Optional -import boto3 -from botocore.client import BaseClient +from boto3 import Session +from dstack._internal.backend.aws import utils as aws_utils from dstack._internal.backend.aws.compute import AWSCompute from dstack._internal.backend.aws.config import AWSConfig from dstack._internal.backend.aws.logs import AWSLogging @@ -21,31 +21,29 @@ def __init__( ): self.backend_config = backend_config if self.backend_config.credentials is not None: - self._session = boto3.session.Session( + self._session = Session( region_name=self.backend_config.region_name, aws_access_key_id=self.backend_config.credentials.get("access_key"), aws_secret_access_key=self.backend_config.credentials.get("secret_key"), ) else: - self._session = boto3.session.Session(region_name=self.backend_config.region_name) + self._session = Session(region_name=self.backend_config.region_name) self._storage = AWSStorage( - s3_client=self._s3_client(), bucket_name=self.backend_config.bucket_name + s3_client=aws_utils.get_s3_client(self._session), + bucket_name=self.backend_config.bucket_name, ) self._compute = AWSCompute( - ec2_client=self._ec2_client(), - iam_client=self._iam_client(), - bucket_name=self.backend_config.bucket_name, - region_name=self.backend_config.region_name, - subnet_id=self.backend_config.subnet_id, + session=self._session, + backend_config=self.backend_config, ) self._secrets_manager = AWSSecretsManager( - secretsmanager_client=self._secretsmanager_client(), - iam_client=self._iam_client(), - sts_client=self._sts_client(), + secretsmanager_client=aws_utils.get_secretsmanager_client(self._session), + iam_client=aws_utils.get_iam_client(self._session), + sts_client=aws_utils.get_sts_client(self._session), bucket_name=self.backend_config.bucket_name, ) self._logging = AWSLogging( - logs_client=self._logs_client(), + logs_client=aws_utils.get_logs_client(self._session), bucket_name=self.backend_config.bucket_name, ) @@ -72,27 +70,6 @@ def logging(self) -> AWSLogging: def create_run(self, repo_id: str) -> str: self._logging.create_log_groups_if_not_exist( - self._logs_client(), self.backend_config.bucket_name, repo_id + aws_utils.get_logs_client(self._session), self.backend_config.bucket_name, repo_id ) return base_runs.create_run(self._storage) - - def _s3_client(self) -> BaseClient: - return self._get_client("s3") - - def _ec2_client(self) -> BaseClient: - return self._get_client("ec2") - - def _iam_client(self) -> BaseClient: - return self._get_client("iam") - - def _logs_client(self) -> BaseClient: - return self._get_client("logs") - - def _secretsmanager_client(self) -> BaseClient: - return self._get_client("secretsmanager") - - def _sts_client(self) -> BaseClient: - return self._get_client("sts") - - def _get_client(self, client_name: str) -> BaseClient: - return self._session.client(client_name) diff --git a/cli/dstack/_internal/backend/aws/compute.py b/cli/dstack/_internal/backend/aws/compute.py index 434406fbe..33320a0d8 100644 --- a/cli/dstack/_internal/backend/aws/compute.py +++ b/cli/dstack/_internal/backend/aws/compute.py @@ -1,49 +1,48 @@ from typing import Optional -from botocore.client import BaseClient +from boto3 import Session from dstack._internal.backend.aws import runners +from dstack._internal.backend.aws import utils as aws_utils +from dstack._internal.backend.aws.config import AWSConfig from dstack._internal.backend.base.compute import Compute -from dstack._internal.core.instance import InstanceType +from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead +from dstack._internal.core.runners import Runner class AWSCompute(Compute): def __init__( self, - ec2_client: BaseClient, - iam_client: BaseClient, - bucket_name: str, - region_name: str, - subnet_id: str, + session: Session, + backend_config: AWSConfig, ): - self.ec2_client = ec2_client - self.iam_client = iam_client - self.bucket_name = bucket_name - self.region_name = region_name - self.subnet_id = subnet_id + self.session = session + self.iam_client = aws_utils.get_iam_client(session) + self.backend_config = backend_config def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead: return runners.get_request_head( - ec2_client=self.ec2_client, + ec2_client=self._get_ec2_client(region=job.location), job=job, request_id=request_id, ) def get_instance_type(self, job: Job) -> Optional[InstanceType]: return runners.get_instance_type( - ec2_client=self.ec2_client, + ec2_client=self._get_ec2_client(), requirements=job.requirements, ) - def run_instance(self, job: Job, instance_type: InstanceType) -> str: - return runners.run_instance_retry( - ec2_client=self.ec2_client, + def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanceInfo: + return runners.run_instance( + session=self.session, iam_client=self.iam_client, - bucket_name=self.bucket_name, - region_name=self.region_name, - subnet_id=self.subnet_id, + bucket_name=self.backend_config.bucket_name, + region_name=self.backend_config.region_name, + extra_regions=self.backend_config.extra_regions, + subnet_id=self.backend_config.subnet_id, runner_id=job.runner_id, instance_type=instance_type, spot=job.requirements.spot, @@ -52,14 +51,19 @@ def run_instance(self, job: Job, instance_type: InstanceType) -> str: ssh_key_pub=job.ssh_key_pub, ) - def terminate_instance(self, request_id: str): + def terminate_instance(self, runner: Runner): runners.terminate_instance( - ec2_client=self.ec2_client, - request_id=request_id, + ec2_client=self._get_ec2_client(region=runner.job.location), + request_id=runner.request_id, ) - def cancel_spot_request(self, request_id: str): + def cancel_spot_request(self, runner: Runner): runners.cancel_spot_request( - ec2_client=self.ec2_client, - request_id=request_id, + ec2_client=self._get_ec2_client(region=runner.job.location), + request_id=runner.request_id, ) + + def _get_ec2_client(self, region: Optional[str] = None): + if region is None: + return aws_utils.get_ec2_client(self.session) + return aws_utils.get_ec2_client(self.session, region_name=region) diff --git a/cli/dstack/_internal/backend/aws/config.py b/cli/dstack/_internal/backend/aws/config.py index cc63a20bf..215979f96 100644 --- a/cli/dstack/_internal/backend/aws/config.py +++ b/cli/dstack/_internal/backend/aws/config.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional from pydantic import BaseModel @@ -10,6 +10,7 @@ class AWSConfig(BackendConfig, BaseModel): bucket_name: str region_name: Optional[str] = DEFAULT_REGION_NAME + extra_regions: List[str] = [] subnet_id: Optional[str] = None credentials: Optional[Dict] = None @@ -20,6 +21,8 @@ def serialize(self) -> Dict: } if self.region_name: config_data["region"] = self.region_name + if self.extra_regions: + config_data["extra_regions"] = self.extra_regions if self.subnet_id: config_data["subnet"] = self.subnet_id return config_data @@ -35,5 +38,6 @@ def deserialize(cls, config_data: Dict) -> Optional["AWSConfig"]: return cls( bucket_name=bucket_name, region_name=config_data.get("region"), + extra_regions=config_data.get("extra_regions", []), subnet_id=config_data.get("subnet"), ) diff --git a/cli/dstack/_internal/backend/aws/runners.py b/cli/dstack/_internal/backend/aws/runners.py index ddecc1a00..09fcada87 100644 --- a/cli/dstack/_internal/backend/aws/runners.py +++ b/cli/dstack/_internal/backend/aws/runners.py @@ -4,20 +4,26 @@ from typing import List, Optional, Tuple import botocore.exceptions +from boto3 import Session from botocore.client import BaseClient from dstack import version +from dstack._internal.backend.aws import utils as aws_utils from dstack._internal.backend.base.compute import WS_PORT, NoCapacityError, choose_instance_type from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml -from dstack._internal.core.instance import InstanceType +from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job, Requirements from dstack._internal.core.request import RequestHead, RequestStatus from dstack._internal.core.runners import Gpu, Resources +from dstack._internal.utils import logging CREATE_INSTANCE_RETRY_RATE_SECS = 3 +logger = logging.get_logger(__name__) + + def get_instance_type( ec2_client: BaseClient, requirements: Optional[Requirements] ) -> Optional[InstanceType]: @@ -25,6 +31,209 @@ def get_instance_type( return choose_instance_type(instance_types, requirements) +def run_instance( + session: Session, + iam_client: BaseClient, + bucket_name: str, + region_name: str, + extra_regions: List[str], + subnet_id: Optional[str], + runner_id: str, + instance_type: InstanceType, + spot: bool, + repo_id: str, + hub_user_name: str, + ssh_key_pub: str, +) -> LaunchedInstanceInfo: + regions = [region_name] + if extra_regions: + regions.extend( + _get_instance_available_regions( + ec2_client=aws_utils.get_ec2_client(session), + instance_type=instance_type, + extra_regions=extra_regions, + ) + ) + for region in regions: + try: + logger.info( + "Requesting %s %s instance in %s...", + instance_type.instance_name, + "spot" if spot else "", + region, + ) + request_id = _run_instance_retry( + ec2_client=aws_utils.get_ec2_client(session, region_name=region), + iam_client=iam_client, + bucket_name=bucket_name, + region_name=region, + subnet_id=subnet_id, + runner_id=runner_id, + instance_type=instance_type, + spot=spot, + repo_id=repo_id, + hub_user_name=hub_user_name, + ssh_key_pub=ssh_key_pub, + ) + logger.info("Request succeeded") + return LaunchedInstanceInfo(request_id=request_id, location=region) + except NoCapacityError: + logger.info("Failed to request instance in %s", region) + logger.info("Failed to request instance") + raise NoCapacityError() + + +def cancel_spot_request(ec2_client: BaseClient, request_id: str): + try: + ec2_client.cancel_spot_instance_requests(SpotInstanceRequestIds=[request_id]) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "InvalidSpotInstanceRequestID.NotFound": + return + else: + raise e + response = ec2_client.describe_instances( + Filters=[ + {"Name": "spot-instance-request-id", "Values": [request_id]}, + ], + ) + if response.get("Reservations") and response["Reservations"][0].get("Instances"): + ec2_client.terminate_instances( + InstanceIds=[response["Reservations"][0]["Instances"][0]["InstanceId"]] + ) + + +def terminate_instance(ec2_client: BaseClient, request_id: str): + try: + ec2_client.terminate_instances(InstanceIds=[request_id]) + except Exception as e: + if ( + hasattr(e, "response") + and e.response.get("Error") + and e.response["Error"].get("Code") == "InvalidInstanceID.NotFound" + ): + pass + else: + raise e + + +def get_request_head( + ec2_client: BaseClient, + job: Job, + request_id: Optional[str], +) -> RequestHead: + spot = job.requirements.spot + if request_id is None: + message = ( + "The spot instance request ID is not specified" + if spot + else "The instance ID is not specified" + ) + return RequestHead(job_id=job.job_id, status=RequestStatus.TERMINATED, message=message) + + if spot: + try: + response = ec2_client.describe_spot_instance_requests( + SpotInstanceRequestIds=[request_id] + ) + if response.get("SpotInstanceRequests"): + status = response["SpotInstanceRequests"][0]["Status"] + if status["Code"] in [ + "fulfilled", + "request-canceled-and-instance-running", + "marked-for-stop-by-experiment", + "marked-for-stop", + "marked-for-termination", + ]: + request_status = RequestStatus.RUNNING + elif status["Code"] in [ + "not-scheduled-yet", + "pending-evaluation", + "pending-fulfillment", + ]: + request_status = RequestStatus.PENDING + elif status["Code"] in [ + "capacity-not-available", + "instance-stopped-no-capacity", + "instance-terminated-by-price", + "instance-stopped-by-price", + "instance-terminated-no-capacity", + "instance-stopped-by-experiment", + "instance-terminated-by-experiment", + "limit-exceeded", + "price-too-low", + ]: + request_status = RequestStatus.NO_CAPACITY + elif status["Code"] in [ + "instance-terminated-by-user", + "instance-stopped-by-user", + "canceled-before-fulfillment", + "instance-terminated-by-schedule", + "instance-terminated-by-service", + "spot-instance-terminated-by-user", + ]: + request_status = RequestStatus.TERMINATED + else: + raise Exception( + f"Unsupported EC2 spot instance request status code: {status['Code']}" + ) + return RequestHead( + job_id=job.job_id, status=request_status, message=status.get("Message") + ) + else: + return RequestHead( + job_id=job.job_id, status=RequestStatus.TERMINATED, message=None + ) + except Exception as e: + if ( + hasattr(e, "response") + and e.response.get("Error") + and e.response["Error"].get("Code") == "InvalidSpotInstanceRequestID.NotFound" + ): + return RequestHead( + job_id=job.job_id, + status=RequestStatus.TERMINATED, + message=e.response["Error"].get("Message"), + ) + else: + raise e + else: + try: + response = ec2_client.describe_instances(InstanceIds=[request_id]) + if response.get("Reservations") and response["Reservations"][0].get("Instances"): + state = response["Reservations"][0]["Instances"][0]["State"] + if state["Name"] in ["running"]: + request_status = RequestStatus.RUNNING + elif state["Name"] in ["pending"]: + request_status = RequestStatus.PENDING + elif state["Name"] in [ + "shutting-down", + "terminated", + "stopping", + "stopped", + ]: + request_status = RequestStatus.TERMINATED + else: + raise Exception(f"Unsupported EC2 instance state name: {state['Name']}") + return RequestHead(job_id=job.job_id, status=request_status, message=None) + else: + return RequestHead( + job_id=job.job_id, status=RequestStatus.TERMINATED, message=None + ) + except Exception as e: + if ( + hasattr(e, "response") + and e.response.get("Error") + and e.response["Error"].get("Code") == "InvalidInstanceID.NotFound" + ): + return RequestHead( + job_id=job.job_id, + status=RequestStatus.TERMINATED, + message=e.response["Error"].get("Message"), + ) + else: + raise e + + def _get_instance_types(ec2_client: BaseClient) -> List[InstanceType]: response = None instance_types = [] @@ -76,86 +285,167 @@ def _get_instance_types(ec2_client: BaseClient) -> List[InstanceType]: return instance_types -def _get_security_group_id(ec2_client: BaseClient, bucket_name: str, subnet_id: Optional[str]): - _subnet_postfix = (subnet_id.replace("-", "_") + "_") if subnet_id else "" - security_group_name = ( - "dstack_security_group_" + _subnet_postfix + bucket_name.replace("-", "_").lower() - ) - if not version.__is_release__: - security_group_name += "_stgn" - response = ec2_client.describe_security_groups( - Filters=[ - { - "Name": "group-name", - "Values": [ - security_group_name, - ], - }, - ], +def _get_instance_available_regions( + ec2_client: BaseClient, + instance_type: InstanceType, + extra_regions: List[str], +) -> List[str]: + resp = ec2_client.get_spot_placement_scores( + InstanceTypes=[instance_type.instance_name], + TargetCapacity=1, + RegionNames=extra_regions, ) - if response.get("SecurityGroups"): - security_group_id = response["SecurityGroups"][0]["GroupId"] - else: - group_specification = {} - if subnet_id: - subnets_response = ec2_client.describe_subnets(SubnetIds=[subnet_id]) - group_specification["VpcId"] = subnets_response["Subnets"][0]["VpcId"] - security_group = ec2_client.create_security_group( - Description="Generated by dstack", - GroupName=security_group_name, - TagSpecifications=[ - { - "ResourceType": "security-group", - "Tags": [ - {"Key": "owner", "Value": "dstack"}, - {"Key": "dstack_bucket", "Value": bucket_name}, - ], - }, - ], - **group_specification, - ) - security_group_id = security_group["GroupId"] - ip_permissions = [ - { - "FromPort": 22, - "ToPort": 22, - "IpProtocol": "tcp", - "IpRanges": [{"CidrIp": "0.0.0.0/0"}], - } - ] - ec2_client.authorize_security_group_ingress( - GroupId=security_group_id, IpPermissions=ip_permissions - ) - ec2_client.authorize_security_group_egress( - GroupId=security_group_id, - IpPermissions=[ - { - "IpProtocol": "-1", - } - ], - ) - return security_group_id - - -def _serialize_config_yaml(bucket_name: str, region_name: str): - return f"backend: aws\\n" f"bucket: {bucket_name}\\n" f"region: {region_name}" + spot_scores = resp["SpotPlacementScores"] + spot_scores = sorted(spot_scores, key=lambda x: -x["Score"]) + return [s["Region"] for s in spot_scores] -def _user_data( - bucket_name, - region_name, +def _run_instance_retry( + ec2_client: BaseClient, + iam_client: BaseClient, + bucket_name: str, + region_name: str, + subnet_id: Optional[str], runner_id: str, - resources: Resources, + instance_type: InstanceType, + spot: bool, + repo_id: str, + hub_user_name: str, ssh_key_pub: str, - port_range_from: int = 3000, - port_range_to: int = 4000, + attempts: int = 3, ) -> str: - sysctl_port_range_from = int((port_range_to - port_range_from) / 2) + port_range_from - sysctl_port_range_to = port_range_to - 1 - runner_port_range_from = port_range_from - runner_port_range_to = sysctl_port_range_from - 1 - user_data = f"""#!/bin/bash -if [ -e "/etc/fuse.conf" ] + try: + return _run_instance( + ec2_client, + iam_client, + bucket_name, + region_name, + subnet_id, + runner_id, + instance_type, + spot, + repo_id, + hub_user_name, + ssh_key_pub, + ) + except botocore.exceptions.ClientError as e: + # FIXME: why retry on "InvalidParameterValue" + if e.response["Error"]["Code"] == "InvalidParameterValue": + if attempts > 0: + time.sleep(CREATE_INSTANCE_RETRY_RATE_SECS) + return _run_instance_retry( + ec2_client, + iam_client, + bucket_name, + region_name, + subnet_id, + runner_id, + instance_type, + spot, + repo_id, + hub_user_name, + ssh_key_pub, + attempts - 1, + ) + else: + raise Exception("Failed to retry", e) + elif e.response["Error"]["Code"] == "InsufficientInstanceCapacity": + raise NoCapacityError() + raise e + + +def _run_instance( + ec2_client: BaseClient, + iam_client: BaseClient, + bucket_name: str, + region_name: str, + subnet_id: Optional[str], + runner_id: str, + instance_type: InstanceType, + spot: bool, + repo_id: str, + hub_user_name: str, + ssh_key_pub: str, +) -> str: + launch_specification = {} + if spot: + launch_specification["InstanceMarketOptions"] = { + "MarketType": "spot", + "SpotOptions": { + "SpotInstanceType": "one-time", + "InstanceInterruptionBehavior": "terminate", + }, + } + if subnet_id: + launch_specification["NetworkInterfaces"] = [ + { + "AssociatePublicIpAddress": True, + "DeviceIndex": 0, + "SubnetId": subnet_id, + "Groups": [_get_security_group_id(ec2_client, bucket_name, subnet_id)], + }, + ] + else: + launch_specification["SecurityGroupIds"] = [ + _get_security_group_id(ec2_client, bucket_name, subnet_id) + ] + tags = [ + {"Key": "owner", "Value": "dstack"}, + {"Key": "dstack_bucket", "Value": bucket_name}, + {"Key": "dstack_repo", "Value": repo_id}, + {"Key": "dstack_repo_user", "Value": hub_user_name}, + ] + response = ec2_client.run_instances( + BlockDeviceMappings=[ + { + "DeviceName": "/dev/sda1", + "Ebs": { + "VolumeSize": 100, + "VolumeType": "gp2", + }, + } + ], + ImageId=_get_ami_image(ec2_client, len(instance_type.resources.gpus) > 0)[0], + InstanceType=instance_type.instance_name, + MinCount=1, + MaxCount=1, + IamInstanceProfile={ + "Arn": _get_instance_profile_arn(iam_client, bucket_name), + }, + UserData=_user_data( + bucket_name, region_name, runner_id, instance_type.resources, ssh_key_pub=ssh_key_pub + ), + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": tags, + }, + ], + **launch_specification, + ) + if spot: + request_id = response["Instances"][0]["SpotInstanceRequestId"] + ec2_client.create_tags(Resources=[request_id], Tags=tags) + else: + request_id = response["Instances"][0]["InstanceId"] + return request_id + + +def _user_data( + bucket_name, + region_name, + runner_id: str, + resources: Resources, + ssh_key_pub: str, + port_range_from: int = 3000, + port_range_to: int = 4000, +) -> str: + sysctl_port_range_from = int((port_range_to - port_range_from) / 2) + port_range_from + sysctl_port_range_to = port_range_to - 1 + runner_port_range_from = port_range_from + runner_port_range_to = sysctl_port_range_from - 1 + user_data = f"""#!/bin/bash +if [ -e "/etc/fuse.conf" ] then sudo sed "s/# *user_allow_other/user_allow_other/" /etc/fuse.conf > t sudo mv t /etc/fuse.conf @@ -175,7 +465,100 @@ def _user_data( return user_data -def role_name(iam_client: BaseClient, bucket_name: str) -> str: +def _serialize_config_yaml(bucket_name: str, region_name: str): + return f"backend: aws\\n" f"bucket: {bucket_name}\\n" f"region: {region_name}" + + +def _get_security_group_id(ec2_client: BaseClient, bucket_name: str, subnet_id: Optional[str]): + _subnet_postfix = (subnet_id.replace("-", "_") + "_") if subnet_id else "" + security_group_name = ( + "dstack_security_group_" + _subnet_postfix + bucket_name.replace("-", "_").lower() + ) + if not version.__is_release__: + security_group_name += "_stgn" + response = ec2_client.describe_security_groups( + Filters=[ + { + "Name": "group-name", + "Values": [ + security_group_name, + ], + }, + ], + ) + if response.get("SecurityGroups"): + security_group_id = response["SecurityGroups"][0]["GroupId"] + else: + group_specification = {} + if subnet_id: + subnets_response = ec2_client.describe_subnets(SubnetIds=[subnet_id]) + group_specification["VpcId"] = subnets_response["Subnets"][0]["VpcId"] + security_group = ec2_client.create_security_group( + Description="Generated by dstack", + GroupName=security_group_name, + TagSpecifications=[ + { + "ResourceType": "security-group", + "Tags": [ + {"Key": "owner", "Value": "dstack"}, + {"Key": "dstack_bucket", "Value": bucket_name}, + ], + }, + ], + **group_specification, + ) + security_group_id = security_group["GroupId"] + ip_permissions = [ + { + "FromPort": 22, + "ToPort": 22, + "IpProtocol": "tcp", + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + } + ] + ec2_client.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=ip_permissions + ) + ec2_client.authorize_security_group_egress( + GroupId=security_group_id, + IpPermissions=[ + { + "IpProtocol": "-1", + } + ], + ) + return security_group_id + + +def _get_instance_profile_arn(iam_client: BaseClient, bucket_name: str) -> str: + _role_name = _get_role_name(iam_client, bucket_name) + try: + response = iam_client.get_instance_profile(InstanceProfileName=_role_name) + return response["InstanceProfile"]["Arn"] + except Exception as e: + if ( + hasattr(e, "response") + and e.response.get("Error") + and e.response["Error"].get("Code") == "NoSuchEntity" + ): + response = iam_client.create_instance_profile( + InstanceProfileName=_role_name, + Tags=[ + {"Key": "owner", "Value": "dstack"}, + {"Key": "dstack_bucket", "Value": bucket_name}, + ], + ) + _instance_profile_arn = response["InstanceProfile"]["Arn"] + iam_client.add_role_to_instance_profile( + InstanceProfileName=_role_name, + RoleName=_role_name, + ) + return _instance_profile_arn + else: + raise e + + +def _get_role_name(iam_client: BaseClient, bucket_name: str) -> str: policy_name = "dstack_policy_" + bucket_name.replace("-", "_").lower() _role_name = "dstack_role_" + bucket_name.replace("-", "_").lower() try: @@ -255,34 +638,6 @@ def role_name(iam_client: BaseClient, bucket_name: str) -> str: return _role_name -def _get_instance_profile_arn(iam_client: BaseClient, bucket_name: str) -> str: - _role_name = role_name(iam_client, bucket_name) - try: - response = iam_client.get_instance_profile(InstanceProfileName=_role_name) - return response["InstanceProfile"]["Arn"] - except Exception as e: - if ( - hasattr(e, "response") - and e.response.get("Error") - and e.response["Error"].get("Code") == "NoSuchEntity" - ): - response = iam_client.create_instance_profile( - InstanceProfileName=_role_name, - Tags=[ - {"Key": "owner", "Value": "dstack"}, - {"Key": "dstack_bucket", "Value": bucket_name}, - ], - ) - _instance_profile_arn = response["InstanceProfile"]["Arn"] - iam_client.add_role_to_instance_profile( - InstanceProfileName=_role_name, - RoleName=_role_name, - ) - return _instance_profile_arn - else: - raise e - - def _get_default_ami_image_version() -> Optional[str]: if version.__is_release__: return version.__version__ @@ -320,287 +675,3 @@ def _get_ami_image( return _get_ami_image(ec2_client, cuda, _version=None) else: raise Exception(f"Can't find an AMI image prefix={ami_name!r}") - - -def _run_instance( - ec2_client: BaseClient, - iam_client: BaseClient, - bucket_name: str, - region_name: str, - subnet_id: Optional[str], - runner_id: str, - instance_type: InstanceType, - spot: bool, - repo_id: str, - hub_user_name: str, - ssh_key_pub: str, -) -> str: - launch_specification = {} - if not version.__is_release__: - launch_specification["KeyName"] = "dstack_victor" - if spot: - launch_specification["InstanceMarketOptions"] = { - "MarketType": "spot", - "SpotOptions": { - "SpotInstanceType": "one-time", - "InstanceInterruptionBehavior": "terminate", - }, - } - if subnet_id: - launch_specification["NetworkInterfaces"] = [ - { - "AssociatePublicIpAddress": True, - "DeviceIndex": 0, - "SubnetId": subnet_id, - "Groups": [_get_security_group_id(ec2_client, bucket_name, subnet_id)], - }, - ] - else: - launch_specification["SecurityGroupIds"] = [ - _get_security_group_id(ec2_client, bucket_name, subnet_id) - ] - tags = [ - {"Key": "owner", "Value": "dstack"}, - {"Key": "dstack_bucket", "Value": bucket_name}, - {"Key": "dstack_repo", "Value": repo_id}, - {"Key": "dstack_repo_user", "Value": hub_user_name}, - ] - response = ec2_client.run_instances( - BlockDeviceMappings=[ - { - "DeviceName": "/dev/sda1", - "Ebs": { - "VolumeSize": 100, - "VolumeType": "gp2", - }, - } - ], - ImageId=_get_ami_image(ec2_client, len(instance_type.resources.gpus) > 0)[0], - InstanceType=instance_type.instance_name, - MinCount=1, - MaxCount=1, - IamInstanceProfile={ - "Arn": _get_instance_profile_arn(iam_client, bucket_name), - }, - UserData=_user_data( - bucket_name, region_name, runner_id, instance_type.resources, ssh_key_pub=ssh_key_pub - ), - TagSpecifications=[ - { - "ResourceType": "instance", - "Tags": tags, - }, - ], - **launch_specification, - ) - if spot: - request_id = response["Instances"][0]["SpotInstanceRequestId"] - ec2_client.create_tags(Resources=[request_id], Tags=tags) - else: - request_id = response["Instances"][0]["InstanceId"] - return request_id - - -def run_instance_retry( - ec2_client: BaseClient, - iam_client: BaseClient, - bucket_name: str, - region_name: str, - subnet_id: Optional[str], - runner_id: str, - instance_type: InstanceType, - spot: bool, - repo_id: str, - hub_user_name: str, - ssh_key_pub: str, - attempts: int = 3, -) -> str: - try: - return _run_instance( - ec2_client, - iam_client, - bucket_name, - region_name, - subnet_id, - runner_id, - instance_type, - spot, - repo_id, - hub_user_name, - ssh_key_pub, - ) - except botocore.exceptions.ClientError as e: - # FIXME: why retry on "InvalidParameterValue" - if e.response["Error"]["Code"] == "InvalidParameterValue": - if attempts > 0: - time.sleep(CREATE_INSTANCE_RETRY_RATE_SECS) - return run_instance_retry( - ec2_client, - iam_client, - bucket_name, - region_name, - subnet_id, - runner_id, - instance_type, - spot, - repo_id, - hub_user_name, - ssh_key_pub, - attempts - 1, - ) - else: - raise Exception("Failed to retry", e) - elif e.response["Error"]["Code"] == "InsufficientInstanceCapacity": - raise NoCapacityError() - raise e - - -def cancel_spot_request(ec2_client: BaseClient, request_id: str): - try: - ec2_client.cancel_spot_instance_requests(SpotInstanceRequestIds=[request_id]) - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "InvalidSpotInstanceRequestID.NotFound": - return - else: - raise e - response = ec2_client.describe_instances( - Filters=[ - {"Name": "spot-instance-request-id", "Values": [request_id]}, - ], - ) - if response.get("Reservations") and response["Reservations"][0].get("Instances"): - ec2_client.terminate_instances( - InstanceIds=[response["Reservations"][0]["Instances"][0]["InstanceId"]] - ) - - -def terminate_instance(ec2_client: BaseClient, request_id: str): - try: - ec2_client.terminate_instances(InstanceIds=[request_id]) - except Exception as e: - if ( - hasattr(e, "response") - and e.response.get("Error") - and e.response["Error"].get("Code") == "InvalidInstanceID.NotFound" - ): - pass - else: - raise e - - -def get_request_head( - ec2_client: BaseClient, - job: Job, - request_id: Optional[str], -) -> RequestHead: - spot = job.requirements.spot - if request_id is None: - message = ( - "The spot instance request ID is not specified" - if spot - else "The instance ID is not specified" - ) - return RequestHead(job_id=job.job_id, status=RequestStatus.TERMINATED, message=message) - - if spot: - try: - response = ec2_client.describe_spot_instance_requests( - SpotInstanceRequestIds=[request_id] - ) - if response.get("SpotInstanceRequests"): - status = response["SpotInstanceRequests"][0]["Status"] - if status["Code"] in [ - "fulfilled", - "request-canceled-and-instance-running", - "marked-for-stop-by-experiment", - "marked-for-stop", - "marked-for-termination", - ]: - request_status = RequestStatus.RUNNING - elif status["Code"] in [ - "not-scheduled-yet", - "pending-evaluation", - "pending-fulfillment", - ]: - request_status = RequestStatus.PENDING - elif status["Code"] in [ - "capacity-not-available", - "instance-stopped-no-capacity", - "instance-terminated-by-price", - "instance-stopped-by-price", - "instance-terminated-no-capacity", - "instance-stopped-by-experiment", - "instance-terminated-by-experiment", - "limit-exceeded", - "price-too-low", - ]: - request_status = RequestStatus.NO_CAPACITY - elif status["Code"] in [ - "instance-terminated-by-user", - "instance-stopped-by-user", - "canceled-before-fulfillment", - "instance-terminated-by-schedule", - "instance-terminated-by-service", - "spot-instance-terminated-by-user", - ]: - request_status = RequestStatus.TERMINATED - else: - raise Exception( - f"Unsupported EC2 spot instance request status code: {status['Code']}" - ) - return RequestHead( - job_id=job.job_id, status=request_status, message=status.get("Message") - ) - else: - return RequestHead( - job_id=job.job_id, status=RequestStatus.TERMINATED, message=None - ) - except Exception as e: - if ( - hasattr(e, "response") - and e.response.get("Error") - and e.response["Error"].get("Code") == "InvalidSpotInstanceRequestID.NotFound" - ): - return RequestHead( - job_id=job.job_id, - status=RequestStatus.TERMINATED, - message=e.response["Error"].get("Message"), - ) - else: - raise e - else: - try: - response = ec2_client.describe_instances(InstanceIds=[request_id]) - if response.get("Reservations") and response["Reservations"][0].get("Instances"): - state = response["Reservations"][0]["Instances"][0]["State"] - if state["Name"] in ["running"]: - request_status = RequestStatus.RUNNING - elif state["Name"] in ["pending"]: - request_status = RequestStatus.PENDING - elif state["Name"] in [ - "shutting-down", - "terminated", - "stopping", - "stopped", - ]: - request_status = RequestStatus.TERMINATED - else: - raise Exception(f"Unsupported EC2 instance state name: {state['Name']}") - return RequestHead(job_id=job.job_id, status=request_status, message=None) - else: - return RequestHead( - job_id=job.job_id, status=RequestStatus.TERMINATED, message=None - ) - except Exception as e: - if ( - hasattr(e, "response") - and e.response.get("Error") - and e.response["Error"].get("Code") == "InvalidInstanceID.NotFound" - ): - return RequestHead( - job_id=job.job_id, - status=RequestStatus.TERMINATED, - message=e.response["Error"].get("Message"), - ) - else: - raise e diff --git a/cli/dstack/_internal/backend/aws/secrets.py b/cli/dstack/_internal/backend/aws/secrets.py index d3bf4752d..57059eccc 100644 --- a/cli/dstack/_internal/backend/aws/secrets.py +++ b/cli/dstack/_internal/backend/aws/secrets.py @@ -113,7 +113,7 @@ def _add_secret( {"Key": "dstack_bucket", "Value": bucket_name}, ], ) - role_name = runners.role_name(iam_client, bucket_name) + role_name = runners._get_role_name(iam_client, bucket_name) account_id = sts_client.get_caller_identity()["Account"] resource_policy = json.dumps( { diff --git a/cli/dstack/_internal/backend/aws/utils.py b/cli/dstack/_internal/backend/aws/utils.py index c244626a7..0b868299a 100644 --- a/cli/dstack/_internal/backend/aws/utils.py +++ b/cli/dstack/_internal/backend/aws/utils.py @@ -2,6 +2,36 @@ from typing import Any, List import botocore.exceptions +from boto3 import Session +from botocore.client import BaseClient + + +def get_s3_client(session: Session, **kwargs) -> BaseClient: + return _get_client(session, "s3", **kwargs) + + +def get_ec2_client(session: Session, **kwargs) -> BaseClient: + return _get_client(session, "ec2", **kwargs) + + +def get_iam_client(session: Session, **kwargs) -> BaseClient: + return _get_client(session, "iam", **kwargs) + + +def get_logs_client(session: Session, **kwargs) -> BaseClient: + return _get_client(session, "logs", **kwargs) + + +def get_secretsmanager_client(session: Session, **kwargs) -> BaseClient: + return _get_client(session, "secretsmanager", **kwargs) + + +def get_sts_client(session: Session, **kwargs) -> BaseClient: + return _get_client(session, "sts", **kwargs) + + +def _get_client(session: Session, client_name: str, **kwargs) -> BaseClient: + return session.client(client_name, **kwargs) def retry_operation_on_service_errors( diff --git a/cli/dstack/_internal/backend/azure/compute.py b/cli/dstack/_internal/backend/azure/compute.py index 3a43c9f7f..8205d2800 100644 --- a/cli/dstack/_internal/backend/azure/compute.py +++ b/cli/dstack/_internal/backend/azure/compute.py @@ -41,10 +41,10 @@ from dstack._internal.backend.base.compute import WS_PORT, Compute, choose_instance_type from dstack._internal.backend.base.config import BACKEND_CONFIG_FILENAME, RUNNER_CONFIG_FILENAME from dstack._internal.backend.base.runners import serialize_runner_yaml -from dstack._internal.core.instance import InstanceType +from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead, RequestStatus -from dstack._internal.core.runners import Gpu, Resources +from dstack._internal.core.runners import Gpu, Resources, Runner from dstack._internal.utils.common import removeprefix @@ -82,7 +82,7 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]: ) return choose_instance_type(instance_types=instance_types, requirements=job.requirements) - def run_instance(self, job: Job, instance_type: InstanceType) -> str: + def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanceInfo: vm = _launch_instance( compute_client=self._compute_client, subscription_id=self.azure_config.subscription_id, @@ -105,7 +105,7 @@ def run_instance(self, job: Job, instance_type: InstanceType) -> str: ssh_pub_key=job.ssh_key_pub, spot=instance_type.resources.spot, ) - return vm.name + return LaunchedInstanceInfo(request_id=vm.name, location=self.azure_config.location) def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead: if request_id is None: @@ -125,14 +125,14 @@ def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead: message=None, ) - def cancel_spot_request(self, request_id: str): - self.terminate_instance(request_id) + def cancel_spot_request(self, runner: Runner): + self.terminate_instance(runner.request_id) - def terminate_instance(self, request_id: str): + def terminate_instance(self, runner: Runner): _terminate_instance( compute_client=self._compute_client, resource_group=self.azure_config.resource_group, - instance_name=request_id, + instance_name=runner.request_id, ) diff --git a/cli/dstack/_internal/backend/base/compute.py b/cli/dstack/_internal/backend/base/compute.py index cd08fa73d..61345b81b 100644 --- a/cli/dstack/_internal/backend/base/compute.py +++ b/cli/dstack/_internal/backend/base/compute.py @@ -3,10 +3,10 @@ from typing import List, Optional from dstack._internal.core.error import DstackError -from dstack._internal.core.instance import InstanceType +from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job, Requirements from dstack._internal.core.request import RequestHead -from dstack._internal.core.runners import Resources +from dstack._internal.core.runners import Resources, Runner WS_PORT = 10999 @@ -29,15 +29,15 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]: pass @abstractmethod - def run_instance(self, job: Job, instance_type: InstanceType) -> str: + def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanceInfo: pass @abstractmethod - def terminate_instance(self, request_id: str): + def terminate_instance(self, runner: Runner): pass @abstractmethod - def cancel_spot_request(self, request_id: str): + def cancel_spot_request(self, runner: Runner): pass diff --git a/cli/dstack/_internal/backend/base/jobs.py b/cli/dstack/_internal/backend/base/jobs.py index c21fb3c71..b793301a6 100644 --- a/cli/dstack/_internal/backend/base/jobs.py +++ b/cli/dstack/_internal/backend/base/jobs.py @@ -219,7 +219,9 @@ def _try_run_job( ) runners.create_runner(storage, runner) try: - runner.request_id = compute.run_instance(job, instance_type) + launched_instance_info = compute.run_instance(job, instance_type) + runner.request_id = launched_instance_info.request_id + job.location = launched_instance_info.location except NoCapacityError: if job.spot_policy == SpotPolicy.AUTO and attempt == 0: return _try_run_job( diff --git a/cli/dstack/_internal/backend/base/runners.py b/cli/dstack/_internal/backend/base/runners.py index 8cdd3a91c..f1f253e80 100644 --- a/cli/dstack/_internal/backend/base/runners.py +++ b/cli/dstack/_internal/backend/base/runners.py @@ -37,9 +37,9 @@ def delete_runner(storage: Storage, runner: Runner): def stop_runner(storage: Storage, compute: Compute, runner: Runner): if runner.request_id: if runner.resources.spot: - compute.cancel_spot_request(runner.request_id) + compute.cancel_spot_request(runner) else: - compute.terminate_instance(runner.request_id) + compute.terminate_instance(runner) delete_runner(storage, runner) diff --git a/cli/dstack/_internal/backend/gcp/compute.py b/cli/dstack/_internal/backend/gcp/compute.py index 854865272..8c1027318 100644 --- a/cli/dstack/_internal/backend/gcp/compute.py +++ b/cli/dstack/_internal/backend/gcp/compute.py @@ -16,10 +16,10 @@ from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.backend.gcp import utils as gcp_utils from dstack._internal.backend.gcp.config import GCPConfig -from dstack._internal.core.instance import InstanceType +from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job, Requirements from dstack._internal.core.request import RequestHead, RequestStatus -from dstack._internal.core.runners import Gpu, Resources +from dstack._internal.core.runners import Gpu, Resources, Runner DSTACK_INSTANCE_TAG = "dstack-runner-instance" @@ -97,7 +97,7 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]: requirements=job.requirements, ) - def run_instance(self, job: Job, instance_type: InstanceType) -> str: + def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanceInfo: instance = _launch_instance( instances_client=self.instances_client, firewalls_client=self.firewalls_client, @@ -125,20 +125,20 @@ def run_instance(self, job: Job, instance_type: InstanceType) -> str: ), ssh_key_pub=job.ssh_key_pub, ) - return instance.name + return LaunchedInstanceInfo(request_id=instance.name, location=self.gcp_config.zone) - def terminate_instance(self, request_id: str): + def terminate_instance(self, runner: Runner): _terminate_instance( client=self.instances_client, gcp_config=self.gcp_config, - instance_name=request_id, + instance_name=runner.request_id, ) - def cancel_spot_request(self, request_id: str): + def cancel_spot_request(self, runner: Runner): _terminate_instance( client=self.instances_client, gcp_config=self.gcp_config, - instance_name=request_id, + instance_name=runner.request_id, ) diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py index 4b7b6ec76..2531bc358 100644 --- a/cli/dstack/_internal/backend/lambdalabs/__init__.py +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -1,8 +1,8 @@ from typing import Optional import boto3 -from botocore.client import BaseClient +from dstack._internal.backend.aws import utils as aws_utils from dstack._internal.backend.aws.logs import AWSLogging from dstack._internal.backend.aws.secrets import AWSSecretsManager from dstack._internal.backend.aws.storage import AWSStorage @@ -27,16 +27,17 @@ def __init__( aws_secret_access_key=self.backend_config.storage_config.credentials.secret_key, ) self._storage = AWSStorage( - s3_client=self._s3_client(), bucket_name=self.backend_config.storage_config.bucket + s3_client=aws_utils.get_s3_client(self._session), + bucket_name=self.backend_config.storage_config.bucket, ) self._secrets_manager = AWSSecretsManager( - secretsmanager_client=self._secretsmanager_client(), - iam_client=self._iam_client(), - sts_client=self._sts_client(), + secretsmanager_client=aws_utils.get_secretsmanager_client(self._session), + iam_client=aws_utils.get_iam_client(self._session), + sts_client=aws_utils.get_sts_client(self._session), bucket_name=self.backend_config.storage_config.bucket, ) self._logging = AWSLogging( - logs_client=self._logs_client(), + logs_client=aws_utils.get_logs_client(self._session), bucket_name=self.backend_config.storage_config.bucket, ) @@ -61,27 +62,8 @@ def logging(self) -> AWSLogging: def create_run(self, repo_id: str) -> str: self._logging.create_log_groups_if_not_exist( - self._logs_client(), self.backend_config.storage_config.bucket, repo_id + aws_utils.get_logs_client(self._session), + self.backend_config.storage_config.bucket, + repo_id, ) return base_runs.create_run(self._storage) - - def _s3_client(self) -> BaseClient: - return self._get_client("s3") - - def _ec2_client(self) -> BaseClient: - return self._get_client("ec2") - - def _iam_client(self) -> BaseClient: - return self._get_client("iam") - - def _logs_client(self) -> BaseClient: - return self._get_client("logs") - - def _secretsmanager_client(self) -> BaseClient: - return self._get_client("secretsmanager") - - def _sts_client(self) -> BaseClient: - return self._get_client("sts") - - def _get_client(self, client_name: str) -> BaseClient: - return self._session.client(client_name) diff --git a/cli/dstack/_internal/backend/lambdalabs/compute.py b/cli/dstack/_internal/backend/lambdalabs/compute.py index d9fad917e..0a09a5133 100644 --- a/cli/dstack/_internal/backend/lambdalabs/compute.py +++ b/cli/dstack/_internal/backend/lambdalabs/compute.py @@ -15,10 +15,10 @@ from dstack._internal.backend.base.runners import serialize_runner_yaml from dstack._internal.backend.lambdalabs.api_client import LambdaAPIClient from dstack._internal.backend.lambdalabs.config import LambdaConfig -from dstack._internal.core.instance import InstanceType +from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead, RequestStatus -from dstack._internal.core.runners import Gpu, Resources +from dstack._internal.core.runners import Gpu, Resources, Runner from dstack._internal.hub.utils.ssh import HUB_PRIVATE_KEY_PATH, get_hub_ssh_public_key _WAIT_FOR_INSTANCE_ATTEMPTS = 120 @@ -115,21 +115,23 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]: requirements=job.requirements, ) - def run_instance(self, job: Job, instance_type: InstanceType) -> str: - return _run_instance( + def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanceInfo: + region = _get_instance_region(instance_type, self.lambda_config.regions) + instance_id = _run_instance( api_client=self.api_client, - region_name=_get_instance_region(instance_type, self.lambda_config.regions), + region_name=region, instance_type_name=instance_type.instance_name, user_ssh_key=job.ssh_key_pub, hub_ssh_key=get_hub_ssh_public_key(), instance_name=_get_instance_name(job), launch_script=_get_launch_script(self.lambda_config, job, instance_type), ) + return LaunchedInstanceInfo(request_id=instance_id, location=region) - def terminate_instance(self, request_id: str): - self.api_client.terminate_instances(instance_ids=[request_id]) + def terminate_instance(self, runner: Runner): + self.api_client.terminate_instances(instance_ids=[runner.request_id]) - def cancel_spot_request(self, request_id: str): + def cancel_spot_request(self, runner: Runner): pass diff --git a/cli/dstack/_internal/backend/local/compute.py b/cli/dstack/_internal/backend/local/compute.py index 568b63db1..0d9b9ee56 100644 --- a/cli/dstack/_internal/backend/local/compute.py +++ b/cli/dstack/_internal/backend/local/compute.py @@ -3,9 +3,10 @@ from dstack._internal.backend.base.compute import Compute, choose_instance_type from dstack._internal.backend.local import runners from dstack._internal.backend.local.config import LocalConfig -from dstack._internal.core.instance import InstanceType +from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo from dstack._internal.core.job import Job from dstack._internal.core.request import RequestHead +from dstack._internal.core.runners import Runner class LocalCompute(Compute): @@ -23,11 +24,12 @@ def get_instance_type(self, job: Job) -> Optional[InstanceType]: ) return instance_type - def run_instance(self, job: Job, instance_type: InstanceType) -> str: - return runners.start_runner_process(self.backend_config, job.runner_id) + def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanceInfo: + pid = runners.start_runner_process(self.backend_config, job.runner_id) + return LaunchedInstanceInfo(request_id=pid, location=None) - def terminate_instance(self, request_id: str): - runners.stop_process(request_id) + def terminate_instance(self, runner: Runner): + runners.stop_process(runner.request_id) - def cancel_spot_request(self, request_id: str): - runners.stop_process(request_id) + def cancel_spot_request(self, runner: Runner): + runners.stop_process(runner.request_id) diff --git a/cli/dstack/_internal/core/instance.py b/cli/dstack/_internal/core/instance.py index 0b297f0bd..97bc7461e 100644 --- a/cli/dstack/_internal/core/instance.py +++ b/cli/dstack/_internal/core/instance.py @@ -9,3 +9,8 @@ class InstanceType(BaseModel): instance_name: str resources: Resources available_regions: Optional[List[str]] = None + + +class LaunchedInstanceInfo(BaseModel): + request_id: str + location: Optional[str] = None diff --git a/cli/dstack/_internal/core/job.py b/cli/dstack/_internal/core/job.py index c771da299..b25405981 100644 --- a/cli/dstack/_internal/core/job.py +++ b/cli/dstack/_internal/core/job.py @@ -204,6 +204,7 @@ class Job(JobHead): app_specs: Optional[List[AppSpec]] runner_id: Optional[str] request_id: Optional[str] + location: Optional[str] tag_name: Optional[str] ssh_key_pub: Optional[str] build_policy: Optional[str] @@ -304,6 +305,7 @@ def serialize(self) -> dict: else [], "runner_id": self.runner_id or "", "request_id": self.request_id or "", + "location": self.location or "", "tag_name": self.tag_name or "", "ssh_key_pub": self.ssh_key_pub or "", "repo_code_filename": self.repo_code_filename, @@ -440,6 +442,7 @@ def unserialize(job_data: dict): app_specs=app_specs, runner_id=job_data.get("runner_id") or None, request_id=job_data.get("request_id") or None, + location=job_data.get("location") or None, tag_name=job_data.get("tag_name") or None, ssh_key_pub=job_data.get("ssh_key_pub") or None, instance_type=job_data.get("instance_type") or None, diff --git a/cli/dstack/_internal/core/request.py b/cli/dstack/_internal/core/request.py index 449625c83..f6d07245d 100644 --- a/cli/dstack/_internal/core/request.py +++ b/cli/dstack/_internal/core/request.py @@ -15,9 +15,3 @@ class RequestHead(BaseModel): job_id: str status: RequestStatus message: Optional[str] - - def __str__(self) -> str: - return ( - f'RequestStatus(job_id="{self.job_id}", status="{self.status.value}", ' - f'message="{self.message})' - ) diff --git a/cli/dstack/_internal/hub/main.py b/cli/dstack/_internal/hub/main.py index 27f8c9144..43a4459ef 100644 --- a/cli/dstack/_internal/hub/main.py +++ b/cli/dstack/_internal/hub/main.py @@ -30,10 +30,11 @@ users, ) from dstack._internal.hub.services.backends import local_backend_available -from dstack._internal.hub.utils import logging +from dstack._internal.hub.utils.logging import configure_logger from dstack._internal.hub.utils.ssh import generate_hub_ssh_key_pair +from dstack._internal.utils import logging -logging.configure_root_logger() +configure_logger() logger = logging.get_logger(__name__) @@ -78,7 +79,7 @@ async def app_logging(request: Request, call_next): path = request.url.path request_dict = {"method": request.method, "path": path} if path.startswith("/api/"): - logger.info( + logger.debug( { "request": request_dict, "process_time": process_time, diff --git a/cli/dstack/_internal/hub/models/__init__.py b/cli/dstack/_internal/hub/models/__init__.py index 15dd760e0..a60041eda 100644 --- a/cli/dstack/_internal/hub/models/__init__.py +++ b/cli/dstack/_internal/hub/models/__init__.py @@ -44,6 +44,7 @@ class AWSProjectConfigPartial(BaseModel): type: Literal["aws"] = "aws" region_name: Optional[str] region_name_title: Optional[str] + extra_regions: Optional[List[str]] s3_bucket_name: Optional[str] ec2_subnet_id: Optional[str] @@ -52,6 +53,7 @@ class AWSProjectConfig(BaseModel): type: Literal["aws"] = "aws" region_name: str region_name_title: Optional[str] + extra_regions: List[str] s3_bucket_name: str ec2_subnet_id: Optional[str] @@ -374,6 +376,7 @@ class AWSProjectValues(BaseModel): type: Literal["aws"] = "aws" default_credentials: bool = False region_name: Optional[ProjectElement] + extra_regions: Optional[ProjectMultiElement] s3_bucket_name: Optional[AWSBucketProjectElement] ec2_subnet_id: Optional[ProjectElement] diff --git a/cli/dstack/_internal/hub/services/backends/aws/configurator.py b/cli/dstack/_internal/hub/services/backends/aws/configurator.py index ac79cc66e..7fdfd6222 100644 --- a/cli/dstack/_internal/hub/services/backends/aws/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/aws/configurator.py @@ -17,6 +17,7 @@ AWSProjectValues, ProjectElement, ProjectElementValue, + ProjectMultiElement, ) from dstack._internal.hub.services.backends.base import BackendConfigError, Configurator @@ -33,6 +34,7 @@ ("Europe, Paris", "eu-west-3"), ("Europe, Stockholm", "eu-north-1"), ] +REGION_VALUES = [r[1] for r in REGIONS] class AWSConfigurator(Configurator): @@ -41,9 +43,10 @@ class AWSConfigurator(Configurator): def configure_project( self, project_config: AWSProjectConfigWithCredsPartial ) -> AWSProjectValues: - if project_config.region_name is not None and project_config.region_name not in { - r[1] for r in REGIONS - }: + if ( + project_config.region_name is not None + and project_config.region_name not in REGION_VALUES + ): raise BackendConfigError(f"Invalid AWS region {project_config.region_name}") project_values = AWSProjectValues() @@ -71,22 +74,25 @@ def configure_project( self._raise_invalid_credentials_error(fields=[["credentials"]]) # TODO validate config values - project_values.region_name = self._get_hub_regions_element( - default_region=session.region_name + project_values.region_name = self._get_hub_region_element(selected=session.region_name) + project_values.extra_regions = self._get_hub_extra_regions_element( + region=session.region_name, + selected=project_config.extra_regions or [], ) project_values.s3_bucket_name = self._get_hub_buckets_element( session=session, region=session.region_name, - default_bucket=project_config.s3_bucket_name, + selected=project_config.s3_bucket_name, ) project_values.ec2_subnet_id = self._get_hub_subnet_element( - session=session, default_subnet=project_config.ec2_subnet_id + session=session, selected=project_config.ec2_subnet_id ) return project_values def create_project(self, project_config: AWSProjectConfigWithCreds) -> Tuple[Dict, Dict]: config_data = { "region_name": project_config.region_name, + "extra_regions": project_config.extra_regions, "s3_bucket_name": project_config.s3_bucket_name.replace("s3://", ""), "ec2_subnet_id": project_config.ec2_subnet_id, } @@ -100,11 +106,13 @@ def get_project_config( region_name = json_config["region_name"] s3_bucket_name = json_config["s3_bucket_name"] ec2_subnet_id = json_config["ec2_subnet_id"] + extra_regions = json_config.get("extra_regions", []) if include_creds: json_auth = json.loads(project.auth) return AWSProjectConfigWithCreds( region_name=region_name, region_name_title=region_name, + extra_regions=extra_regions, s3_bucket_name=s3_bucket_name, ec2_subnet_id=ec2_subnet_id, credentials=AWSProjectCreds.parse_obj(json_auth), @@ -112,6 +120,7 @@ def get_project_config( return AWSProjectConfig( region_name=region_name, region_name_title=region_name, + extra_regions=extra_regions, s3_bucket_name=s3_bucket_name, ec2_subnet_id=ec2_subnet_id, ) @@ -124,6 +133,7 @@ def get_backend(self, project: Project) -> AwsBackend: or config_data.get("bucket_name") or config_data.get("s3_bucket_name"), region_name=config_data.get("region_name"), + extra_regions=config_data.get("extra_regions", []), subnet_id=config_data.get("subnet_id") or config_data.get("ec2_subnet_id") or config_data.get("subnet"), @@ -146,18 +156,27 @@ def _raise_invalid_credentials_error(self, fields: Optional[List[List[str]]] = N fields=fields, ) - def _get_hub_regions_element(self, default_region: Optional[str]) -> ProjectElement: - element = ProjectElement(selected=default_region) + def _get_hub_region_element(self, selected: Optional[str]) -> ProjectElement: + element = ProjectElement(selected=selected) for r in REGIONS: - element.values.append(ProjectElementValue(value=r[1], label=r[0])) + element.values.append(ProjectElementValue(value=r[1], label=r[1])) + return element + + def _get_hub_extra_regions_element( + self, region: str, selected: List[str] + ) -> ProjectMultiElement: + element = ProjectMultiElement(selected=selected) + for r in REGION_VALUES: + if r != region: + element.values.append(ProjectElementValue(value=r, label=r)) return element def _get_hub_buckets_element( - self, session: Session, region: str, default_bucket: Optional[str] + self, session: Session, region: str, selected: Optional[str] ) -> AWSBucketProjectElement: - if default_bucket is not None: - self._validate_hub_bucket(session=session, region=region, bucket_name=default_bucket) - element = AWSBucketProjectElement(selected=default_bucket) + if selected is not None: + self._validate_hub_bucket(session=session, region=region, bucket_name=selected) + element = AWSBucketProjectElement(selected=selected) s3_client = session.client("s3") response = s3_client.list_buckets() for bucket in response["Buckets"]: @@ -194,10 +213,8 @@ def _validate_hub_bucket(self, session: Session, region: str, bucket_name: str): ) raise e - def _get_hub_subnet_element( - self, session: Session, default_subnet: Optional[str] - ) -> ProjectElement: - element = ProjectElement(selected=default_subnet) + def _get_hub_subnet_element(self, session: Session, selected: Optional[str]) -> ProjectElement: + element = ProjectElement(selected=selected) _ec2 = session.client("ec2") response = _ec2.describe_subnets() for subnet in response["Subnets"]: diff --git a/cli/dstack/_internal/hub/utils/logging.py b/cli/dstack/_internal/hub/utils/logging.py index 9c175932e..a322d1a26 100644 --- a/cli/dstack/_internal/hub/utils/logging.py +++ b/cli/dstack/_internal/hub/utils/logging.py @@ -13,8 +13,8 @@ def filter(self, record: logging.LogRecord) -> bool: return True -def configure_root_logger(): - logger = logging.getLogger(None) +def configure_logger(): + root_logger = logging.getLogger(None) handler = logging.StreamHandler(stream=sys.stdout) handler.addFilter(AsyncioCancelledErrorFilter()) formatter = logging.Formatter( @@ -22,11 +22,7 @@ def configure_root_logger(): datefmt="%Y-%m-%dT%H:%M:%S", ) handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(os.getenv("DSTACK_HUB_ROOT_LOG_LEVEL", "ERROR").upper()) - - -def get_logger(name: str) -> logging.Logger: - logger = logging.getLogger(name) - logger.setLevel(os.getenv("DSTACK_HUB_LOG_LEVEL", "ERROR").upper()) - return logger + root_logger.addHandler(handler) + root_logger.setLevel(os.getenv("DSTACK_HUB_ROOT_LOG_LEVEL", "ERROR").upper()) + dstack_logger = logging.getLogger("dstack") + dstack_logger.setLevel(os.getenv("DSTACK_HUB_LOG_LEVEL", "ERROR").upper()) diff --git a/cli/dstack/_internal/utils/logging.py b/cli/dstack/_internal/utils/logging.py new file mode 100644 index 000000000..3d736313d --- /dev/null +++ b/cli/dstack/_internal/utils/logging.py @@ -0,0 +1,5 @@ +import logging + + +def get_logger(name: str) -> logging.Logger: + return logging.getLogger(name) diff --git a/cli/tests/hub/common.py b/cli/tests/hub/common.py index bfad2f291..51615c33b 100644 --- a/cli/tests/hub/common.py +++ b/cli/tests/hub/common.py @@ -29,6 +29,7 @@ async def create_project( if config is None: config = { "region_name": "eu-west-1", + "extra_regions": [], "s3_bucket_name": "dstack-test-eu-west-1", "ec2_subnet_id": None, } diff --git a/cli/tests/hub/routers/test_projects.py b/cli/tests/hub/routers/test_projects.py index 4ec246a01..c42742067 100644 --- a/cli/tests/hub/routers/test_projects.py +++ b/cli/tests/hub/routers/test_projects.py @@ -63,6 +63,7 @@ async def test_successfull_response_format(self, test_db): "type": "aws", "region_name": project_config["region_name"], "region_name_title": project_config["region_name"], + "extra_regions": project_config["extra_regions"], "s3_bucket_name": project_config["s3_bucket_name"], "ec2_subnet_id": project_config["ec2_subnet_id"], }, @@ -104,6 +105,7 @@ async def test_successfull_response_format(self, test_db): }, "region_name": project_config["region_name"], "region_name_title": project_config["region_name"], + "extra_regions": project_config["extra_regions"], "s3_bucket_name": project_config["s3_bucket_name"], "ec2_subnet_id": project_config["ec2_subnet_id"], }, diff --git a/runner/internal/backend/aws/backend.go b/runner/internal/backend/aws/backend.go index 239d08621..a7bc9d142 100644 --- a/runner/internal/backend/aws/backend.go +++ b/runner/internal/backend/aws/backend.go @@ -4,14 +4,15 @@ import ( "context" "errors" "fmt" - "github.com/dstackai/dstack/runner/internal/backend/aws/s3fs" - "github.com/dstackai/dstack/runner/internal/backend/base" "io" "io/ioutil" "path" "strings" "time" + "github.com/dstackai/dstack/runner/internal/backend/aws/s3fs" + "github.com/dstackai/dstack/runner/internal/backend/base" + "github.com/docker/docker/api/types/mount" "github.com/dstackai/dstack/runner/consts" "github.com/dstackai/dstack/runner/internal/backend" @@ -78,6 +79,12 @@ func (s *AWSBackend) Init(ctx context.Context, ID string) error { } s.runnerID = ID err := base.LoadRunnerState(ctx, s.storage, ID, &s.State) + if err != nil { + return gerrors.Wrap(err) + } + if s.State.Job.Location != "" { + s.cliEC2 = NewClientEC2(s.State.Job.Location) + } return gerrors.Wrap(err) } diff --git a/runner/internal/backend/aws/ec2.go b/runner/internal/backend/aws/ec2.go index e8d77a400..07db5b937 100644 --- a/runner/internal/backend/aws/ec2.go +++ b/runner/internal/backend/aws/ec2.go @@ -33,13 +33,13 @@ func NewClientEC2(region string) *ClientEC2 { func (ec *ClientEC2) CancelSpot(ctx context.Context, requestID string) error { log.Trace(ctx, "Cancel spot instance", "ID", requestID) - id, err := ec.getInstanceID(ctx) + _, err := ec.cli.CancelSpotInstanceRequests(ctx, &ec2.CancelSpotInstanceRequestsInput{ + SpotInstanceRequestIds: []string{requestID}, + }) if err != nil { return gerrors.Wrap(err) } - _, err = ec.cli.CancelSpotInstanceRequests(ctx, &ec2.CancelSpotInstanceRequestsInput{ - SpotInstanceRequestIds: []string{requestID}, - }) + id, err := ec.getInstanceID(ctx) if err != nil { return gerrors.Wrap(err) } diff --git a/runner/internal/backend/base/backend.go b/runner/internal/backend/base/backend.go index aeb050eb3..b0c79ceeb 100644 --- a/runner/internal/backend/base/backend.go +++ b/runner/internal/backend/base/backend.go @@ -3,13 +3,14 @@ package base import ( "context" "fmt" + "os" + "strings" + "github.com/dstackai/dstack/runner/internal/gerrors" "github.com/dstackai/dstack/runner/internal/log" "github.com/dstackai/dstack/runner/internal/models" "github.com/dstackai/dstack/runner/internal/repo" "gopkg.in/yaml.v2" - "os" - "strings" ) func LoadRunnerState(ctx context.Context, storage Storage, id string, out interface{}) error { diff --git a/runner/internal/models/backend.go b/runner/internal/models/backend.go index f226da343..cf25ab807 100644 --- a/runner/internal/models/backend.go +++ b/runner/internal/models/backend.go @@ -48,6 +48,7 @@ type Job struct { RepoCodeFilename string `yaml:"repo_code_filename"` RequestID string `yaml:"request_id"` + Location string `yaml:"location"` Requirements Requirements `yaml:"requirements"` RunName string `yaml:"run_name"` RunnerID string `yaml:"runner_id"`