diff --git a/cli/dstack/_internal/backend/aws/__init__.py b/cli/dstack/_internal/backend/aws/__init__.py index 9d57d15b4..612ed1f6d 100644 --- a/cli/dstack/_internal/backend/aws/__init__.py +++ b/cli/dstack/_internal/backend/aws/__init__.py @@ -1,46 +1,25 @@ -from datetime import datetime -from typing import Generator, List, Optional +from typing import Optional import boto3 from botocore.client import BaseClient -from dstack._internal.backend.aws import logs from dstack._internal.backend.aws.compute import AWSCompute from dstack._internal.backend.aws.config import AWSConfig +from dstack._internal.backend.aws.logs import AWSLogging from dstack._internal.backend.aws.secrets import AWSSecretsManager from dstack._internal.backend.aws.storage import AWSStorage -from dstack._internal.backend.base import Backend -from dstack._internal.backend.base import artifacts as base_artifacts -from dstack._internal.backend.base import cache as base_cache -from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base import repos as base_repos +from dstack._internal.backend.base import ComponentBasedBackend from dstack._internal.backend.base import runs as base_runs -from dstack._internal.backend.base import secrets as base_secrets -from dstack._internal.backend.base import tags as base_tags -from dstack._internal.core.artifact import Artifact -from dstack._internal.core.instance import InstanceType -from dstack._internal.core.job import Job, JobHead, JobStatus -from dstack._internal.core.log_event import LogEvent -from dstack._internal.core.repo import RemoteRepoCredentials, RepoHead, RepoSpec -from dstack._internal.core.repo.base import Repo -from dstack._internal.core.run import RunHead -from dstack._internal.core.secret import Secret -from dstack._internal.core.tag import TagHead -from dstack._internal.utils.common import PathLike -class AwsBackend(Backend): +class AwsBackend(ComponentBasedBackend): NAME = "aws" - backend_config: AWSConfig - _storage: AWSStorage - _compute: AWSCompute - _secrets_manager: AWSSecretsManager def __init__( self, backend_config: AWSConfig, ): - super().__init__(backend_config=backend_config) + self.backend_config = backend_config if self.backend_config.credentials is not None: self._session = boto3.session.Session( region_name=self.backend_config.region_name, @@ -65,6 +44,10 @@ def __init__( sts_client=self._sts_client(), bucket_name=self.backend_config.bucket_name, ) + self._logging = AWSLogging( + logs_client=self._logs_client(), + bucket_name=self.backend_config.bucket_name, + ) @classmethod def load(cls) -> Optional["AwsBackend"]: @@ -75,6 +58,24 @@ def load(cls) -> Optional["AwsBackend"]: backend_config=config, ) + def storage(self) -> AWSStorage: + return self._storage + + def compute(self) -> AWSCompute: + return self._compute + + def secrets_manager(self) -> AWSSecretsManager: + return self._secrets_manager + + def logging(self) -> AWSLogging: + return self._logging + + def create_run(self, repo_id: str) -> str: + self._logging.create_log_groups_if_not_exist( + self._logs_client(), self.backend_config.bucket_name, repo_id + ) + return base_runs.create_run(self._storage) + def _s3_client(self) -> BaseClient: return self._get_client("s3") @@ -95,196 +96,3 @@ def _sts_client(self) -> BaseClient: def _get_client(self, client_name: str) -> BaseClient: return self._session.client(client_name) - - def predict_instance_type(self, job: Job) -> Optional[InstanceType]: - return base_jobs.predict_job_instance(self._compute, job) - - def create_run(self, repo_id: str) -> str: - logs.create_log_groups_if_not_exist( - self._logs_client(), self.backend_config.bucket_name, repo_id - ) - return base_runs.create_run(self._storage) - - def create_job(self, job: Job): - base_jobs.create_job(self._storage, job) - - def get_job(self, repo_id: str, job_id: str) -> Optional[Job]: - return base_jobs.get_job(self._storage, repo_id, job_id) - - def list_jobs(self, repo_id: str, run_name: str) -> List[Job]: - return base_jobs.list_jobs(self._storage, repo_id, run_name) - - def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus): - base_jobs.run_job(self._storage, self._compute, job, failed_to_start_job_new_status) - - def stop_job(self, repo_id: str, abort: bool, job_id: str): - base_jobs.stop_job(self._storage, self._compute, repo_id, job_id, abort) - - def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]: - return base_jobs.list_job_heads(self._storage, repo_id, run_name) - - def delete_job_head(self, repo_id: str, job_id: str): - base_jobs.delete_job_head(self._storage, repo_id, job_id) - - def list_run_heads( - self, - repo_id: str, - run_name: Optional[str] = None, - include_request_heads: bool = True, - interrupted_job_new_status: JobStatus = JobStatus.FAILED, - ) -> List[RunHead]: - job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name) - return base_runs.get_run_heads( - self._storage, - self._compute, - job_heads, - include_request_heads, - interrupted_job_new_status, - ) - - def poll_logs( - self, - repo_id: str, - run_name: str, - start_time: datetime, - end_time: Optional[datetime] = None, - descending: bool = False, - diagnose: bool = False, - ) -> Generator[LogEvent, None, None]: - return logs.poll_logs( - self._storage, - self._logs_client(), - self.backend_config.bucket_name, - repo_id, - run_name, - start_time, - end_time, - descending, - diagnose, - ) - - def list_run_artifact_files( - self, repo_id: str, run_name: str, prefix: str, recursive: bool = False - ) -> List[Artifact]: - return base_artifacts.list_run_artifact_files( - self._storage, repo_id, run_name, prefix, recursive - ) - - def download_run_artifact_files( - self, - repo_id: str, - run_name: str, - output_dir: Optional[PathLike], - files_path: Optional[PathLike] = None, - ): - artifacts = self.list_run_artifact_files( - repo_id, run_name=run_name, prefix="", recursive=True - ) - base_artifacts.download_run_artifact_files( - storage=self._storage, - repo_id=repo_id, - artifacts=artifacts, - output_dir=output_dir, - files_path=files_path, - ) - - def upload_job_artifact_files( - self, - repo_id: str, - job_id: str, - artifact_name: str, - artifact_path: PathLike, - local_path: PathLike, - ): - base_artifacts.upload_job_artifact_files( - storage=self._storage, - repo_id=repo_id, - job_id=job_id, - artifact_name=artifact_name, - artifact_path=artifact_path, - local_path=local_path, - ) - - def list_tag_heads(self, repo_id: str) -> List[TagHead]: - return base_tags.list_tag_heads(self._storage, repo_id) - - def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]: - return base_tags.get_tag_head(self._storage, repo_id, tag_name) - - def add_tag_from_run( - self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]] - ): - base_tags.create_tag_from_run( - self._storage, - repo_id, - tag_name, - run_name, - run_jobs, - ) - - def add_tag_from_local_dirs( - self, - repo: Repo, - hub_user_name: str, - tag_name: str, - local_dirs: List[str], - artifact_paths: List[str], - ): - base_tags.create_tag_from_local_dirs( - storage=self._storage, - repo=repo, - hub_user_name=hub_user_name, - tag_name=tag_name, - local_dirs=local_dirs, - artifact_paths=artifact_paths, - ) - - def delete_tag_head(self, repo_id: str, tag_head: TagHead): - base_tags.delete_tag(self._storage, repo_id, tag_head) - - def list_repo_heads(self) -> List[RepoHead]: - return base_repos.list_repo_heads(self._storage) - - def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int): - base_repos.update_repo_last_run_at( - self._storage, - repo_spec, - last_run_at, - ) - - def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]: - return base_repos.get_repo_credentials(self._secrets_manager, repo_id) - - def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials): - base_repos.save_repo_credentials(self._secrets_manager, repo_id, repo_credentials) - - def delete_repo(self, repo_id: str): - base_repos.delete_repo(self._storage, repo_id) - - def list_secret_names(self, repo_id: str) -> List[str]: - return base_secrets.list_secret_names(self._storage, repo_id) - - def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]: - return base_secrets.get_secret(self._secrets_manager, repo_id, repo_id) - - def add_secret(self, repo_id: str, secret: Secret): - base_secrets.add_secret(self._storage, self._secrets_manager, repo_id, secret) - - def update_secret(self, repo_id: str, secret: Secret): - base_secrets.update_secret(self._storage, self._secrets_manager, repo_id, secret) - - def delete_secret(self, repo_id: str, secret_name: str): - base_secrets.delete_secret(self._storage, self._secrets_manager, repo_id, repo_id) - - def get_signed_download_url(self, object_key: str) -> str: - return self._storage.get_signed_download_url(object_key) - - def get_signed_upload_url(self, object_key: str) -> str: - return self._storage.get_signed_upload_url(object_key) - - def delete_configuration_cache( - self, repo_id: str, hub_user_name: str, configuration_path: str - ): - base_cache.delete_configuration_cache( - self._storage, repo_id, hub_user_name, configuration_path - ) diff --git a/cli/dstack/_internal/backend/aws/config.py b/cli/dstack/_internal/backend/aws/config.py index ef1f31670..cc63a20bf 100644 --- a/cli/dstack/_internal/backend/aws/config.py +++ b/cli/dstack/_internal/backend/aws/config.py @@ -1,38 +1,17 @@ -import os from typing import Dict, Optional -from dstack._internal.backend.base.config import BackendConfig +from pydantic import BaseModel -_DEFAULT_REGION_NAME = "us-east-1" +from dstack._internal.backend.base.config import BackendConfig +DEFAULT_REGION_NAME = "us-east-1" -class AWSConfig(BackendConfig): - bucket_name = None - region_name = None - profile_name = None - subnet_id = None - credentials = None - def __init__( - self, - bucket_name: Optional[str] = None, - region_name: Optional[str] = None, - profile_name: Optional[str] = None, - subnet_id: Optional[str] = None, - credentials: Optional[Dict] = None, - ): - self.bucket_name = bucket_name or os.getenv("DSTACK_AWS_S3_BUCKET") - self.region_name = ( - region_name - or os.getenv("DSTACK_AWS_REGION") - or os.getenv("AWS_DEFAULT_REGION") - or _DEFAULT_REGION_NAME - ) - self.profile_name = ( - profile_name or os.getenv("DSTACK_AWS_PROFILE") or os.getenv("AWS_PROFILE") - ) - self.subnet_id = subnet_id or os.getenv("DSTACK_AWS_EC2_SUBNET") - self.credentials = credentials +class AWSConfig(BackendConfig, BaseModel): + bucket_name: str + region_name: Optional[str] = DEFAULT_REGION_NAME + subnet_id: Optional[str] = None + credentials: Optional[Dict] = None def serialize(self) -> Dict: config_data = { @@ -41,33 +20,20 @@ def serialize(self) -> Dict: } if self.region_name: config_data["region"] = self.region_name - if self.profile_name: - config_data["profile"] = self.profile_name if self.subnet_id: config_data["subnet"] = self.subnet_id return config_data @classmethod - def deserialize(cls, config_data: Dict, auth_data: Dict = None) -> Optional["AWSConfig"]: - backend = config_data.get("backend") or config_data.get("type") - if backend != "aws": + def deserialize(cls, config_data: Dict) -> Optional["AWSConfig"]: + if config_data.get("backend") != "aws": + return None + try: + bucket_name = config_data["bucket"] + except KeyError: return None - bucket_name = ( - config_data.get("bucket") - or config_data.get("bucket_name") - or config_data.get("s3_bucket_name") - ) - region_name = config_data.get("region_name") or _DEFAULT_REGION_NAME - profile_name = config_data.get("profile_name") - subnet_id = ( - config_data.get("subnet_id") - or config_data.get("ec2_subnet_id") - or config_data.get("subnet") - ) return cls( bucket_name=bucket_name, - region_name=region_name, - profile_name=profile_name, - subnet_id=subnet_id, - credentials=auth_data, + region_name=config_data.get("region"), + subnet_id=config_data.get("subnet"), ) diff --git a/cli/dstack/_internal/backend/aws/logs.py b/cli/dstack/_internal/backend/aws/logs.py index f7f8c7d8e..b3b8c62b4 100644 --- a/cli/dstack/_internal/backend/aws/logs.py +++ b/cli/dstack/_internal/backend/aws/logs.py @@ -4,7 +4,7 @@ from botocore.client import BaseClient from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base.logs import fix_log_event_urls, render_log_event +from dstack._internal.backend.base.logs import Logging, fix_log_event_urls, render_log_event from dstack._internal.backend.base.storage import Storage from dstack._internal.core.log_event import LogEvent from dstack._internal.utils.common import ( @@ -13,58 +13,70 @@ ) -def poll_logs( - storage: Storage, - logs_client: BaseClient, - bucket_name: str, - repo_id: str, - run_name: str, - start_time: datetime, - end_time: Optional[datetime], - descending: bool, - diagnose: bool, -) -> Generator[LogEvent, None, None]: - jobs = base_jobs.list_jobs(storage, repo_id, run_name) - jobs_map = {j.job_id: j for j in jobs} - if diagnose: - runner_id = jobs[0].runner_id - log_group = f"/dstack/runners/{bucket_name}" - log_stream = runner_id - else: - log_group = f"/dstack/jobs/{bucket_name}/{repo_id}" - log_stream = run_name - filter_logs_events_kwargs = _filter_logs_events_kwargs( - log_group=log_group, - log_stream=log_stream, - start_time=start_time, - end_time=end_time, - next_token=None, - ) - try: - paginator = logs_client.get_paginator("filter_log_events") - pages = paginator.paginate(**filter_logs_events_kwargs) - # aws sdk doesn't provide a way to order events by descending so we have to do it by hand - if descending: - pages = reversed(list(pages)) - for page in pages: - events = page["events"] - if descending: - events = reversed(page["events"]) - for event in events: - event["timestamp"] = timestamps_in_milliseconds_to_datetime(event["timestamp"]) - log_event = render_log_event(event) - if not diagnose: - log_event = fix_log_event_urls(log_event, jobs_map) - yield log_event - except Exception as e: - if ( - hasattr(e, "response") - and e.response.get("Error") - and e.response["Error"].get("Code") == "ResourceNotFoundException" - ): - return +class AWSLogging(Logging): + def __init__(self, logs_client: BaseClient, bucket_name: str): + self.logs_client = logs_client + self.bucket_name = bucket_name + + def poll_logs( + self, + storage: Storage, + repo_id: str, + run_name: str, + start_time: datetime, + end_time: Optional[datetime], + descending: bool, + diagnose: bool, + ) -> Generator[LogEvent, None, None]: + jobs = base_jobs.list_jobs(storage, repo_id, run_name) + jobs_map = {j.job_id: j for j in jobs} + if diagnose: + runner_id = jobs[0].runner_id + log_group = f"/dstack/runners/{self.bucket_name}" + log_stream = runner_id else: - raise e + log_group = f"/dstack/jobs/{self.bucket_name}/{repo_id}" + log_stream = run_name + filter_logs_events_kwargs = _filter_logs_events_kwargs( + log_group=log_group, + log_stream=log_stream, + start_time=start_time, + end_time=end_time, + next_token=None, + ) + try: + paginator = self.logs_client.get_paginator("filter_log_events") + pages = paginator.paginate(**filter_logs_events_kwargs) + # aws sdk doesn't provide a way to order events by descending so we have to do it by hand + if descending: + pages = reversed(list(pages)) + for page in pages: + events = page["events"] + if descending: + events = reversed(page["events"]) + for event in events: + event["timestamp"] = timestamps_in_milliseconds_to_datetime(event["timestamp"]) + log_event = render_log_event(event) + if not diagnose: + log_event = fix_log_event_urls(log_event, jobs_map) + yield log_event + except Exception as e: + if ( + hasattr(e, "response") + and e.response.get("Error") + and e.response["Error"].get("Code") == "ResourceNotFoundException" + ): + return + else: + raise e + + def create_log_groups_if_not_exist( + self, logs_client: BaseClient, bucket_name: str, repo_id: str + ): + _create_log_group_if_not_exists( + logs_client, bucket_name, f"/dstack/jobs/{bucket_name}/{repo_id}" + ) + _create_log_group_if_not_exists(logs_client, bucket_name, f"/dstack/runners/{bucket_name}") def _filter_logs_events_kwargs( @@ -86,13 +98,6 @@ def _filter_logs_events_kwargs( return filter_logs_events_kwargs -def create_log_groups_if_not_exist(logs_client: BaseClient, bucket_name: str, repo_id: str): - _create_log_group_if_not_exists( - logs_client, bucket_name, f"/dstack/jobs/{bucket_name}/{repo_id}" - ) - _create_log_group_if_not_exists(logs_client, bucket_name, f"/dstack/runners/{bucket_name}") - - def _create_log_group_if_not_exists( logs_client: BaseClient, bucket_name: str, log_group_name: str ): @@ -107,7 +112,3 @@ def _create_log_group_if_not_exists( "dstack_bucket": bucket_name, }, ) - - -def create_log_stream(logs_client: BaseClient, log_group_name: str, run_name: str): - logs_client.create_log_stream(logGroupName=log_group_name, logStreamName=run_name) diff --git a/cli/dstack/_internal/backend/azure/__init__.py b/cli/dstack/_internal/backend/azure/__init__.py index c86cfad5f..d51ed4c42 100644 --- a/cli/dstack/_internal/backend/azure/__init__.py +++ b/cli/dstack/_internal/backend/azure/__init__.py @@ -1,5 +1,4 @@ -from datetime import datetime -from typing import Generator, List, Optional +from typing import Optional from azure.core.credentials import TokenCredential from azure.identity import ClientSecretCredential, DefaultAzureCredential @@ -9,33 +8,14 @@ from dstack._internal.backend.azure.logs import AzureLogging from dstack._internal.backend.azure.secrets import AzureSecretsManager from dstack._internal.backend.azure.storage import AzureStorage -from dstack._internal.backend.base import Backend -from dstack._internal.backend.base import artifacts as base_artifacts -from dstack._internal.backend.base import cache as base_cache -from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base import repos as base_repos -from dstack._internal.backend.base import runs as base_runs -from dstack._internal.backend.base import secrets as base_secrets -from dstack._internal.backend.base import tags as base_tags -from dstack._internal.core.artifact import Artifact -from dstack._internal.core.instance import InstanceType -from dstack._internal.core.job import Job, JobHead, JobStatus -from dstack._internal.core.log_event import LogEvent -from dstack._internal.core.repo.base import Repo -from dstack._internal.core.repo.head import RepoHead -from dstack._internal.core.repo.remote import RemoteRepoCredentials -from dstack._internal.core.repo.spec import RepoSpec -from dstack._internal.core.run import RunHead -from dstack._internal.core.secret import Secret -from dstack._internal.core.tag import TagHead -from dstack._internal.utils.common import PathLike +from dstack._internal.backend.base import ComponentBasedBackend -class AzureBackend(Backend): +class AzureBackend(ComponentBasedBackend): NAME = "azure" def __init__(self, backend_config: AzureConfig, credential: Optional[TokenCredential] = None): - super().__init__(backend_config=backend_config) + self.backend_config = backend_config if credential is None: if backend_config.credentials["type"] == "client": credential = ClientSecretCredential( @@ -71,190 +51,14 @@ def load(cls) -> Optional["AzureBackend"]: return None return cls(backend_config=config, credential=DefaultAzureCredential()) - def predict_instance_type(self, job: Job) -> Optional[InstanceType]: - return base_jobs.predict_job_instance(self._compute, job) + def storage(self) -> AzureStorage: + return self._storage - def create_run(self, repo_id: str) -> str: - return base_runs.create_run(self._storage) + def compute(self) -> AzureCompute: + return self._compute - def create_job(self, job: Job): - base_jobs.create_job(self._storage, job) + def secrets_manager(self) -> AzureSecretsManager: + return self._secrets_manager - def get_job(self, repo_id: str, job_id: str) -> Optional[Job]: - return base_jobs.get_job(self._storage, repo_id, job_id) - - def list_jobs(self, repo_id: str, run_name: str) -> List[Job]: - return base_jobs.list_jobs(self._storage, repo_id, run_name) - - def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus): - base_jobs.run_job(self._storage, self._compute, job, failed_to_start_job_new_status) - - def stop_job(self, repo_id: str, abort: bool, job_id: str): - base_jobs.stop_job(self._storage, self._compute, repo_id, job_id, abort) - - def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]: - return base_jobs.list_job_heads(self._storage, repo_id, run_name) - - def delete_job_head(self, repo_id: str, job_id: str): - base_jobs.delete_job_head(self._storage, repo_id, job_id) - - def list_run_heads( - self, - repo_id: str, - run_name: Optional[str] = None, - include_request_heads: bool = True, - interrupted_job_new_status: JobStatus = JobStatus.FAILED, - ) -> List[RunHead]: - job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name) - return base_runs.get_run_heads( - self._storage, - self._compute, - job_heads, - include_request_heads, - interrupted_job_new_status, - ) - - def poll_logs( - self, - repo_id: str, - run_name: str, - start_time: datetime, - end_time: Optional[datetime] = None, - descending: bool = False, - diagnose: bool = False, - ) -> Generator[LogEvent, None, None]: - yield from self._logging.poll_logs( - storage=self._storage, - repo_id=repo_id, - run_name=run_name, - start_time=start_time, - end_time=end_time, - descending=descending, - diagnose=diagnose, - ) - - def list_run_artifact_files( - self, repo_id: str, run_name: str, prefix: str, recursive: bool = False - ) -> List[Artifact]: - return base_artifacts.list_run_artifact_files( - self._storage, repo_id, run_name, prefix, recursive - ) - - def download_run_artifact_files( - self, - repo_id: str, - run_name: str, - output_dir: Optional[PathLike], - files_path: Optional[PathLike] = None, - ): - artifacts = self.list_run_artifact_files( - repo_id, run_name=run_name, prefix="", recursive=True - ) - base_artifacts.download_run_artifact_files( - storage=self._storage, - repo_id=repo_id, - artifacts=artifacts, - output_dir=output_dir, - files_path=files_path, - ) - - def upload_job_artifact_files( - self, - repo_id: str, - job_id: str, - artifact_name: str, - artifact_path: PathLike, - local_path: PathLike, - ): - base_artifacts.upload_job_artifact_files( - storage=self._storage, - repo_id=repo_id, - job_id=job_id, - artifact_name=artifact_name, - artifact_path=artifact_path, - local_path=local_path, - ) - - def list_tag_heads(self, repo_id: str) -> List[TagHead]: - return base_tags.list_tag_heads(self._storage, repo_id) - - def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]: - return base_tags.get_tag_head(self._storage, repo_id, tag_name) - - def add_tag_from_run( - self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]] - ): - base_tags.create_tag_from_run( - self._storage, - repo_id, - tag_name, - run_name, - run_jobs, - ) - - def add_tag_from_local_dirs( - self, - repo: Repo, - hub_user_name: str, - tag_name: str, - local_dirs: List[str], - artifact_paths: List[str], - ): - base_tags.create_tag_from_local_dirs( - storage=self._storage, - repo=repo, - hub_user_name=hub_user_name, - tag_name=tag_name, - local_dirs=local_dirs, - artifact_paths=artifact_paths, - ) - - def delete_tag_head(self, repo_id: str, tag_head: TagHead): - base_tags.delete_tag(self._storage, repo_id, tag_head) - - def list_repo_heads(self) -> List[RepoHead]: - return base_repos.list_repo_heads(self._storage) - - def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int): - base_repos.update_repo_last_run_at( - self._storage, - repo_spec, - last_run_at, - ) - - def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]: - return base_repos.get_repo_credentials(self._secrets_manager, repo_id) - - def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials): - base_repos.save_repo_credentials(self._secrets_manager, repo_id, repo_credentials) - - def delete_repo(self, repo_id: str): - base_repos.delete_repo(self._storage, repo_id) - - def list_secret_names(self, repo_id: str) -> List[str]: - return base_secrets.list_secret_names(self._storage, repo_id) - - def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]: - return base_secrets.get_secret(self._secrets_manager, repo_id, secret_name) - - def add_secret(self, repo_id: str, secret: Secret): - base_secrets.add_secret(self._storage, self._secrets_manager, repo_id, secret) - - def update_secret(self, repo_id: str, secret: Secret): - base_secrets.update_secret(self._storage, self._secrets_manager, repo_id, secret) - - def delete_secret(self, repo_id: str, secret_name: str): - base_secrets.delete_secret(self._storage, self._secrets_manager, repo_id, secret_name) - - def get_signed_download_url(self, object_key: str) -> str: - return self._storage.get_signed_download_url(object_key) - - def get_signed_upload_url(self, object_key: str) -> str: - return self._storage.get_signed_upload_url(object_key) - - def delete_configuration_cache( - self, repo_id: str, hub_user_name: str, configuration_path: str - ): - base_cache.delete_configuration_cache( - self._storage, repo_id, hub_user_name, configuration_path - ) + def logging(self) -> AzureLogging: + return self._logging diff --git a/cli/dstack/_internal/backend/azure/config.py b/cli/dstack/_internal/backend/azure/config.py index 1694bbc6e..de12b2d7e 100644 --- a/cli/dstack/_internal/backend/azure/config.py +++ b/cli/dstack/_internal/backend/azure/config.py @@ -1,73 +1,31 @@ from typing import Dict, Optional -from dstack._internal.backend.base.config import BackendConfig +from pydantic import BaseModel, ValidationError +from typing_extensions import Literal +from dstack._internal.backend.base.config import BackendConfig -class AzureConfig(BackendConfig): - NAME = "azure" - def __init__( - self, - tenant_id: str, - subscription_id: str, - location: str, - resource_group: str, - storage_account: str, - vault_url: str, - network: str, - subnet: str, - credentials: Optional[Dict], - ): - self.subscription_id = subscription_id - self.tenant_id = tenant_id - self.location = location - self.resource_group = resource_group - self.storage_account = storage_account - self.vault_url = vault_url - self.network = network - self.subnet = subnet - self.credentials = credentials +class AzureConfig(BackendConfig, BaseModel): + backend: Literal["azure"] = "azure" + tenant_id: str + subscription_id: str + location: str + resource_group: str + storage_account: str + vault_url: str + network: str + subnet: str + credentials: Optional[Dict] = None def serialize(self) -> Dict: - res = { - "backend": "azure", - "tenant_id": self.tenant_id, - "subscription_id": self.subscription_id, - "location": self.location, - "resource_group": self.resource_group, - "storage_account": self.storage_account, - "vault_url": self.vault_url, - "network": self.network, - "subnet": self.subnet, - } - return res + return self.dict(exclude={"credentials"}) @classmethod - def deserialize(cls, data: Dict) -> Optional["AzureConfig"]: - if data.get("backend") != "azure": + def deserialize(cls, config_data: Dict) -> Optional["AzureConfig"]: + if config_data.get("backend") != "azure": return None - try: - tenant_id = data["tenant_id"] - subscription_id = data["subscription_id"] - location = data["location"] - resource_group = data["resource_group"] - storage_account = data["storage_account"] - vault_url = data["vault_url"] - network = data["network"] - subnet = data["subnet"] - credentials = data.get("credentials") - except KeyError: + return cls.parse_obj(config_data) + except ValidationError: return None - - return cls( - tenant_id=tenant_id, - subscription_id=subscription_id, - location=location, - resource_group=resource_group, - storage_account=storage_account, - vault_url=vault_url, - network=network, - subnet=subnet, - credentials=credentials, - ) diff --git a/cli/dstack/_internal/backend/azure/logs.py b/cli/dstack/_internal/backend/azure/logs.py index 758f2b433..a8c114eb7 100644 --- a/cli/dstack/_internal/backend/azure/logs.py +++ b/cli/dstack/_internal/backend/azure/logs.py @@ -8,13 +8,13 @@ from dstack._internal.backend.azure.utils import DSTACK_LOGS_TABLE_NAME, get_logs_workspace_name from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base.logs import fix_log_event_urls +from dstack._internal.backend.base.logs import Logging, fix_log_event_urls from dstack._internal.backend.base.storage import Storage from dstack._internal.core.log_event import LogEvent from dstack._internal.utils.common import get_current_datetime -class AzureLogging: +class AzureLogging(Logging): def __init__( self, credential: TokenCredential, diff --git a/cli/dstack/_internal/backend/base/__init__.py b/cli/dstack/_internal/backend/base/__init__.py index c1fe61659..9cb56d735 100644 --- a/cli/dstack/_internal/backend/base/__init__.py +++ b/cli/dstack/_internal/backend/base/__init__.py @@ -2,9 +2,15 @@ from datetime import datetime from typing import Generator, List, Optional +from dstack._internal.backend.base import artifacts as base_artifacts +from dstack._internal.backend.base import cache as base_cache from dstack._internal.backend.base import jobs as base_jobs +from dstack._internal.backend.base import repos as base_repos +from dstack._internal.backend.base import runs as base_runs +from dstack._internal.backend.base import secrets as base_secrets +from dstack._internal.backend.base import tags as base_tags from dstack._internal.backend.base.compute import Compute -from dstack._internal.backend.base.config import BackendConfig +from dstack._internal.backend.base.logs import Logging from dstack._internal.backend.base.secrets import SecretsManager from dstack._internal.backend.base.storage import Storage from dstack._internal.core.artifact import Artifact @@ -22,12 +28,6 @@ class Backend(ABC): NAME = None - def __init__( - self, - backend_config: BackendConfig, - ): - self.backend_config = backend_config - @classmethod @abstractmethod def load(cls) -> Optional["Backend"]: @@ -229,3 +229,209 @@ def get_signed_download_url(self, object_key: str) -> str: @abstractmethod def get_signed_upload_url(self, object_key: str) -> str: pass + + +class ComponentBasedBackend(Backend): + @abstractmethod + def storage(self) -> Storage: + pass + + @abstractmethod + def compute(self) -> Compute: + pass + + @abstractmethod + def secrets_manager(self) -> SecretsManager: + pass + + @abstractmethod + def logging(self) -> Logging: + pass + + def predict_instance_type(self, job: Job) -> Optional[InstanceType]: + return base_jobs.predict_job_instance(self.compute(), job) + + def create_run(self, repo_id: str) -> str: + return base_runs.create_run(self.storage()) + + def create_job(self, job: Job): + base_jobs.create_job(self.storage(), job) + + def get_job(self, repo_id: str, job_id: str) -> Optional[Job]: + return base_jobs.get_job(self.storage(), repo_id, job_id) + + def list_jobs(self, repo_id: str, run_name: str) -> List[Job]: + return base_jobs.list_jobs(self.storage(), repo_id, run_name) + + def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus): + base_jobs.run_job(self.storage(), self.compute(), job, failed_to_start_job_new_status) + + def stop_job(self, repo_id: str, abort: bool, job_id: str): + base_jobs.stop_job(self.storage(), self.compute(), repo_id, job_id, abort) + + def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]: + return base_jobs.list_job_heads(self.storage(), repo_id, run_name) + + def delete_job_head(self, repo_id: str, job_id: str): + base_jobs.delete_job_head(self.storage(), repo_id, job_id) + + def list_run_heads( + self, + repo_id: str, + run_name: Optional[str] = None, + include_request_heads: bool = True, + interrupted_job_new_status: JobStatus = JobStatus.FAILED, + ) -> List[RunHead]: + job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name) + return base_runs.get_run_heads( + self.storage(), + self.compute(), + job_heads, + include_request_heads, + interrupted_job_new_status, + ) + + def poll_logs( + self, + repo_id: str, + run_name: str, + start_time: datetime, + end_time: Optional[datetime] = None, + descending: bool = False, + diagnose: bool = False, + ) -> Generator[LogEvent, None, None]: + return self.logging().poll_logs( + self.storage(), + repo_id, + run_name, + start_time, + end_time, + descending, + diagnose, + ) + + def list_run_artifact_files( + self, repo_id: str, run_name: str, prefix: str, recursive: bool = False + ) -> List[Artifact]: + return base_artifacts.list_run_artifact_files( + self.storage(), repo_id, run_name, prefix, recursive + ) + + def download_run_artifact_files( + self, + repo_id: str, + run_name: str, + output_dir: Optional[PathLike], + files_path: Optional[PathLike] = None, + ): + artifacts = self.list_run_artifact_files( + repo_id, run_name=run_name, prefix="", recursive=True + ) + base_artifacts.download_run_artifact_files( + storage=self.storage(), + repo_id=repo_id, + artifacts=artifacts, + output_dir=output_dir, + files_path=files_path, + ) + + def upload_job_artifact_files( + self, + repo_id: str, + job_id: str, + artifact_name: str, + artifact_path: PathLike, + local_path: PathLike, + ): + base_artifacts.upload_job_artifact_files( + storage=self.storage(), + repo_id=repo_id, + job_id=job_id, + artifact_name=artifact_name, + artifact_path=artifact_path, + local_path=local_path, + ) + + def list_tag_heads(self, repo_id: str) -> List[TagHead]: + return base_tags.list_tag_heads(self.storage(), repo_id) + + def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]: + return base_tags.get_tag_head(self.storage(), repo_id, tag_name) + + def add_tag_from_run( + self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]] + ): + base_tags.create_tag_from_run( + self.storage(), + repo_id, + tag_name, + run_name, + run_jobs, + ) + + def add_tag_from_local_dirs( + self, + repo: Repo, + hub_user_name: str, + tag_name: str, + local_dirs: List[str], + artifact_paths: List[str], + ): + base_tags.create_tag_from_local_dirs( + storage=self.storage(), + repo=repo, + hub_user_name=hub_user_name, + tag_name=tag_name, + local_dirs=local_dirs, + artifact_paths=artifact_paths, + ) + + def delete_tag_head(self, repo_id: str, tag_head: TagHead): + base_tags.delete_tag(self.storage(), repo_id, tag_head) + + def list_repo_heads(self) -> List[RepoHead]: + return base_repos.list_repo_heads(self.storage()) + + def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int): + base_repos.update_repo_last_run_at( + self.storage(), + repo_spec, + last_run_at, + ) + + def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]: + return base_repos.get_repo_credentials(self.secrets_manager(), repo_id) + + def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials): + base_repos.save_repo_credentials(self.secrets_manager(), repo_id, repo_credentials) + + def delete_repo(self, repo_id: str): + base_repos.delete_repo(self.storage(), repo_id) + + def list_secret_names(self, repo_id: str) -> List[str]: + return base_secrets.list_secret_names(self.storage(), repo_id) + + def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]: + return base_secrets.get_secret(self.secrets_manager(), repo_id, secret_name) + + def add_secret(self, repo_id: str, secret: Secret): + base_secrets.add_secret(self.storage(), self.secrets_manager(), repo_id, secret) + + def update_secret(self, repo_id: str, secret: Secret): + base_secrets.update_secret(self.storage(), self.secrets_manager(), repo_id, secret) + + def delete_secret(self, repo_id: str, secret_name: str): + base_secrets.delete_secret(self.storage(), self.secrets_manager(), repo_id, secret_name) + + def get_signed_download_url(self, object_key: str) -> str: + return self.storage().get_signed_download_url(object_key) + + def get_signed_upload_url(self, object_key: str) -> str: + return self.storage().get_signed_upload_url(object_key) + + def delete_configuration_cache( + self, repo_id: str, hub_user_name: str, configuration_path: str + ): + base_cache.delete_configuration_cache( + self.storage(), repo_id, hub_user_name, configuration_path + ) diff --git a/cli/dstack/_internal/backend/base/logs.py b/cli/dstack/_internal/backend/base/logs.py index 93950556f..c6d8ccec2 100644 --- a/cli/dstack/_internal/backend/base/logs.py +++ b/cli/dstack/_internal/backend/base/logs.py @@ -1,9 +1,10 @@ import json import re import urllib.parse -from typing import Any, Dict, Optional +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Generator, Optional -from dstack._internal.backend.base import jobs from dstack._internal.backend.base.storage import Storage from dstack._internal.core.job import Job from dstack._internal.core.log_event import LogEvent, LogEventSource @@ -15,6 +16,21 @@ POLL_LOGS_RATE_SECS = 1 +class Logging(ABC): + @abstractmethod + def poll_logs( + self, + storage: Storage, + repo_id: str, + run_name: str, + start_time: datetime, + end_time: Optional[datetime], + descending: bool, + diagnose: bool, + ) -> Generator[LogEvent, None, None]: + pass + + def render_log_event( event: Dict[str, Any], ) -> LogEvent: diff --git a/cli/dstack/_internal/backend/gcp/__init__.py b/cli/dstack/_internal/backend/gcp/__init__.py index 25b225258..d32a04ee5 100644 --- a/cli/dstack/_internal/backend/gcp/__init__.py +++ b/cli/dstack/_internal/backend/gcp/__init__.py @@ -1,54 +1,30 @@ import warnings -from datetime import datetime -from typing import Generator, List, Optional +from typing import Optional import google.auth from google.auth._default import _CLOUD_SDK_CREDENTIALS_WARNING from google.auth.credentials import Credentials -from google.oauth2 import service_account -from dstack._internal.backend.base import Backend -from dstack._internal.backend.base import artifacts as base_artifacts -from dstack._internal.backend.base import cache as base_cache -from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base import repos as base_repos -from dstack._internal.backend.base import runs as base_runs -from dstack._internal.backend.base import secrets as base_secrets -from dstack._internal.backend.base import tags as base_tags +from dstack._internal.backend.base import ComponentBasedBackend from dstack._internal.backend.gcp.auth import authenticate from dstack._internal.backend.gcp.compute import GCPCompute from dstack._internal.backend.gcp.config import GCPConfig from dstack._internal.backend.gcp.logs import GCPLogging from dstack._internal.backend.gcp.secrets import GCPSecretsManager from dstack._internal.backend.gcp.storage import GCPStorage -from dstack._internal.cli.common import console -from dstack._internal.core.artifact import Artifact -from dstack._internal.core.instance import InstanceType -from dstack._internal.core.job import Job, JobHead, JobStatus -from dstack._internal.core.log_event import LogEvent -from dstack._internal.core.repo import RemoteRepoCredentials, RepoHead, RepoSpec -from dstack._internal.core.repo.base import Repo -from dstack._internal.core.run import RunHead -from dstack._internal.core.secret import Secret -from dstack._internal.core.tag import TagHead -from dstack._internal.utils.common import PathLike warnings.filterwarnings("ignore", message=_CLOUD_SDK_CREDENTIALS_WARNING) -class GCPBackend(Backend): +class GCPBackend(ComponentBasedBackend): NAME = "gcp" - backend_config: GCPConfig - _storage: GCPStorage - _compute: GCPCompute - _secrets_manager: GCPSecretsManager def __init__( self, backend_config: GCPConfig, credentials: Optional[Credentials] = None, ): - super().__init__(backend_config=backend_config) + self.backend_config = backend_config if credentials is None: credentials = authenticate(backend_config) self._storage = GCPStorage( @@ -68,6 +44,18 @@ def __init__( credentials=credentials, ) + def storage(self) -> GCPStorage: + return self._storage + + def compute(self) -> GCPCompute: + return self._compute + + def secrets_manager(self) -> GCPSecretsManager: + return self._secrets_manager + + def logging(self) -> GCPLogging: + return self._logging + @classmethod def load(cls) -> Optional["GCPBackend"]: config = GCPConfig.load() @@ -78,194 +66,3 @@ def load(cls) -> Optional["GCPBackend"]: backend_config=config, credentials=credentials, ) - - def predict_instance_type(self, job: Job) -> Optional[InstanceType]: - return base_jobs.predict_job_instance(self._compute, job) - - def create_run(self, repo_id: str) -> str: - return base_runs.create_run(self._storage) - - def create_job(self, job: Job): - if job.artifact_specs and any(art_spec.mount for art_spec in job.artifact_specs): - console.print("Mount artifacts are not currently supported for 'gcp' backend") - exit(1) - base_jobs.create_job(self._storage, job) - - def get_job(self, repo_id: str, job_id: str) -> Optional[Job]: - return base_jobs.get_job(self._storage, repo_id, job_id) - - def list_jobs(self, repo_id: str, run_name: str) -> List[Job]: - return base_jobs.list_jobs(self._storage, repo_id, run_name) - - def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus): - base_jobs.run_job(self._storage, self._compute, job, failed_to_start_job_new_status) - - def stop_job(self, repo_id: str, abort: bool, job_id: str): - base_jobs.stop_job(self._storage, self._compute, repo_id, job_id, abort) - - def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]: - return base_jobs.list_job_heads(self._storage, repo_id, run_name) - - def delete_job_head(self, repo_id: str, job_id: str): - base_jobs.delete_job_head(self._storage, repo_id, job_id) - - def list_run_heads( - self, - repo_id: str, - run_name: Optional[str] = None, - include_request_heads: bool = True, - interrupted_job_new_status: JobStatus = JobStatus.FAILED, - ) -> List[RunHead]: - job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name) - return base_runs.get_run_heads( - self._storage, - self._compute, - job_heads, - include_request_heads, - interrupted_job_new_status, - ) - - def poll_logs( - self, - repo_id: str, - run_name: str, - start_time: datetime, - end_time: Optional[datetime] = None, - descending: bool = False, - diagnose: bool = False, - ) -> Generator[LogEvent, None, None]: - yield from self._logging.poll_logs( - storage=self._storage, - repo_id=repo_id, - run_name=run_name, - start_time=start_time, - end_time=end_time, - descending=descending, - diagnose=diagnose, - ) - - def list_run_artifact_files( - self, repo_id: str, run_name: str, prefix: str, recursive: bool = False - ) -> List[Artifact]: - return base_artifacts.list_run_artifact_files( - self._storage, repo_id, run_name, prefix, recursive - ) - - def download_run_artifact_files( - self, - repo_id: str, - run_name: str, - output_dir: Optional[PathLike], - files_path: Optional[PathLike] = None, - ): - artifacts = self.list_run_artifact_files( - repo_id, run_name=run_name, prefix="", recursive=True - ) - base_artifacts.download_run_artifact_files( - storage=self._storage, - repo_id=repo_id, - artifacts=artifacts, - output_dir=output_dir, - files_path=files_path, - ) - - def upload_job_artifact_files( - self, - repo_id: str, - job_id: str, - artifact_name: str, - artifact_path: PathLike, - local_path: PathLike, - ): - base_artifacts.upload_job_artifact_files( - storage=self._storage, - repo_id=repo_id, - job_id=job_id, - artifact_name=artifact_name, - artifact_path=artifact_path, - local_path=local_path, - ) - - def list_tag_heads(self, repo_id: str) -> List[TagHead]: - return base_tags.list_tag_heads(self._storage, repo_id) - - def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]: - return base_tags.get_tag_head(self._storage, repo_id, tag_name) - - def add_tag_from_run( - self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]] - ): - base_tags.create_tag_from_run( - self._storage, - repo_id, - tag_name, - run_name, - run_jobs, - ) - - def add_tag_from_local_dirs( - self, - repo: Repo, - hub_user_name: str, - tag_name: str, - local_dirs: List[str], - artifact_paths: List[str], - ): - base_tags.create_tag_from_local_dirs( - storage=self._storage, - repo=repo, - hub_user_name=hub_user_name, - tag_name=tag_name, - local_dirs=local_dirs, - artifact_paths=artifact_paths, - ) - - def delete_tag_head(self, repo_id: str, tag_head: TagHead): - base_tags.delete_tag(self._storage, repo_id, tag_head) - - def list_repo_heads(self) -> List[RepoHead]: - return base_repos.list_repo_heads(self._storage) - - def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int): - base_repos.update_repo_last_run_at( - self._storage, - repo_spec, - last_run_at, - ) - - def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]: - return base_repos.get_repo_credentials(self._secrets_manager, repo_id) - - def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials): - base_repos.save_repo_credentials(self._secrets_manager, repo_id, repo_credentials) - - def delete_repo(self, repo_id: str): - base_repos.delete_repo(self._storage, repo_id) - - def list_secret_names(self, repo_id: str) -> List[str]: - return base_secrets.list_secret_names(self._storage, repo_id) - - def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]: - return base_secrets.get_secret(self._secrets_manager, repo_id, secret_name) - - def add_secret(self, repo_id: str, secret: Secret): - base_secrets.add_secret(self._storage, self._secrets_manager, repo_id, secret) - - def update_secret(self, repo_id: str, secret: Secret): - base_secrets.update_secret(self._storage, self._secrets_manager, repo_id, secret) - - def delete_secret(self, repo_id: str, secret_name: str): - base_secrets.delete_secret(self._storage, self._secrets_manager, repo_id, secret_name) - - def get_signed_download_url(self, object_key: str) -> str: - return self._storage.get_signed_download_url(object_key) - - def get_signed_upload_url(self, object_key: str) -> str: - return self._storage.get_signed_upload_url(object_key) - - def delete_configuration_cache( - self, repo_id: str, hub_user_name: str, configuration_path: str - ): - base_cache.delete_configuration_cache( - self._storage, repo_id, hub_user_name, configuration_path - ) diff --git a/cli/dstack/_internal/backend/gcp/compute.py b/cli/dstack/_internal/backend/gcp/compute.py index cacf79de0..854865272 100644 --- a/cli/dstack/_internal/backend/gcp/compute.py +++ b/cli/dstack/_internal/backend/gcp/compute.py @@ -732,7 +732,7 @@ def _create_firewall_rules( A Firewall object. """ firewall_rule = compute_v1.Firewall() - firewall_rule.name = f"dstack-runner-allow-incoming-" + network.replace("/", "-") + firewall_rule.name = f"dstack-ssh-in-" + network.replace("/", "-") firewall_rule.direction = "INGRESS" allowed_ssh_port = compute_v1.Allowed() diff --git a/cli/dstack/_internal/backend/gcp/config.py b/cli/dstack/_internal/backend/gcp/config.py index e346b30ce..509b0269f 100644 --- a/cli/dstack/_internal/backend/gcp/config.py +++ b/cli/dstack/_internal/backend/gcp/config.py @@ -1,28 +1,21 @@ from typing import Dict, Optional +from pydantic import BaseModel, ValidationError +from typing_extensions import Literal + from dstack._internal.backend.base.config import BackendConfig -class GCPConfig(BackendConfig): - def __init__( - self, - project_id: str, - region: str, - zone: str, - bucket_name: str, - vpc: str, - subnet: str, - credentials_file: Optional[str] = None, - credentials: Optional[Dict] = None, - ): - self.project_id = project_id - self.region = region - self.zone = zone - self.bucket_name = bucket_name - self.vpc = vpc - self.subnet = subnet - self.credentials_file = credentials_file - self.credentials = credentials +class GCPConfig(BackendConfig, BaseModel): + backend: Literal["gcp"] = "gcp" + project_id: str + region: str + zone: str + bucket_name: str + vpc: str + subnet: str + credentials_file: Optional[str] = None + credentials: Optional[Dict] = None def serialize(self) -> Dict: res = { @@ -43,21 +36,12 @@ def deserialize(cls, config_data: Dict) -> Optional["GCPConfig"]: if config_data.get("backend") != "gcp": return None try: - project_id = config_data["project"] - region = config_data["region"] - zone = config_data["zone"] - bucket_name = config_data["bucket"] - vpc = config_data["vpc"] - subnet = config_data["subnet"] - except KeyError: + return cls.parse_obj( + { + **config_data, + "project_id": config_data["project"], + "bucket_name": config_data["bucket"], + } + ) + except ValidationError: return None - return cls( - project_id=project_id, - region=region, - zone=zone, - bucket_name=bucket_name, - vpc=vpc, - subnet=subnet, - credentials_file=config_data.get("credentials_file"), - credentials=config_data.get("credentials"), - ) diff --git a/cli/dstack/_internal/backend/gcp/logs.py b/cli/dstack/_internal/backend/gcp/logs.py index 7798593a6..8b1d8144f 100644 --- a/cli/dstack/_internal/backend/gcp/logs.py +++ b/cli/dstack/_internal/backend/gcp/logs.py @@ -5,7 +5,7 @@ from google.oauth2 import service_account from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base.logs import fix_log_event_urls +from dstack._internal.backend.base.logs import Logging, fix_log_event_urls from dstack._internal.backend.base.storage import Storage from dstack._internal.core.job import Job from dstack._internal.core.log_event import LogEvent, LogEventSource @@ -13,7 +13,7 @@ LOGS_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z" -class GCPLogging: +class GCPLogging(Logging): def __init__( self, project_id: str, bucket_name: str, credentials: Optional[service_account.Credentials] ): diff --git a/cli/dstack/_internal/backend/lambdalabs/__init__.py b/cli/dstack/_internal/backend/lambdalabs/__init__.py index 3c177fdbd..4b7b6ec76 100644 --- a/cli/dstack/_internal/backend/lambdalabs/__init__.py +++ b/cli/dstack/_internal/backend/lambdalabs/__init__.py @@ -1,35 +1,18 @@ -from datetime import datetime -from typing import Generator, List, Optional +from typing import Optional import boto3 from botocore.client import BaseClient -from dstack._internal.backend.aws import logs +from dstack._internal.backend.aws.logs import AWSLogging from dstack._internal.backend.aws.secrets import AWSSecretsManager from dstack._internal.backend.aws.storage import AWSStorage -from dstack._internal.backend.base import Backend -from dstack._internal.backend.base import artifacts as base_artifacts -from dstack._internal.backend.base import cache as base_cache -from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base import repos as base_repos +from dstack._internal.backend.base import ComponentBasedBackend from dstack._internal.backend.base import runs as base_runs -from dstack._internal.backend.base import secrets as base_secrets -from dstack._internal.backend.base import tags as base_tags from dstack._internal.backend.lambdalabs.compute import LambdaCompute from dstack._internal.backend.lambdalabs.config import LambdaConfig -from dstack._internal.core.artifact import Artifact -from dstack._internal.core.instance import InstanceType -from dstack._internal.core.job import Job, JobHead, JobStatus -from dstack._internal.core.log_event import LogEvent -from dstack._internal.core.repo import RemoteRepoCredentials, RepoHead, RepoSpec -from dstack._internal.core.repo.base import Repo -from dstack._internal.core.run import RunHead -from dstack._internal.core.secret import Secret -from dstack._internal.core.tag import TagHead -from dstack._internal.utils.common import PathLike -class LambdaBackend(Backend): +class LambdaBackend(ComponentBasedBackend): NAME = "lambda" def __init__( @@ -52,6 +35,10 @@ def __init__( sts_client=self._sts_client(), bucket_name=self.backend_config.storage_config.bucket, ) + self._logging = AWSLogging( + logs_client=self._logs_client(), + bucket_name=self.backend_config.storage_config.bucket, + ) @classmethod def load(cls) -> Optional["LambdaBackend"]: @@ -60,6 +47,24 @@ def load(cls) -> Optional["LambdaBackend"]: return None return cls(config) + def storage(self) -> AWSStorage: + return self._storage + + def compute(self) -> LambdaCompute: + return self._compute + + def secrets_manager(self) -> AWSSecretsManager: + return self._secrets_manager + + def logging(self) -> AWSLogging: + return self._logging + + def create_run(self, repo_id: str) -> str: + self._logging.create_log_groups_if_not_exist( + self._logs_client(), self.backend_config.storage_config.bucket, repo_id + ) + return base_runs.create_run(self._storage) + def _s3_client(self) -> BaseClient: return self._get_client("s3") @@ -80,196 +85,3 @@ def _sts_client(self) -> BaseClient: def _get_client(self, client_name: str) -> BaseClient: return self._session.client(client_name) - - def predict_instance_type(self, job: Job) -> Optional[InstanceType]: - return base_jobs.predict_job_instance(self._compute, job) - - def create_run(self, repo_id: str) -> str: - logs.create_log_groups_if_not_exist( - self._logs_client(), self.backend_config.storage_config.bucket, repo_id - ) - return base_runs.create_run(self._storage) - - def create_job(self, job: Job): - base_jobs.create_job(self._storage, job) - - def get_job(self, repo_id: str, job_id: str) -> Optional[Job]: - return base_jobs.get_job(self._storage, repo_id, job_id) - - def list_jobs(self, repo_id: str, run_name: str) -> List[Job]: - return base_jobs.list_jobs(self._storage, repo_id, run_name) - - def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus): - base_jobs.run_job(self._storage, self._compute, job, failed_to_start_job_new_status) - - def stop_job(self, repo_id: str, abort: bool, job_id: str): - base_jobs.stop_job(self._storage, self._compute, repo_id, job_id, abort) - - def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]: - return base_jobs.list_job_heads(self._storage, repo_id, run_name) - - def delete_job_head(self, repo_id: str, job_id: str): - base_jobs.delete_job_head(self._storage, repo_id, job_id) - - def list_run_heads( - self, - repo_id: str, - run_name: Optional[str] = None, - include_request_heads: bool = True, - interrupted_job_new_status: JobStatus = JobStatus.FAILED, - ) -> List[RunHead]: - job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name) - return base_runs.get_run_heads( - self._storage, - self._compute, - job_heads, - include_request_heads, - interrupted_job_new_status, - ) - - def poll_logs( - self, - repo_id: str, - run_name: str, - start_time: datetime, - end_time: Optional[datetime] = None, - descending: bool = False, - diagnose: bool = False, - ) -> Generator[LogEvent, None, None]: - return logs.poll_logs( - self._storage, - self._logs_client(), - self.backend_config.storage_config.bucket, - repo_id, - run_name, - start_time, - end_time, - descending, - diagnose, - ) - - def list_run_artifact_files( - self, repo_id: str, run_name: str, prefix: str, recursive: bool = False - ) -> List[Artifact]: - return base_artifacts.list_run_artifact_files( - self._storage, repo_id, run_name, prefix, recursive - ) - - def download_run_artifact_files( - self, - repo_id: str, - run_name: str, - output_dir: Optional[PathLike], - files_path: Optional[PathLike] = None, - ): - artifacts = self.list_run_artifact_files( - repo_id, run_name=run_name, prefix="", recursive=True - ) - base_artifacts.download_run_artifact_files( - storage=self._storage, - repo_id=repo_id, - artifacts=artifacts, - output_dir=output_dir, - files_path=files_path, - ) - - def upload_job_artifact_files( - self, - repo_id: str, - job_id: str, - artifact_name: str, - artifact_path: PathLike, - local_path: PathLike, - ): - base_artifacts.upload_job_artifact_files( - storage=self._storage, - repo_id=repo_id, - job_id=job_id, - artifact_name=artifact_name, - artifact_path=artifact_path, - local_path=local_path, - ) - - def list_tag_heads(self, repo_id: str) -> List[TagHead]: - return base_tags.list_tag_heads(self._storage, repo_id) - - def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]: - return base_tags.get_tag_head(self._storage, repo_id, tag_name) - - def add_tag_from_run( - self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]] - ): - base_tags.create_tag_from_run( - self._storage, - repo_id, - tag_name, - run_name, - run_jobs, - ) - - def add_tag_from_local_dirs( - self, - repo: Repo, - hub_user_name: str, - tag_name: str, - local_dirs: List[str], - artifact_paths: List[str], - ): - base_tags.create_tag_from_local_dirs( - storage=self._storage, - repo=repo, - hub_user_name=hub_user_name, - tag_name=tag_name, - local_dirs=local_dirs, - artifact_paths=artifact_paths, - ) - - def delete_tag_head(self, repo_id: str, tag_head: TagHead): - base_tags.delete_tag(self._storage, repo_id, tag_head) - - def list_repo_heads(self) -> List[RepoHead]: - return base_repos.list_repo_heads(self._storage) - - def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int): - base_repos.update_repo_last_run_at( - self._storage, - repo_spec, - last_run_at, - ) - - def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]: - return base_repos.get_repo_credentials(self._secrets_manager, repo_id) - - def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials): - base_repos.save_repo_credentials(self._secrets_manager, repo_id, repo_credentials) - - def delete_repo(self, repo_id: str): - base_repos.delete_repo(self._storage, repo_id) - - def list_secret_names(self, repo_id: str) -> List[str]: - return base_secrets.list_secret_names(self._storage, repo_id) - - def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]: - return base_secrets.get_secret(self._secrets_manager, repo_id, repo_id) - - def add_secret(self, repo_id: str, secret: Secret): - base_secrets.add_secret(self._storage, self._secrets_manager, repo_id, secret) - - def update_secret(self, repo_id: str, secret: Secret): - base_secrets.update_secret(self._storage, self._secrets_manager, repo_id, secret) - - def delete_secret(self, repo_id: str, secret_name: str): - base_secrets.delete_secret(self._storage, self._secrets_manager, repo_id, repo_id) - - def get_signed_download_url(self, object_key: str) -> str: - return self._storage.get_signed_download_url(object_key) - - def get_signed_upload_url(self, object_key: str) -> str: - return self._storage.get_signed_upload_url(object_key) - - def delete_configuration_cache( - self, repo_id: str, hub_user_name: str, configuration_path: str - ): - base_cache.delete_configuration_cache( - self._storage, repo_id, hub_user_name, configuration_path - ) diff --git a/cli/dstack/_internal/backend/lambdalabs/config.py b/cli/dstack/_internal/backend/lambdalabs/config.py index 6bd95137c..963794003 100644 --- a/cli/dstack/_internal/backend/lambdalabs/config.py +++ b/cli/dstack/_internal/backend/lambdalabs/config.py @@ -29,7 +29,9 @@ def serialize(self) -> Dict: @classmethod def deserialize(cls, config_data: Dict) -> Optional["LambdaConfig"]: + if config_data.get("backend") != "lambda": + return None try: - return LambdaConfig.parse_obj(config_data) + return cls.parse_obj(config_data) except ValidationError: return None diff --git a/cli/dstack/_internal/backend/local/__init__.py b/cli/dstack/_internal/backend/local/__init__.py index 4377ce418..4c791db82 100644 --- a/cli/dstack/_internal/backend/local/__init__.py +++ b/cli/dstack/_internal/backend/local/__init__.py @@ -1,48 +1,25 @@ -from datetime import datetime -from typing import Generator, List, Optional +from typing import Optional -from dstack._internal.backend.base import Backend -from dstack._internal.backend.base import artifacts as base_artifacts -from dstack._internal.backend.base import cache as base_cache -from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base import repos as base_repos -from dstack._internal.backend.base import runs as base_runs -from dstack._internal.backend.base import secrets as base_secrets -from dstack._internal.backend.base import tags as base_tags -from dstack._internal.backend.local import artifacts, logs +from dstack._internal.backend.base import ComponentBasedBackend from dstack._internal.backend.local.compute import LocalCompute from dstack._internal.backend.local.config import LocalConfig +from dstack._internal.backend.local.logs import LocalLogging from dstack._internal.backend.local.secrets import LocalSecretsManager from dstack._internal.backend.local.storage import LocalStorage -from dstack._internal.core.artifact import Artifact -from dstack._internal.core.instance import InstanceType -from dstack._internal.core.job import Job, JobHead, JobStatus -from dstack._internal.core.log_event import LogEvent -from dstack._internal.core.repo import RemoteRepoCredentials, RepoHead, RepoSpec -from dstack._internal.core.repo.base import Repo -from dstack._internal.core.run import RunHead -from dstack._internal.core.secret import Secret -from dstack._internal.core.tag import TagHead -from dstack._internal.utils.common import PathLike -class LocalBackend(Backend): +class LocalBackend(ComponentBasedBackend): NAME = "local" - backend_config: LocalConfig - _storage: LocalStorage - _compute: LocalCompute - _secrets_manager: LocalSecretsManager def __init__( self, backend_config: LocalConfig, ): - super().__init__(backend_config=backend_config) - if self.backend_config is None: - return + self.backend_config = backend_config self._storage = LocalStorage(self.backend_config.backend_dir) self._compute = LocalCompute(self.backend_config) self._secrets_manager = LocalSecretsManager(self.backend_config.backend_dir) + self._logging = LocalLogging(self.backend_config) @classmethod def load(cls) -> Optional["LocalBackend"]: @@ -51,189 +28,14 @@ def load(cls) -> Optional["LocalBackend"]: return None return cls(backend_config=config) - def predict_instance_type(self, job: Job) -> Optional[InstanceType]: - return base_jobs.predict_job_instance(self._compute, job) + def storage(self) -> LocalStorage: + return self._storage - def create_run(self, repo_id: str) -> str: - return base_runs.create_run(self._storage) + def compute(self) -> LocalCompute: + return self._compute - def create_job(self, job: Job): - base_jobs.create_job(self._storage, job) + def secrets_manager(self) -> LocalSecretsManager: + return self._secrets_manager - def get_job(self, repo_id: str, job_id: str) -> Optional[Job]: - return base_jobs.get_job(self._storage, repo_id, job_id) - - def list_jobs(self, repo_id: str, run_name: str) -> List[Job]: - return base_jobs.list_jobs(self._storage, repo_id, run_name) - - def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus): - base_jobs.run_job(self._storage, self._compute, job, failed_to_start_job_new_status) - - def stop_job(self, repo_id: str, abort: bool, job_id: str): - base_jobs.stop_job(self._storage, self._compute, repo_id, job_id, abort) - - def list_job_heads(self, repo_id: str, run_name: Optional[str] = None) -> List[JobHead]: - return base_jobs.list_job_heads(self._storage, repo_id, run_name) - - def delete_job_head(self, repo_id: str, job_id: str): - base_jobs.delete_job_head(self._storage, repo_id, job_id) - - def list_run_heads( - self, - repo_id: str, - run_name: Optional[str] = None, - include_request_heads: bool = True, - interrupted_job_new_status: JobStatus = JobStatus.FAILED, - ) -> List[RunHead]: - job_heads = self.list_job_heads(repo_id=repo_id, run_name=run_name) - return base_runs.get_run_heads( - self._storage, - self._compute, - job_heads, - include_request_heads, - interrupted_job_new_status, - ) - - def poll_logs( - self, - repo_id: str, - run_name: str, - start_time: datetime, - end_time: Optional[datetime] = None, - descending: bool = False, - diagnose: bool = False, - ) -> Generator[LogEvent, None, None]: - return logs.poll_logs( - self.backend_config, - self._storage, - repo_id, - run_name, - start_time, - end_time, - descending, - diagnose, - ) - - def list_run_artifact_files( - self, repo_id: str, run_name: str, prefix: Optional[str], recursive: bool = False - ) -> List[Artifact]: - return base_artifacts.list_run_artifact_files( - self._storage, repo_id, run_name, prefix, recursive - ) - - def download_run_artifact_files( - self, - repo_id: str, - run_name: str, - output_dir: Optional[PathLike], - files_path: Optional[PathLike] = None, - ): - artifacts = self.list_run_artifact_files( - repo_id, run_name=run_name, prefix="", recursive=True - ) - base_artifacts.download_run_artifact_files( - storage=self._storage, - repo_id=repo_id, - artifacts=artifacts, - output_dir=output_dir, - files_path=files_path, - ) - - def upload_job_artifact_files( - self, - repo_id: str, - job_id: str, - artifact_name: str, - artifact_path: PathLike, - local_path: PathLike, - ): - base_artifacts.upload_job_artifact_files( - storage=self._storage, - repo_id=repo_id, - job_id=job_id, - artifact_name=artifact_name, - artifact_path=artifact_path, - local_path=local_path, - ) - - def list_tag_heads(self, repo_id: str) -> List[TagHead]: - return base_tags.list_tag_heads(self._storage, repo_id) - - def get_tag_head(self, repo_id: str, tag_name: str) -> Optional[TagHead]: - return base_tags.get_tag_head(self._storage, repo_id, tag_name) - - def add_tag_from_run( - self, repo_id: str, tag_name: str, run_name: str, run_jobs: Optional[List[Job]] - ): - base_tags.create_tag_from_run( - self._storage, - repo_id, - tag_name, - run_name, - run_jobs, - ) - - def add_tag_from_local_dirs( - self, - repo: Repo, - hub_user_name: str, - tag_name: str, - local_dirs: List[str], - artifact_paths: List[str], - ): - base_tags.create_tag_from_local_dirs( - storage=self._storage, - repo=repo, - hub_user_name=hub_user_name, - tag_name=tag_name, - local_dirs=local_dirs, - artifact_paths=artifact_paths, - ) - - def delete_tag_head(self, repo_id: str, tag_head: TagHead): - base_tags.delete_tag(self._storage, repo_id, tag_head) - - def list_repo_heads(self) -> List[RepoHead]: - return base_repos.list_repo_heads(self._storage) - - def update_repo_last_run_at(self, repo_spec: RepoSpec, last_run_at: int): - base_repos.update_repo_last_run_at(self._storage, repo_spec, last_run_at) - - def get_repo_credentials(self, repo_id: str) -> Optional[RemoteRepoCredentials]: - return base_repos.get_repo_credentials(self._secrets_manager, repo_id) - - def save_repo_credentials(self, repo_id: str, repo_credentials: RemoteRepoCredentials): - base_repos.save_repo_credentials(self._secrets_manager, repo_id, repo_credentials) - - def delete_repo(self, repo_id: str): - base_repos.delete_repo(self._storage, repo_id) - - def list_secret_names(self, repo_id: str) -> List[str]: - return base_secrets.list_secret_names(self._storage, repo_id) - - def get_secret(self, repo_id: str, secret_name: str) -> Optional[Secret]: - return base_secrets.get_secret(self._secrets_manager, repo_id, secret_name) - - def add_secret(self, repo_id: str, secret: Secret): - base_secrets.add_secret(self._storage, self._secrets_manager, repo_id, secret) - - def update_secret(self, repo_id: str, secret: Secret): - base_secrets.update_secret(self._storage, self._secrets_manager, repo_id, secret) - - def delete_secret(self, repo_id: str, secret_name: str): - base_secrets.delete_secret(self._storage, self._secrets_manager, repo_id, secret_name) - - def delete_configuration_cache( - self, repo_id: str, hub_user_name: str, configuration_path: str - ): - base_cache.delete_configuration_cache( - self._storage, repo_id, hub_user_name, configuration_path - ) - - def get_signed_download_url(self, object_key: str) -> str: - # Implemented by Hub - raise NotImplementedError() - - def get_signed_upload_url(self, object_key: str) -> str: - # Implemented by Hub - raise NotImplementedError() + def logging(self) -> LocalLogging: + return self._logging diff --git a/cli/dstack/_internal/backend/local/config.py b/cli/dstack/_internal/backend/local/config.py index c1a957fb6..8d4c3eb1e 100644 --- a/cli/dstack/_internal/backend/local/config.py +++ b/cli/dstack/_internal/backend/local/config.py @@ -1,26 +1,28 @@ from typing import Dict +from pydantic import BaseModel, ValidationError +from typing_extensions import Literal + from dstack._internal.backend.base.config import BackendConfig from dstack._internal.utils.common import get_dstack_dir -class LocalConfig(BackendConfig): - def __init__(self, namespace: str): - self.namespace = namespace - self.backend_dir = get_dstack_dir() / "local_backend" / self.namespace +class LocalConfig(BackendConfig, BaseModel): + backend: Literal["local"] = "local" + namespace: str def serialize(self) -> Dict: - return { - "backend": "local", - "namespace": self.namespace, - } + return self.dict() @classmethod def deserialize(cls, config_data: Dict) -> "LocalConfig": if config_data.get("backend") != "local": return None try: - namespace = config_data["namespace"] - except KeyError: + return cls.parse_obj(config_data) + except ValidationError: return None - return LocalConfig(namespace) + + @property + def backend_dir(self): + return get_dstack_dir() / "local_backend" / self.namespace diff --git a/cli/dstack/_internal/backend/local/logs.py b/cli/dstack/_internal/backend/local/logs.py index 1b6196a8a..7aef3ac58 100644 --- a/cli/dstack/_internal/backend/local/logs.py +++ b/cli/dstack/_internal/backend/local/logs.py @@ -5,7 +5,7 @@ from file_read_backwards import FileReadBackwards from dstack._internal.backend.base import jobs as base_jobs -from dstack._internal.backend.base.logs import fix_log_event_urls, render_log_event +from dstack._internal.backend.base.logs import Logging, fix_log_event_urls, render_log_event from dstack._internal.backend.base.storage import Storage from dstack._internal.backend.local.config import LocalConfig from dstack._internal.core.log_event import LogEvent @@ -13,46 +13,59 @@ LOGS_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f%z" -def poll_logs( - backend_config: LocalConfig, - storage: Storage, - repo_id: str, - run_name: str, - start_time: datetime, - end_time: Optional[datetime], - descending: bool, - diagnose: bool, -) -> Generator[LogEvent, None, None]: - jobs = base_jobs.list_jobs(storage, repo_id, run_name) - jobs_map = {j.job_id: j for j in jobs} - if diagnose: - runner_id = jobs[0].runner_id - logs_filepath = ( - backend_config.backend_dir / "logs" / "dstack" / "runners" / f"{runner_id}.log" - ) - else: - logs_filepath = ( - backend_config.backend_dir / "logs" / "dstack" / "jobs" / repo_id / f"{run_name}.log" - ) - if descending: - log_file = FileReadBackwards(logs_filepath) - else: - log_file = open(logs_filepath, "r") - found_log = False - with log_file as f: - for line in f: - event = _log_line_to_log_event(line) - if start_time <= event["timestamp"] and ( - end_time is None or event["timestamp"] <= end_time - ): - found_log = True - log_event = render_log_event(event) - if not diagnose: - log_event = fix_log_event_urls(log_event, jobs_map) - yield log_event - else: - if found_log: - break +class LocalLogging(Logging): + def __init__(self, backend_config: LocalConfig): + self.backend_config = backend_config + + def poll_logs( + self, + storage: Storage, + repo_id: str, + run_name: str, + start_time: datetime, + end_time: Optional[datetime], + descending: bool, + diagnose: bool, + ) -> Generator[LogEvent, None, None]: + jobs = base_jobs.list_jobs(storage, repo_id, run_name) + jobs_map = {j.job_id: j for j in jobs} + if diagnose: + runner_id = jobs[0].runner_id + logs_filepath = ( + self.backend_config.backend_dir + / "logs" + / "dstack" + / "runners" + / f"{runner_id}.log" + ) + else: + logs_filepath = ( + self.backend_config.backend_dir + / "logs" + / "dstack" + / "jobs" + / repo_id + / f"{run_name}.log" + ) + if descending: + log_file = FileReadBackwards(logs_filepath) + else: + log_file = open(logs_filepath, "r") + found_log = False + with log_file as f: + for line in f: + event = _log_line_to_log_event(line) + if start_time <= event["timestamp"] and ( + end_time is None or event["timestamp"] <= end_time + ): + found_log = True + log_event = render_log_event(event) + if not diagnose: + log_event = fix_log_event_urls(log_event, jobs_map) + yield log_event + else: + if found_log: + break def _log_line_to_log_event(line: str) -> Dict: diff --git a/cli/dstack/_internal/hub/services/backends/aws/configurator.py b/cli/dstack/_internal/hub/services/backends/aws/configurator.py index bd0a89161..ac79cc66e 100644 --- a/cli/dstack/_internal/hub/services/backends/aws/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/aws/configurator.py @@ -5,7 +5,7 @@ from boto3.session import Session from dstack._internal.backend.aws import AwsBackend -from dstack._internal.backend.aws.config import AWSConfig +from dstack._internal.backend.aws.config import DEFAULT_REGION_NAME, AWSConfig from dstack._internal.hub.db.models import Project from dstack._internal.hub.models import ( AWSBucketProjectElement, @@ -49,7 +49,7 @@ def configure_project( project_values = AWSProjectValues() session = Session() if session.region_name is None: - session = Session(region_name=project_config.region_name) + session = Session(region_name=project_config.region_name or DEFAULT_REGION_NAME) project_values.default_credentials = self._valid_credentials(session=session) @@ -71,13 +71,15 @@ def configure_project( self._raise_invalid_credentials_error(fields=[["credentials"]]) # TODO validate config values - project_values.region_name = self._get_hub_regions(default_region=session.region_name) - project_values.s3_bucket_name = self._get_hub_buckets( + project_values.region_name = self._get_hub_regions_element( + default_region=session.region_name + ) + project_values.s3_bucket_name = self._get_hub_buckets_element( session=session, - region=project_config.region_name, + region=session.region_name, default_bucket=project_config.s3_bucket_name, ) - project_values.ec2_subnet_id = self._get_hub_subnet( + project_values.ec2_subnet_id = self._get_hub_subnet_element( session=session, default_subnet=project_config.ec2_subnet_id ) return project_values @@ -122,7 +124,6 @@ def get_backend(self, project: Project) -> AwsBackend: or config_data.get("bucket_name") or config_data.get("s3_bucket_name"), region_name=config_data.get("region_name"), - profile_name=config_data.get("profile_name"), subnet_id=config_data.get("subnet_id") or config_data.get("ec2_subnet_id") or config_data.get("subnet"), @@ -145,13 +146,13 @@ def _raise_invalid_credentials_error(self, fields: Optional[List[List[str]]] = N fields=fields, ) - def _get_hub_regions(self, default_region: Optional[str]) -> ProjectElement: + def _get_hub_regions_element(self, default_region: Optional[str]) -> ProjectElement: element = ProjectElement(selected=default_region) for r in REGIONS: element.values.append(ProjectElementValue(value=r[1], label=r[0])) return element - def _get_hub_buckets( + def _get_hub_buckets_element( self, session: Session, region: str, default_bucket: Optional[str] ) -> AWSBucketProjectElement: if default_bucket is not None: @@ -193,7 +194,9 @@ def _validate_hub_bucket(self, session: Session, region: str, bucket_name: str): ) raise e - def _get_hub_subnet(self, session: Session, default_subnet: Optional[str]) -> ProjectElement: + def _get_hub_subnet_element( + self, session: Session, default_subnet: Optional[str] + ) -> ProjectElement: element = ProjectElement(selected=default_subnet) _ec2 = session.client("ec2") response = _ec2.describe_subnets() diff --git a/cli/dstack/_internal/hub/services/backends/gcp/configurator.py b/cli/dstack/_internal/hub/services/backends/gcp/configurator.py index 8597e7ee6..ac64004a5 100644 --- a/cli/dstack/_internal/hub/services/backends/gcp/configurator.py +++ b/cli/dstack/_internal/hub/services/backends/gcp/configurator.py @@ -178,7 +178,7 @@ def create_project(self, project_config: GCPProjectConfigWithCreds) -> Tuple[Dic "region": project_config.region, "zone": project_config.zone, "bucket_name": project_config.bucket_name, - "vpc": project_config.bucket_name, + "vpc": project_config.vpc, "subnet": project_config.subnet, } return config_data, auth_data