Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 27 additions & 219 deletions cli/dstack/_internal/backend/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,25 @@
from datetime import datetime
from typing import Generator, List, Optional
from typing import Optional

import boto3
from botocore.client import BaseClient

from dstack._internal.backend.aws import logs
from dstack._internal.backend.aws.compute import AWSCompute
from dstack._internal.backend.aws.config import AWSConfig
from dstack._internal.backend.aws.logs import AWSLogging
from dstack._internal.backend.aws.secrets import AWSSecretsManager
from dstack._internal.backend.aws.storage import AWSStorage
from dstack._internal.backend.base import Backend
from dstack._internal.backend.base import artifacts as base_artifacts
from dstack._internal.backend.base import cache as base_cache
from dstack._internal.backend.base import jobs as base_jobs
from dstack._internal.backend.base import repos as base_repos
from dstack._internal.backend.base import ComponentBasedBackend
from dstack._internal.backend.base import runs as base_runs
from dstack._internal.backend.base import secrets as base_secrets
from dstack._internal.backend.base import tags as base_tags
from dstack._internal.core.artifact import Artifact
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job, JobHead, JobStatus
from dstack._internal.core.log_event import LogEvent
from dstack._internal.core.repo import RemoteRepoCredentials, RepoHead, RepoSpec
from dstack._internal.core.repo.base import Repo
from dstack._internal.core.run import RunHead
from dstack._internal.core.secret import Secret
from dstack._internal.core.tag import TagHead
from dstack._internal.utils.common import PathLike


class AwsBackend(Backend):
class AwsBackend(ComponentBasedBackend):
NAME = "aws"
backend_config: AWSConfig
_storage: AWSStorage
_compute: AWSCompute
_secrets_manager: AWSSecretsManager

def __init__(
self,
backend_config: AWSConfig,
):
super().__init__(backend_config=backend_config)
self.backend_config = backend_config
if self.backend_config.credentials is not None:
self._session = boto3.session.Session(
region_name=self.backend_config.region_name,
Expand All @@ -65,6 +44,10 @@ def __init__(
sts_client=self._sts_client(),
bucket_name=self.backend_config.bucket_name,
)
self._logging = AWSLogging(
logs_client=self._logs_client(),
bucket_name=self.backend_config.bucket_name,
)

@classmethod
def load(cls) -> Optional["AwsBackend"]:
Expand All @@ -75,6 +58,24 @@ def load(cls) -> Optional["AwsBackend"]:
backend_config=config,
)

def storage(self) -> AWSStorage:
return self._storage

def compute(self) -> AWSCompute:
return self._compute

def secrets_manager(self) -> AWSSecretsManager:
return self._secrets_manager

def logging(self) -> AWSLogging:
return self._logging

def create_run(self, repo_id: str) -> str:
self._logging.create_log_groups_if_not_exist(
self._logs_client(), self.backend_config.bucket_name, repo_id
)
return base_runs.create_run(self._storage)

def _s3_client(self) -> BaseClient:
return self._get_client("s3")

Expand All @@ -95,196 +96,3 @@ def _sts_client(self) -> BaseClient:

def _get_client(self, client_name: str) -> BaseClient:
return self._session.client(client_name)

def predict_instance_type(self, job: Job) -> Optional[InstanceType]:
return base_jobs.predict_job_instance(self._compute, job)

def create_run(self, repo_id: str) -> str:
logs.create_log_groups_if_not_exist(
self._logs_client(), self.backend_config.bucket_name, repo_id
)
return base_runs.create_run(self._storage)

def create_job(self, job: Job):
base_jobs.create_job(self._storage, job)

def get_job(self, repo_id: str, job_id: str) -> Optional[Job]:
return base_jobs.get_job(self._storage, repo_id, job_id)

def list_jobs(self, repo_id: str, run_name: str) -> List[Job]:
return base_jobs.list_jobs(self._storage, repo_id, run_name)

