Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 13 additions & 36 deletions cli/dstack/_internal/backend/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

import boto3
from botocore.client import BaseClient
from boto3 import Session

from dstack._internal.backend.aws import utils as aws_utils
from dstack._internal.backend.aws.compute import AWSCompute
from dstack._internal.backend.aws.config import AWSConfig
from dstack._internal.backend.aws.logs import AWSLogging
Expand All @@ -21,31 +21,29 @@ def __init__(
):
self.backend_config = backend_config
if self.backend_config.credentials is not None:
self._session = boto3.session.Session(
self._session = Session(
region_name=self.backend_config.region_name,
aws_access_key_id=self.backend_config.credentials.get("access_key"),
aws_secret_access_key=self.backend_config.credentials.get("secret_key"),
)
else:
self._session = boto3.session.Session(region_name=self.backend_config.region_name)
self._session = Session(region_name=self.backend_config.region_name)
self._storage = AWSStorage(
s3_client=self._s3_client(), bucket_name=self.backend_config.bucket_name
s3_client=aws_utils.get_s3_client(self._session),
bucket_name=self.backend_config.bucket_name,
)
self._compute = AWSCompute(
ec2_client=self._ec2_client(),
iam_client=self._iam_client(),
bucket_name=self.backend_config.bucket_name,
region_name=self.backend_config.region_name,
subnet_id=self.backend_config.subnet_id,
session=self._session,
backend_config=self.backend_config,
)
self._secrets_manager = AWSSecretsManager(
secretsmanager_client=self._secretsmanager_client(),
iam_client=self._iam_client(),
sts_client=self._sts_client(),
secretsmanager_client=aws_utils.get_secretsmanager_client(self._session),
iam_client=aws_utils.get_iam_client(self._session),
sts_client=aws_utils.get_sts_client(self._session),
bucket_name=self.backend_config.bucket_name,
)
self._logging = AWSLogging(
logs_client=self._logs_client(),
logs_client=aws_utils.get_logs_client(self._session),
bucket_name=self.backend_config.bucket_name,
)

Expand All @@ -72,27 +70,6 @@ def logging(self) -> AWSLogging:

def create_run(self, repo_id: str) -> str:
self._logging.create_log_groups_if_not_exist(
self._logs_client(), self.backend_config.bucket_name, repo_id
aws_utils.get_logs_client(self._session), self.backend_config.bucket_name, repo_id
)
return base_runs.create_run(self._storage)

def _s3_client(self) -> BaseClient:
return self._get_client("s3")

def _ec2_client(self) -> BaseClient:
return self._get_client("ec2")

def _iam_client(self) -> BaseClient:
return self._get_client("iam")

def _logs_client(self) -> BaseClient:
return self._get_client("logs")

def _secretsmanager_client(self) -> BaseClient:
return self._get_client("secretsmanager")

def _sts_client(self) -> BaseClient:
return self._get_client("sts")

def _get_client(self, client_name: str) -> BaseClient:
return self._session.client(client_name)
56 changes: 30 additions & 26 deletions cli/dstack/_internal/backend/aws/compute.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,48 @@
from typing import Optional

from botocore.client import BaseClient
from boto3 import Session

from dstack._internal.backend.aws import runners
from dstack._internal.backend.aws import utils as aws_utils
from dstack._internal.backend.aws.config import AWSConfig
from dstack._internal.backend.base.compute import Compute
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.instance import InstanceType, LaunchedInstanceInfo
from dstack._internal.core.job import Job
from dstack._internal.core.request import RequestHead
from dstack._internal.core.runners import Runner


class AWSCompute(Compute):
def __init__(
self,
ec2_client: BaseClient,
iam_client: BaseClient,
bucket_name: str,
region_name: str,
subnet_id: str,
session: Session,
backend_config: AWSConfig,
):
self.ec2_client = ec2_client
self.iam_client = iam_client
self.bucket_name = bucket_name
self.region_name = region_name
self.subnet_id = subnet_id
self.session = session
self.iam_client = aws_utils.get_iam_client(session)
self.backend_config = backend_config

def get_request_head(self, job: Job, request_id: Optional[str]) -> RequestHead:
return runners.get_request_head(
ec2_client=self.ec2_client,
ec2_client=self._get_ec2_client(region=job.location),
job=job,
request_id=request_id,
)

def get_instance_type(self, job: Job) -> Optional[InstanceType]:
return runners.get_instance_type(
ec2_client=self.ec2_client,
ec2_client=self._get_ec2_client(),
requirements=job.requirements,
)

def run_instance(self, job: Job, instance_type: InstanceType) -> str:
return runners.run_instance_retry(
ec2_client=self.ec2_client,
def run_instance(self, job: Job, instance_type: InstanceType) -> LaunchedInstanceInfo:
return runners.run_instance(
session=self.session,
iam_client=self.iam_client,
bucket_name=self.bucket_name,
region_name=self.region_name,
subnet_id=self.subnet_id,
bucket_name=self.backend_config.bucket_name,
region_name=self.backend_config.region_name,
extra_regions=self.backend_config.extra_regions,
subnet_id=self.backend_config.subnet_id,
runner_id=job.runner_id,
instance_type=instance_type,
spot=job.requirements.spot,
Expand All @@ -52,14 +51,19 @@ def run_instance(self, job: Job, instance_type: InstanceType) -> str:
ssh_key_pub=job.ssh_key_pub,
)

def terminate_instance(self, request_id: str):
def terminate_instance(self, runner: Runner):
runners.terminate_instance(
ec2_client=self.ec2_client,
request_id=request_id,
ec2_client=self._get_ec2_client(region=runner.job.location),
request_id=runner.request_id,
)

def cancel_spot_request(self, request_id: str):
def cancel_spot_request(self, runner: Runner):
runners.cancel_spot_request(
ec2_client=self.ec2_client,
request_id=request_id,
ec2_client=self._get_ec2_client(region=runner.job.location),
request_id=runner.request_id,
)

def _get_ec2_client(self, region: Optional[str] = None):
if region is None:
return aws_utils.get_ec2_client(self.session)
return aws_utils.get_ec2_client(self.session, region_name=region)
6 changes: 5 additions & 1 deletion cli/dstack/_internal/backend/aws/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Dict, List, Optional

from pydantic import BaseModel

Expand All @@ -10,6 +10,7 @@
class AWSConfig(BackendConfig, BaseModel):
bucket_name: str
region_name: Optional[str] = DEFAULT_REGION_NAME
extra_regions: List[str] = []
subnet_id: Optional[str] = None
credentials: Optional[Dict] = None

Expand All @@ -20,6 +21,8 @@ def serialize(self) -> Dict:
}
if self.region_name:
config_data["region"] = self.region_name
if self.extra_regions:
config_data["extra_regions"] = self.extra_regions
if self.subnet_id:
config_data["subnet"] = self.subnet_id
return config_data
Expand All @@ -35,5 +38,6 @@ def deserialize(cls, config_data: Dict) -> Optional["AWSConfig"]:
return cls(
bucket_name=bucket_name,
region_name=config_data.get("region"),
extra_regions=config_data.get("extra_regions", []),
subnet_id=config_data.get("subnet"),
)
Loading