def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus):
base_jobs.run_job(self._storage, self._compute, job, failed_to_start_job_new_status)

def stop_job(self, repo_id: str, abort: bool, job_id: str):
base_jobs.stop_job(self._storage, self._compute, repo_id, job_id, abort)

def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]:
return base_jobs.list_job_heads(self._storage, repo_id, run_name)

def delete_job_head(self, repo_id: str, job_id: str):
base_jobs.delete_job_head(self._storage, repo_id, job_id)

def list_run_heads(
self,
repo_id: str,
run_name: Optional[str] = None,
include_request_heads: bool = True,
interrupted_job_new_status: JobStatus = JobStatus.FAILED,
) -> List[RunHead]:
job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name)
return base_runs.get_run_heads(
self._storage,
self._compute,
job_heads,
include_request_heads,
interrupted_job_new_status,
)

def poll_logs(
self,
repo_id: str,
run_name: str,
start_time: datetime,
end_time: Optional[datetime] = None,
descending: bool = False,
diagnose: bool = False,
) -> Generator[LogEvent, None, None]:
return logs.poll_logs(
self._storage,
self._logs_client(),
self.backend_config.bucket_name,
repo_id,
run_name,
start_time,
end_time,
descending,
diagnose,
)

def list_run_artifact_files(
self, repo_id: str, run_name: str, prefix: str, recursive: bool = False
) -> List[Artifact]:
return base_artifacts.list_run_artifact_files(
self._storage, repo_id, run_name, prefix, recursive
)

def download_run_artifact_files(
self,
repo_id: str,
run_name: str,
output_dir: Optional[PathLike],
files_path: Optional[PathLike] = None,
):
artifacts = self.list_run_artifact_files(
repo_id, run_name=run_name, prefix="", recursive=True
)
base_artifacts.download_run_artifact_files(
storage=self._storage,
repo_id=repo_id,
artifacts=artifacts,
output_dir=output_dir,
files_path=files_path,
)

def upload_job_artifact_files(
self,
repo_id: str,
job_id: str,
artifact_name: str,
artifact_path: PathLike,
local_path: PathLike,
):
base_artifacts.upload_job_artifact_files(
storage=self._storage,
repo_id=repo_id,
job_id=job_id,
artifact_name=artifact_name,
artifact_path=artifact_path,
local_path=local_path,
)

def list_tag_heads(self, repo_id: str) -> List[TagHead]:
return base_tags.list_tag_heads(self._storage, repo_id)

def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]:
return base_tags.get_tag_head(self._storage, repo_id, tag_name)

def add_tag_from_run(
self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]]
):
base_tags.create_tag_from_run(
self._storage,
repo_id,
tag_name,
run_name,
run_jobs,
)

def add_tag_from_local_dirs(
self,
repo: Repo,
hub_user_name: str,
tag_name: str,
local_dirs: List[str],
artifact_paths: List[str],
):
base_tags.create_tag_from_local_dirs(
storage=self._storage,
repo=repo,
hub_user_name=hub_user_name,
tag_name=tag_name,
local_dirs=local_dirs,
artifact_paths=artifact_paths,
)

def delete_tag_head(self, repo_id: str, tag_head: TagHead):
base_tags.delete_tag(self._storage, repo_id, tag_head)

def list_repo_heads(self) -> List[RepoHead]:
return base_repos.list_repo_heads(self._storage)

def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int):
base_repos.update_repo_last_run_at(
self._storage,
repo_spec,
last_run_at,
)

def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]:
return base_repos.get_repo_credentials(self._secrets_manager, repo_id)

def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials):
base_repos.save_repo_credentials(self._secrets_manager, repo_id, repo_credentials)

def delete_repo(self, repo_id: str):
base_repos.delete_repo(self._storage, repo_id)

def list_secret_names(self, repo_id: str) -> List[str]:
return base_secrets.list_secret_names(self._storage, repo_id)

def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]:
return base_secrets.get_secret(self._secrets_manager, repo_id, repo_id)

def add_secret(self, repo_id: str, secret: Secret):
base_secrets.add_secret(self._storage, self._secrets_manager, repo_id, secret)

def update_secret(self, repo_id: str, secret: Secret):
base_secrets.update_secret(self._storage, self._secrets_manager, repo_id, secret)

def delete_secret(self, repo_id: str, secret_name: str):
base_secrets.delete_secret(self._storage, self._secrets_manager, repo_id, repo_id)

def get_signed_download_url(self, object_key: str) -> str:
return self._storage.get_signed_download_url(object_key)

def get_signed_upload_url(self, object_key: str) -> str:
return self._storage.get_signed_upload_url(object_key)

def delete_configuration_cache(
self, repo_id: str, hub_user_name: str, configuration_path: str
):
base_cache.delete_configuration_cache(
self._storage, repo_id, hub_user_name, configuration_path
)
66 changes: 16 additions & 50 deletions cli/dstack/_internal/backend/aws/config.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,17 @@
import os
from typing import Dict, Optional

from dstack._internal.backend.base.config import BackendConfig
from pydantic import BaseModel

_DEFAULT_REGION_NAME = "us-east-1"
from dstack._internal.backend.base.config import BackendConfig

DEFAULT_REGION_NAME = "us-east-1"

class AWSConfig(BackendConfig):
bucket_name = None
region_name = None
profile_name = None
subnet_id = None
credentials = None

def __init__(
self,
bucket_name: Optional[str] = None,
region_name: Optional[str] = None,
profile_name: Optional[str] = None,
subnet_id: Optional[str] = None,
credentials: Optional[Dict] = None,
):
self.bucket_name = bucket_name or os.getenv("DSTACK_AWS_S3_BUCKET")
self.region_name = (
region_name
or os.getenv("DSTACK_AWS_REGION")
or os.getenv("AWS_DEFAULT_REGION")
or _DEFAULT_REGION_NAME
)
self.profile_name = (
profile_name or os.getenv("DSTACK_AWS_PROFILE") or os.getenv("AWS_PROFILE")
)
self.subnet_id = subnet_id or os.getenv("DSTACK_AWS_EC2_SUBNET")
self.credentials = credentials
class AWSConfig(BackendConfig, BaseModel):
bucket_name: str
region_name: Optional[str] = DEFAULT_REGION_NAME
subnet_id: Optional[str] = None
credentials: Optional[Dict] = None

def serialize(self) -> Dict:
config_data = {
Expand All @@ -41,33 +20,20 @@ def serialize(self) -> Dict:
}
if self.region_name:
config_data["region"] = self.region_name
if self.profile_name:
config_data["profile"] = self.profile_name
if self.subnet_id:
config_data["subnet"] = self.subnet_id
return config_data

@classmethod
def deserialize(cls, config_data: Dict, auth_data: Dict = None) -> Optional["AWSConfig"]:
backend = config_data.get("backend") or config_data.get("type")
if backend != "aws":
def deserialize(cls, config_data: Dict) -> Optional["AWSConfig"]:
if config_data.get("backend") != "aws":
return None
try:
bucket_name = config_data["bucket"]
except KeyError:
return None
bucket_name = (
config_data.get("bucket")
or config_data.get("bucket_name")
or config_data.get("s3_bucket_name")
)
region_name = config_data.get("region_name") or _DEFAULT_REGION_NAME
profile_name = config_data.get("profile_name")
subnet_id = (
config_data.get("subnet_id")
or config_data.get("ec2_subnet_id")
or config_data.get("subnet")
)
return cls(
bucket_name=bucket_name,
region_name=region_name,
profile_name=profile_name,
subnet_id=subnet_id,
credentials=auth_data,
region_name=config_data.get("region"),
subnet_id=config_data.get("subnet"),
)
Loading