From 124fd309ac1a319c86c6d4fdca3d820e32cd721c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 13 Mar 2025 15:31:53 +0500 Subject: [PATCH 1/7] Improve Python API --- .../cli/services/configurators/run.py | 23 +- .../_internal/core/models/configurations.py | 16 +- src/dstack/_internal/core/models/profiles.py | 30 +-- .../_internal/core/models/repos/local.py | 4 +- .../_internal/core/models/repos/remote.py | 8 +- src/dstack/_internal/core/models/runs.py | 14 +- src/dstack/api/__init__.py | 4 + src/dstack/api/_public/__init__.py | 30 +-- src/dstack/api/_public/repos.py | 70 +++--- src/dstack/api/_public/resources.py | 105 --------- src/dstack/api/_public/runs.py | 220 +++++++++++++----- src/dstack/api/server/__init__.py | 18 +- 12 files changed, 264 insertions(+), 278 deletions(-) delete mode 100644 src/dstack/api/_public/resources.py diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index 4f2ca9512..561fbe286 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -85,26 +85,11 @@ def apply_configuration( ) profile = load_profile(Path.cwd(), configurator_args.profile) with console.status("Getting apply plan..."): - run_plan = self.api.runs.get_plan( + run_plan = self.api.runs.get_run_plan( configuration=conf, repo=repo, configuration_path=configuration_path, - backends=profile.backends, - regions=profile.regions, - instance_types=profile.instance_types, - reservation=profile.reservation, - spot_policy=profile.spot_policy, - retry_policy=profile.retry_policy, - utilization_policy=profile.utilization_policy, - max_duration=profile.max_duration, - stop_duration=profile.stop_duration, - max_price=profile.max_price, - working_dir=conf.working_dir, - run_name=conf.name, - creation_policy=profile.creation_policy, - termination_policy=profile.termination_policy, - termination_policy_idle=profile.termination_idle_time, - idle_duration=profile.idle_duration, + profile=profile, ) print_run_plan(run_plan, offers_limit=configurator_args.max_offers) @@ -163,8 +148,8 @@ def apply_configuration( try: with console.status("Applying plan..."): - run = self.api.runs.exec_plan( - run_plan, repo, reserve_ports=not command_args.detach + run = self.api.runs.apply_plan( + run_plan=run_plan, repo=repo, reserve_ports=not command_args.detach ) except ServerClientError as e: raise CLIError(e.msg) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 090acc505..565c2d6c7 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -93,7 +93,9 @@ class BaseRunConfiguration(CoreModel): Optional[str], Field(description="The run name. If not specified, a random name is generated"), ] = None - image: Annotated[Optional[str], Field(description="The name of the Docker image to run")] + image: Annotated[Optional[str], Field(description="The name of the Docker image to run")] = ( + None + ) user: Annotated[ Optional[str], Field( @@ -104,7 +106,7 @@ class BaseRunConfiguration(CoreModel): ), ] = None privileged: Annotated[bool, Field(description="Run the container in privileged mode")] = False - entrypoint: Annotated[Optional[str], Field(description="The Docker entrypoint")] + entrypoint: Annotated[Optional[str], Field(description="The Docker entrypoint")] = None working_dir: Annotated[ Optional[str], Field( @@ -119,17 +121,17 @@ class BaseRunConfiguration(CoreModel): home_dir: str = "/root" registry_auth: Annotated[ Optional[RegistryAuth], Field(description="Credentials for pulling a private Docker image") - ] + ] = None python: Annotated[ Optional[PythonVersion], Field(description="The major version of Python. Mutually exclusive with `image`"), - ] + ] = None nvcc: Annotated[ Optional[bool], Field( description="Use image with NVIDIA CUDA Compiler (NVCC) included. Mutually exclusive with `image`" ), - ] + ] = None single_branch: Annotated[ Optional[bool], Field( @@ -209,7 +211,7 @@ def check_image_or_commands_present(cls, values): class DevEnvironmentConfigurationParams(CoreModel): ide: Annotated[Literal["vscode"], Field(description="The IDE to run")] - version: Annotated[Optional[str], Field(description="The version of the IDE")] + version: Annotated[Optional[str], Field(description="The version of the IDE")] = None init: Annotated[CommandsList, Field(description="The bash commands to run on startup")] = [] inactivity_duration: Annotated[ Optional[Union[Literal["off"], int, bool, str]], @@ -225,7 +227,7 @@ class DevEnvironmentConfigurationParams(CoreModel): " Defaults to `off`" ) ), - ] + ] = None @validator("inactivity_duration", pre=True, allow_reuse=True) def parse_inactivity_duration( diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index 7ae385dc8..630c26baa 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -154,13 +154,13 @@ class ProfileParams(CoreModel): backends: Annotated[ Optional[List[BackendType]], Field(description="The backends to consider for provisioning (e.g., `[aws, gcp]`)"), - ] + ] = None regions: Annotated[ Optional[List[str]], Field( description="The regions to consider for provisioning (e.g., `[eu-west-1, us-west4, westeurope]`)" ), - ] + ] = None availability_zones: Annotated[ Optional[List[str]], Field( @@ -172,7 +172,7 @@ class ProfileParams(CoreModel): Field( description="The cloud-specific instance types to consider for provisioning (e.g., `[p3.8xlarge, n1-standard-4]`)" ), - ] + ] = None reservation: Annotated[ Optional[str], Field( @@ -181,17 +181,17 @@ class ProfileParams(CoreModel): " Supports AWS Capacity Reservations and Capacity Blocks" ) ), - ] + ] = None spot_policy: Annotated[ Optional[SpotPolicy], Field( description="The policy for provisioning spot or on-demand instances: `spot`, `on-demand`, or `auto`. Defaults to `on-demand`" ), - ] + ] = None retry: Annotated[ Optional[Union[ProfileRetry, bool]], Field(description="The policy for resubmitting the run. Defaults to `false`"), - ] + ] = None max_duration: Annotated[ Optional[Union[Literal["off"], str, int, bool]], Field( @@ -201,7 +201,7 @@ class ProfileParams(CoreModel): " Use `off` for unlimited duration. Defaults to `off`" ) ), - ] + ] = None stop_duration: Annotated[ Optional[Union[Literal["off"], str, int, bool]], Field( @@ -212,17 +212,17 @@ class ProfileParams(CoreModel): " Use `off` for unlimited duration. Defaults to `5m`" ) ), - ] + ] = None max_price: Annotated[ Optional[float], Field(description="The maximum instance price per hour, in dollars", gt=0.0), - ] + ] = None creation_policy: Annotated[ Optional[CreationPolicy], Field( description="The policy for using instances from fleets. Defaults to `reuse-or-create`" ), - ] + ] = None idle_duration: Annotated[ Optional[Union[Literal["off"], str, int, bool]], Field( @@ -231,26 +231,26 @@ class ProfileParams(CoreModel): " Defaults to `5m` for runs and `3d` for fleets. Use `off` for unlimited duration" ) ), - ] + ] = None utilization_policy: Annotated[ Optional[UtilizationPolicy], Field(description="Run termination policy based on utilization"), - ] + ] = None # Deprecated: termination_policy: Annotated[ Optional[TerminationPolicy], Field( description="Deprecated in favor of `idle_duration`", ), - ] + ] = None termination_idle_time: Annotated[ Optional[Union[str, int]], Field( description="Deprecated in favor of `idle_duration`", ), - ] + ] = None # The policy for resubmitting the run. Deprecated in favor of `retry` - retry_policy: Optional[ProfileRetryPolicy] + retry_policy: Optional[ProfileRetryPolicy] = None _validate_max_duration = validator("max_duration", pre=True, allow_reuse=True)( parse_max_duration diff --git a/src/dstack/_internal/core/models/repos/local.py b/src/dstack/_internal/core/models/repos/local.py index 07773f2f9..2aa09caa4 100644 --- a/src/dstack/_internal/core/models/repos/local.py +++ b/src/dstack/_internal/core/models/repos/local.py @@ -41,10 +41,10 @@ def from_dir(repo_dir: PathLike) -> "LocalRepo": Creates an instance of a local repo from a local path. Args: - repo_dir: The path to a local folder + repo_dir: The path to a local folder. Returns: - A local repo instance + A local repo instance. """ return LocalRepo(repo_dir=repo_dir) diff --git a/src/dstack/_internal/core/models/repos/remote.py b/src/dstack/_internal/core/models/repos/remote.py index e8001722a..4f1c2c0be 100644 --- a/src/dstack/_internal/core/models/repos/remote.py +++ b/src/dstack/_internal/core/models/repos/remote.py @@ -100,10 +100,10 @@ def from_dir(repo_dir: PathLike) -> "RemoteRepo": Creates an instance of a remote repo from a local path. Args: - repo_dir: The path to a local folder + repo_dir: The path to a local folder. Returns: - A remote repo instance + A remote repo instance. """ return RemoteRepo(local_repo_dir=repo_dir) @@ -115,12 +115,12 @@ def from_url( Creates an instance of a remote repo from a URL. Args: - repo_url: The URL of a remote Git repo + repo_url: The URL of a remote Git repo. repo_branch: The name of the remote branch. Must be specified if `hash` is not specified. repo_hash: The hash of the revision. Must be specified if `branch` is not specified. Returns: - A remote repo instance + A remote repo instance. """ if repo_branch is None and repo_hash is None: raise ValueError("Either `repo_branch` or `repo_hash` must be specified.") diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index cef7b1a75..bf4022f1f 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -312,7 +312,7 @@ class RunSpec(CoreModel): run_name: Annotated[ Optional[str], Field(description="The run name. If not set, the run name is generated automatically."), - ] + ] = None repo_id: Annotated[ Optional[str], Field( @@ -322,15 +322,15 @@ class RunSpec(CoreModel): " If not specified, a default virtual repo is used." ) ), - ] + ] = None repo_data: Annotated[ Optional[AnyRunRepoData], Field( discriminator="repo_type", description="The repo data such as the current branch and commit.", ), - ] - repo_code_hash: Annotated[Optional[str], Field(description="The hash of the repo diff")] + ] = None + repo_code_hash: Annotated[Optional[str], Field(description="The hash of the repo diff")] = None working_dir: Annotated[ Optional[str], Field( @@ -340,7 +340,7 @@ class RunSpec(CoreModel): ' Defaults to `"."`.' ) ), - ] + ] = None configuration_path: Annotated[ Optional[str], Field( @@ -349,9 +349,9 @@ class RunSpec(CoreModel): " It can be omitted when using the programmatic API." ) ), - ] + ] = None configuration: Annotated[AnyRunConfiguration, Field(discriminator="type")] - profile: Annotated[Optional[Profile], Field(description="The profile parameters")] + profile: Annotated[Optional[Profile], Field(description="The profile parameters")] = None ssh_key_pub: Annotated[ str, Field( diff --git a/src/dstack/api/__init__.py b/src/dstack/api/__init__.py index d1fe55a07..0a4dce191 100644 --- a/src/dstack/api/__init__.py +++ b/src/dstack/api/__init__.py @@ -2,6 +2,9 @@ from dstack._internal.core.errors import ClientError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import RegistryAuth +from dstack._internal.core.models.configurations import ( + DevEnvironmentConfiguration as _DevEnvironmentConfiguration, +) from dstack._internal.core.models.configurations import ScalingSpec as Scaling from dstack._internal.core.models.configurations import ( ServiceConfiguration as _ServiceConfiguration, @@ -22,3 +25,4 @@ Service = _ServiceConfiguration Task = _TaskConfiguration +DevEnvironment = _DevEnvironmentConfiguration diff --git a/src/dstack/api/_public/__init__.py b/src/dstack/api/_public/__init__.py index 4a2af8702..44ad85597 100644 --- a/src/dstack/api/_public/__init__.py +++ b/src/dstack/api/_public/__init__.py @@ -15,12 +15,14 @@ class Client: """ - High-level API client for interacting with dstack server + High-level API client for interacting with the `dstack` server Attributes: + project: The project name. runs: Operations with runs. repos: Operations with repositories. backends: Operations with backends. + client: Low-level API client that supports all API endpoints. """ def __init__( @@ -56,13 +58,13 @@ def from_config( Creates a Client using the default configuration from `~/.dstack/config.yml` if it exists. Args: - project_name: The name of the project, required if `server_url` and `user_token` are specified - server_url: The dstack server URL (e.g. `http://localhost:3000/` or `https://sky.dstack.ai`) - user_token: The dstack user token - ssh_identity_file: The private SSH key path for SSH tunneling + project_name: The name of the project. required if `server_url` and `user_token` are specified. + server_url: The dstack server URL (e.g. `http://localhost:3000/` or `https://sky.dstack.ai`). + user_token: The dstack user token. + ssh_identity_file: The private SSH key path for SSH tunneling. Returns: - A client instance + A client instance. """ if server_url is not None and user_token is not None: if project_name is None: @@ -76,6 +78,14 @@ def from_config( ssh_identity_file=ssh_identity_file, ) + @property + def project(self) -> str: + return self._project + + @property + def runs(self) -> RunCollection: + return self._runs + @property def repos(self) -> RepoCollection: return self._repos @@ -84,14 +94,6 @@ def repos(self) -> RepoCollection: def backends(self) -> BackendCollection: return self._backends - @property - def runs(self) -> RunCollection: - return self._runs - @property def client(self) -> APIClient: return self._client - - @property - def project(self) -> str: - return self._project diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py index 7034eb825..f94429846 100644 --- a/src/dstack/api/_public/repos.py +++ b/src/dstack/api/_public/repos.py @@ -78,25 +78,6 @@ def init( raise ConfigurationError(*e.args) self._api_client.repos.init(self._project, repo.repo_id, repo.get_repo_info(), creds) - def is_initialized( - self, - repo: Repo, - ) -> bool: - # """ - # Checks if the remote repo is initialized in the project - # - # Args: - # repo: repo to check - # - # Returns: - # repo is initialized - # """ - try: - self._api_client.repos.get(self._project, repo.repo_id, include_creds=False) - return True - except ResourceNotExistsError: - return False - def load( self, repo_dir: PathLike, @@ -105,22 +86,22 @@ def load( git_identity_file: Optional[PathLike] = None, oauth_token: Optional[str] = None, ) -> Union[RemoteRepo, LocalRepo]: - # """ - # Loads the repo from the local directory using global config - # - # Args: - # repo_dir: repo root directory - # local: do not try to load `RemoteRepo` first - # init: initialize the repo if it's not initialized - # git_identity_file: path to an SSH private key to access the remote repo - # oauth_token: GitHub OAuth token to access the remote repo - # - # Raises: - # ConfigurationError: if the repo is not initialized and `init` is `False` - # - # Returns: - # repo: initialized repo - # """ + """ + Loads the repo from the local directory using global config + + Args: + repo_dir: Repo root directory. + local: Do not try to load `RemoteRepo` first. + init: Initialize the repo if it's not initialized. + git_identity_file: Path to an SSH private key to access the remote repo. + oauth_token: GitHub OAuth token to access the remote repo. + + Raises: + ConfigurationError: If the repo is not initialized and `init` is `False`. + + Returns: + repo: Initialized repo. + """ config = ConfigManager() if not init: logger.debug("Loading repo config") @@ -155,6 +136,25 @@ def load( ) return repo + def is_initialized( + self, + repo: Repo, + ) -> bool: + """ + Checks if the remote repo is initialized in the project + + Args: + repo: The repo to check. + + Returns: + Whether the repo is initialized or not. + """ + try: + self._api_client.repos.get(self._project, repo.repo_id, include_creds=False) + return True + except ResourceNotExistsError: + return False + def get_ssh_keypair(key_path: Optional[PathLike], dstack_key_path: Path) -> str: """Returns a path to the private key""" diff --git a/src/dstack/api/_public/resources.py b/src/dstack/api/_public/resources.py deleted file mode 100644 index d32ec4bd4..000000000 --- a/src/dstack/api/_public/resources.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import List, Optional, Union - -from dstack._internal.core.models.resources import ( - DEFAULT_CPU_COUNT, - DEFAULT_GPU_COUNT, - DEFAULT_MEMORY_SIZE, - ComputeCapabilityLike, - DiskLike, - DiskSpec, - DiskSpecSchema, - GPULike, - GPUSpec, - GPUSpecSchema, - IntRangeLike, - MemoryLike, - MemoryRangeLike, - ResourcesSpec, - ResourcesSpecSchema, -) - - -# TODO(andrey): This method looks like a workaround and possibly must be reworked (replaced with something else). -# Currently it's only used by the `dstack pool add` command. -def Resources( - *, - cpu: IntRangeLike = DEFAULT_CPU_COUNT, - memory: MemoryRangeLike = DEFAULT_MEMORY_SIZE, - gpu: Optional[GPULike] = None, - shm_size: Optional[MemoryLike] = None, - disk: Optional[DiskLike] = None, -) -> ResourcesSpec: - """ - Creates required resources specification. - - Args: - cpu (Optional[Range[int]]): The number of CPUs - memory (Optional[Range[Memory]]): The size of RAM memory (e.g., `"16GB"`) - gpu (Optional[GPUSpec]): The GPU spec - shm_size (Optional[Range[Memory]]): The of shared memory (e.g., `"8GB"`). If you are using parallel communicating processes (e.g., dataloaders in PyTorch), you may need to configure this. - disk (Optional[DiskSpec]): The disk spec - - Returns: - resources specification - """ - return ResourcesSpec.parse_obj( - ResourcesSpecSchema( - cpu=cpu, - memory=memory, - gpu=gpu, - shm_size=shm_size, - disk=disk, - ) - ) - - -def GPU( - *, - name: Optional[Union[List[str], str]] = None, - count: IntRangeLike = DEFAULT_GPU_COUNT, - memory: Optional[MemoryRangeLike] = None, - total_memory: Optional[MemoryRangeLike] = None, - compute_capability: Optional[ComputeCapabilityLike] = None, -) -> GPUSpec: - """ - Creates GPU specification. - - Args: - name (Optional[List[str]]): The name of the GPU (e.g., `"A100"` or `"H100"`) - count (Optional[Range[int]]): The number of GPUs - memory (Optional[Range[Memory]]): The size of a single GPU memory (e.g., `"16GB"`) - total_memory (Optional[Range[Memory]]): The total size of all GPUs memory (e.g., `"32GB"`) - compute_capability (Optional[float]): The minimum compute capability of the GPU (e.g., `7.5`) - - Returns: - GPU specification - """ - return GPUSpec.parse_obj( - GPUSpecSchema( - name=name, - count=count, - memory=memory, - total_memory=total_memory, - compute_capability=compute_capability, - ) - ) - - -def Disk( - *, - size: MemoryRangeLike, -) -> DiskSpec: - """ - Creates disk specification. - - Args: - size (Range[Memory]): The size of the disk (e.g., `"100GB"`) - - Returns: - disk specification - """ - return DiskSpec.parse_obj( - DiskSpecSchema( - size=size, - ) - ) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 3ef4c6163..44c0294bb 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -16,7 +16,6 @@ from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import ApplyAction from dstack._internal.core.models.configurations import AnyRunConfiguration, PortMapping from dstack._internal.core.models.profiles import ( CreationPolicy, @@ -188,14 +187,14 @@ def logs( job_num: int = 0, ) -> Iterable[bytes]: """ - Iterate through run's log messages + Iterate through run's log messages. Args: - start_time: minimal log timestamp - diagnose: return runner logs if `True` + start_time: Minimal log timestamp. + diagnose: Return runner logs if `True`. Yields: - log messages + Log messages. """ if diagnose is False and self._ssh_attach is not None: yield from self._attached_logs() @@ -225,17 +224,17 @@ def logs( def refresh(self): """ - Get up-to-date run info + Get up-to-date run info. """ self._run = self._api_client.runs.get(self._project, self._run.run_spec.run_name) logger.debug("Refreshed run %s: %s", self.name, self.status) def stop(self, abort: bool = False): """ - Terminate the instance and detach + Terminate the instance and detach. Args: - abort: gracefully stop the run if `False` + abort: Gracefully stop the run if `False`. """ self._api_client.runs.stop(self._project, [self.name], abort) logger.debug("%s run %s", "Aborted" if abort else "Stopped", self.name) @@ -253,7 +252,7 @@ def attach( Establish an SSH tunnel to the instance and update SSH config Args: - ssh_identity_file: SSH keypair to access instances + ssh_identity_file: SSH keypair to access instances. Raises: dstack.api.PortUsedError: If ports are in use or the run is attached by another process. @@ -390,7 +389,7 @@ def __repr__(self) -> str: class RunCollection: """ - Operations with runs + Operations with runs. """ def __init__( @@ -403,6 +402,117 @@ def __init__( self._project = project self._client = client + def get_run_plan( + self, + configuration: AnyRunConfiguration, + repo: Optional[Repo] = None, + profile: Optional[Profile] = None, + configuration_path: Optional[str] = None, + ) -> RunPlan: + """ + Get a run plan. + Use this method to see the run plan before applying the cofiguration. + + Args: + configuration (Union[Task, Service, DevEnvironment]): The run configuration. + repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): The repo to mount to the run. + profile: The profile to use for the run. + configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file. + + Returns: + Run plan. + """ + if repo is None: + repo = configuration.get_repo() + if repo is None: + raise ConfigurationError("Repo is required for this type of configuration") + + run_spec = RunSpec( + run_name=configuration.name, + repo_id=repo.repo_id, + repo_data=repo.run_repo_data, + repo_code_hash=None, # `apply_plan` will fill it + working_dir=configuration.working_dir, + configuration_path=configuration_path, + configuration=configuration, + profile=profile, + ssh_key_pub=Path(self._client.ssh_identity_file + ".pub").read_text().strip(), + ) + logger.debug("Getting run plan") + run_plan = self._api_client.runs.get_plan(self._project, run_spec) + return run_plan + + def apply_plan( + self, + run_plan: RunPlan, + repo: Repo, + reserve_ports: bool = True, + ) -> Run: + """ + Apply the run plan. + Use this method to apply run plans returned by `get_run_plan`. + + Args: + run_plan: Result of `get_run_plan` call. + repo: Repo to use for the run. + reserve_ports: Reserve local ports before submit. + + Returns: + Submitted run. + """ + ports_lock = None + if reserve_ports: + # TODO handle multiple jobs + ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec) + + with tempfile.TemporaryFile("w+b") as fp: + run_plan.run_spec.repo_code_hash = repo.write_code_file(fp) + fp.seek(0) + self._api_client.repos.upload_code( + self._project, repo.repo_id, run_plan.run_spec.repo_code_hash, fp + ) + run = self._api_client.runs.apply_plan(self._project, run_plan) + return self._model_to_submitted_run(run, ports_lock) + + def apply_configuration( + self, + configuration: AnyRunConfiguration, + repo: Optional[Repo] = None, + profile: Optional[Profile] = None, + configuration_path: Optional[str] = None, + reserve_ports: bool = True, + ) -> Run: + """ + Apply the run configuration. + Use this method to apply configurations without getting a run plan first. + + Args: + configuration (Union[Task, Service, DevEnvironment]): The run configuration. + repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): The repo to mount to the run. + profile: The profile to use for the run. + configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file. + + Returns: + Submitted run. + """ + if repo is None: + repo = configuration.get_repo() + if repo is None: + raise ConfigurationError("Repo is required for this type of configuration") + + run_plan = self.get_run_plan( + configuration=configuration, + repo=repo, + profile=profile, + configuration_path=configuration_path, + ) + run = self.apply_plan( + run_plan=run_plan, + repo=repo, + reserve_ports=reserve_ports, + ) + return run + def submit( self, configuration: AnyRunConfiguration, @@ -420,27 +530,28 @@ def submit( run_name: Optional[str] = None, reserve_ports: bool = True, ) -> Run: - """ - Submit a run + # """ + # Submit a run - Args: - configuration (Union[Task, Service]): A run configuration. - configuration_path: The path to the configuration file, relative to the root directory of the repo. - repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): A repo to mount to the run. - backends: A list of allowed backend for provisioning. - regions: A list of cloud regions for provisioning. - resources: The requirements to run the configuration. Overrides the configuration's resources. - spot_policy: A spot policy for provisioning. - retry_policy (RetryPolicy): A retry policy. - max_duration: The max instance running duration in seconds. - max_price: The max instance price in dollars per hour for provisioning. - working_dir: A working directory relative to the repo root directory - run_name: A desired name of the run. Must be unique in the project. If not specified, a random name is assigned. - reserve_ports: Whether local ports should be reserved in advance. + # Args: + # configuration (Union[Task, Service]): A run configuration. + # configuration_path: The path to the configuration file, relative to the root directory of the repo. + # repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): A repo to mount to the run. + # backends: A list of allowed backend for provisioning. + # regions: A list of cloud regions for provisioning. + # resources: The requirements to run the configuration. Overrides the configuration's resources. + # spot_policy: A spot policy for provisioning. + # retry_policy (RetryPolicy): A retry policy. + # max_duration: The max instance running duration in seconds. + # max_price: The max instance price in dollars per hour for provisioning. + # working_dir: A working directory relative to the repo root directory + # run_name: A desired name of the run. Must be unique in the project. If not specified, a random name is assigned. + # reserve_ports: Whether local ports should be reserved in advance. - Returns: - submitted run - """ + # Returns: + # Submitted run. + # """ + logger.warning("The submit() method is deprecated in favor of apply_configuration().") if repo is None: repo = configuration.get_repo() if repo is None: @@ -465,7 +576,7 @@ def submit( ) return self.exec_plan(run_plan, repo, reserve_ports=reserve_ports) - # TODO: [Andrey] I guess we need to drop profile-related fields (currently retry is not reflected there) + # Deprecated in favor of get_run_plan() def get_plan( self, configuration: AnyRunConfiguration, @@ -497,7 +608,7 @@ def get_plan( # Returns: # run plan # """ - + logger.warning("The get_plan() method is deprecated in favor of get_run_plan().") if repo is None: repo = configuration.get_repo() if repo is None: @@ -566,45 +677,28 @@ def exec_plan( reserve_ports: bool = True, ) -> Run: # """ - # Execute run plan - # + # Execute the run plan. + # Args: - # run_plan: result of `get_plan` call - # repo: repo to use for the run - # reserve_ports: reserve local ports before submit - # + # run_plan: Result of `get_run_plan` call. + # repo: Repo to use for the run. + # reserve_ports: Reserve local ports before submit. + # Returns: - # submitted run + # Submitted run. # """ - ports_lock = None - if reserve_ports: - # TODO handle multiple jobs - ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec) - - with tempfile.TemporaryFile("w+b") as fp: - run_plan.run_spec.repo_code_hash = repo.write_code_file(fp) - fp.seek(0) - self._api_client.repos.upload_code( - self._project, repo.repo_id, run_plan.run_spec.repo_code_hash, fp - ) - # Calling submit when action is CREATE since apply_plan is not backward-compatible. - # Otherwise, apply_plan can replace submit, i.e. it creates the run if it does not exist. - # TODO: Remove in 0.19 - if run_plan.action == ApplyAction.UPDATE: - run = self._api_client.runs.apply_plan(self._project, run_plan) - else: - run = self._api_client.runs.submit(self._project, run_plan.run_spec) - return self._model_to_submitted_run(run, ports_lock) + logger.warning("The exec_plan() method is deprecated in favor of apply_plan().") + return self.apply_plan(run_plan=run_plan, repo=repo, reserve_ports=reserve_ports) def list(self, all: bool = False) -> List[Run]: """ - List runs + List runs. Args: - all: show all runs (active and finished) if `True` + all: Show all runs (active and finished) if `True`. Returns: - list of runs + List of runs. """ # Return only one page of latest runs (<=100). Returning all the pages may be costly. # TODO: Consider introducing `since` filter with a reasonable default. @@ -623,13 +717,13 @@ def list(self, all: bool = False) -> List[Run]: def get(self, run_name: str) -> Optional[Run]: """ - Get run by run name + Get run by run name. Args: - run_name: run name + run_name: Run name. Returns: - The run or `None` if not found + The run or `None` if not found. """ try: run = self._api_client.runs.get(self._project, run_name) diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 5e8cf3d9d..8a8058965 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -29,22 +29,26 @@ class APIClient: """ - Low-level API client for interacting with dstack server. Implements all API endpoints + Low-level API client for interacting with the `dstack` server. + Supports all HTTP API endpoints. Attributes: users: operations with users projects: operations with projects backends: operations with backends + fleets: operations with fleets runs: operations with runs + metrics: operations with metrics logs: operations with logs gateways: operations with gateways + volumes: operations with volumes """ def __init__(self, base_url: str, token: str): """ Args: - base_url: API endpoints prefix, e.g. `http://127.0.0.1:3000/` - token: API token + base_url: The API endpoints prefix, e.g. `http://127.0.0.1:3000/`. + token: The API token. """ self._base_url = base_url.rstrip("/") self._token = token @@ -70,6 +74,10 @@ def projects(self) -> ProjectsAPIClient: def backends(self) -> BackendsAPIClient: return BackendsAPIClient(self._request) + @property + def fleets(self) -> FleetsAPIClient: + return FleetsAPIClient(self._request) + @property def repos(self) -> ReposAPIClient: return ReposAPIClient(self._request) @@ -94,10 +102,6 @@ def secrets(self) -> SecretsAPIClient: def gateways(self) -> GatewaysAPIClient: return GatewaysAPIClient(self._request) - @property - def fleets(self) -> FleetsAPIClient: - return FleetsAPIClient(self._request) - @property def volumes(self) -> VolumesAPIClient: return VolumesAPIClient(self._request) From d4b24cb51a802823a1788c5a693175b15f81c393 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 13 Mar 2025 15:35:07 +0500 Subject: [PATCH 2/7] Improve Python API * Introduce new client.runs methods: `get_run_plan`, `apply_plan`, `apply_configuration` * Deprecate old client.runs methods. * Improve Python API docs. * Update examples and docs to use the new . --- docs/docs/reference/api/python/index.md | 22 ++++++++++++++----- examples/misc/airflow/README.md | 6 ++--- examples/misc/airflow/dags/dstack_tasks.py | 8 +++---- .../_internal/core/models/repos/local.py | 2 +- .../_internal/core/models/repos/remote.py | 2 +- .../_internal/core/models/repos/virtual.py | 2 +- src/dstack/api/_public/repos.py | 2 +- 7 files changed, 27 insertions(+), 17 deletions(-) diff --git a/docs/docs/reference/api/python/index.md b/docs/docs/reference/api/python/index.md index 52e9d2f4b..f2974c1ba 100644 --- a/docs/docs/reference/api/python/index.md +++ b/docs/docs/reference/api/python/index.md @@ -14,6 +14,7 @@ from dstack.api import Task, GPU, Client, Resources client = Client.from_config() task = Task( + name="my-awesome-run", # If not specified, a random name is assigned image="ghcr.io/huggingface/text-generation-inference:latest", env={"MODEL_ID": "TheBloke/Llama-2-13B-chat-GPTQ"}, commands=[ @@ -23,8 +24,7 @@ task = Task( resources=Resources(gpu=GPU(memory="24GB")), ) -run = client.runs.submit( - run_name="my-awesome-run", # If not specified, a random name is assigned +run = client.runs.apply_configuration( configuration=task, repo=None, # Specify to mount additional files ) @@ -42,10 +42,9 @@ finally: ``` !!! info "NOTE:" - 1. The `configuration` argument in the `submit` method can be either `dstack.api.Task` or `dstack.api.Service`. - 2. If you create `dstack.api.Task` or `dstack.api.Service`, you may specify the `image` argument. If `image` isn't - specified, the default image will be used. For a private Docker registry, ensure you also pass the `registry_auth` argument. - 3. The `repo` argument in the `submit` method allows the mounting of a local folder, a remote repo, or a + 1. The `configuration` argument in the `apply_configuration` method can be either `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`. + 2. When you create `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`, you can specify the `image` argument. If `image` isn't specified, the default image will be used. For a private Docker registry, ensure you also pass the `registry_auth` argument. + 3. The `repo` argument in the `apply_configuration` method allows the mounting of a local folder, a remote repo, or a programmatically created repo. In this case, the `commands` argument can refer to the files within this repo. 4. The `attach` method waits for the run to start and, for `dstack.api.Task` sets up an SSH tunnel and forwards configured `ports` to `localhost`. @@ -109,6 +108,17 @@ finally: registry_auth: dstack.api.RegistryAuth resources: dstack.api.Resources +### `dstack.api.DevEnvironment` { #dstack.api.DevEnvironment data-toc-label="DevEnvironment" } + +#SCHEMA# dstack.api.DevEnvironment + overrides: + show_root_heading: false + show_root_toc_entry: false + heading_level: 4 + item_id_mapping: + registry_auth: dstack.api.RegistryAuth + resources: dstack.api.Resources + ### `dstack.api.Run` { #dstack.api.Run data-toc-label="Run" } ::: dstack.api.Run diff --git a/examples/misc/airflow/README.md b/examples/misc/airflow/README.md index d9959fb9a..13598d8e0 100644 --- a/examples/misc/airflow/README.md +++ b/examples/misc/airflow/README.md @@ -47,10 +47,11 @@ DSTACK_VENV_PYTHON_BINARY_PATH = f"{DSTACK_VENV_PATH}/bin/python" def pipeline(...): ... @task.external_python(task_id="external_python", python=DSTACK_VENV_PYTHON_BINARY_PATH) - def dstack_api_submit_venv() -> str: + def dstack_api_submit_venv(): from dstack.api import Client, Task task = Task( + name="my-airflow-task", commands=[ "echo 'Running dstack task via Airflow'", "sleep 10", @@ -61,8 +62,7 @@ def pipeline(...): # or set explicitly from Ariflow Variables. client = Client.from_config() - run = client.runs.submit( - run_name="my-airflow-task", + run = client.runs.apply_configuration( configuration=task, ) run.attach() diff --git a/examples/misc/airflow/dags/dstack_tasks.py b/examples/misc/airflow/dags/dstack_tasks.py index 6b02f096a..1f8bc1649 100644 --- a/examples/misc/airflow/dags/dstack_tasks.py +++ b/examples/misc/airflow/dags/dstack_tasks.py @@ -54,7 +54,7 @@ def dstack_cli_apply_venv() -> str: ) @task.external_python(task_id="external_python", python=DSTACK_VENV_PYTHON_BINARY_PATH) - def dstack_api_submit_venv() -> str: + def dstack_api_submit_venv(): """ This task shows how to run the dstack API when dstack is installed into a separate virtual environment available to Airflow. @@ -63,18 +63,18 @@ def dstack_api_submit_venv() -> str: from dstack.api import Client, Task task = Task( + name="my-airflow-task", commands=[ "echo 'Running dstack task via Airflow'", "sleep 10", "echo 'Finished'", - ] + ], ) # Pick up config from `~/.dstack/config.yml` # or set explicitly from Ariflow Variables. client = Client.from_config() - run = client.runs.submit( - run_name="my-airflow-task", + run = client.runs.apply_configuration( configuration=task, ) run.attach() diff --git a/src/dstack/_internal/core/models/repos/local.py b/src/dstack/_internal/core/models/repos/local.py index 2aa09caa4..1bc815f12 100644 --- a/src/dstack/_internal/core/models/repos/local.py +++ b/src/dstack/_internal/core/models/repos/local.py @@ -26,7 +26,7 @@ class LocalRepo(Repo): Example: ```python - run = client.runs.submit( + run = client.runs.apply_configuration( configuration=..., repo=LocalRepo.from_dir("."), # Mount the current folder to the run ) diff --git a/src/dstack/_internal/core/models/repos/remote.py b/src/dstack/_internal/core/models/repos/remote.py index 4f1c2c0be..39ddaf816 100644 --- a/src/dstack/_internal/core/models/repos/remote.py +++ b/src/dstack/_internal/core/models/repos/remote.py @@ -84,7 +84,7 @@ class RemoteRepo(Repo): Finally, you can pass the repo object to the run: ```python - run = client.runs.submit( + run = client.runs.apply_configuration( configuration=..., repo=repo, ) diff --git a/src/dstack/_internal/core/models/repos/virtual.py b/src/dstack/_internal/core/models/repos/virtual.py index 4ce79b2c7..4a975481a 100644 --- a/src/dstack/_internal/core/models/repos/virtual.py +++ b/src/dstack/_internal/core/models/repos/virtual.py @@ -30,7 +30,7 @@ class VirtualRepo(Repo): virtual_repo.add_file_from_package(package=some_package, path="requirements.txt") virtual_repo.add_file_from_package(package=some_package, path="train.py") - run = client.runs.submit( + run = client.runs.apply_configuration( configuration=..., repo=virtual_repo, ) diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py index f94429846..ced2bf8c9 100644 --- a/src/dstack/api/_public/repos.py +++ b/src/dstack/api/_public/repos.py @@ -55,7 +55,7 @@ def init( Once the repo is initialized, you can pass the repo object to the run: ```python - run = client.runs.submit( + run = client.runs.apply_configuration( configuration=..., repo=repo, ) From a19db915cb64eed8fc295d5e76d99be620b44189 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 13 Mar 2025 15:48:15 +0500 Subject: [PATCH 3/7] Deprecate /api/project/{project_name}/runs/submit --- src/dstack/_internal/server/routers/runs.py | 33 ++++++++++----------- src/dstack/api/server/_runs.py | 9 ------ 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 5a7cc5356..ff5fff70f 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -126,23 +126,6 @@ async def apply_plan( ) -# apply_plan replaces submit_run since it can create new runs. -# submit_run can be deprecated in the future. -@project_router.post("/submit") -async def submit_run( - body: SubmitRunRequest, - session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), -) -> Run: - user, project = user_project - return await runs.submit_run( - session=session, - user=user, - project=project, - run_spec=body.run_spec, - ) - - @project_router.post("/stop") async def stop_runs( body: StopRunsRequest, @@ -172,3 +155,19 @@ async def delete_runs( """ _, project = user_project await runs.delete_runs(session=session, project=project, runs_names=body.runs_names) + + +# apply_plan replaces submit_run since it can create new runs. +@project_router.post("/submit", deprecated=True) +async def submit_run( + body: SubmitRunRequest, + session: AsyncSession = Depends(get_session), + user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), +) -> Run: + user, project = user_project + return await runs.submit_run( + session=session, + user=user, + project=project, + run_spec=body.run_spec, + ) diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index e65bc05ca..77143f9ad 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -24,7 +24,6 @@ GetRunRequest, ListRunsRequest, StopRunsRequest, - SubmitRunRequest, ) from dstack.api.server._group import APIClientGroup @@ -83,14 +82,6 @@ def apply_plan( ) return parse_obj_as(Run.__response__, resp.json()) - def submit(self, project_name: str, run_spec: RunSpec) -> Run: - body = SubmitRunRequest(run_spec=run_spec) - resp = self._request( - f"/api/project/{project_name}/runs/submit", - body=body.json(exclude=_get_run_spec_excludes(run_spec)), - ) - return parse_obj_as(Run.__response__, resp.json()) - def stop(self, project_name: str, runs_names: List[str], abort: bool): body = StopRunsRequest(runs_names=runs_names, abort=abort) self._request(f"/api/project/{project_name}/runs/stop", body=body.json()) From a12df281d4287bfdf898a991a2d8b28fa673f821 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 13 Mar 2025 15:51:27 +0500 Subject: [PATCH 4/7] Remove missing fleets get_plan handling --- .../cli/services/configurators/fleet.py | 43 ++----------------- 1 file changed, 4 insertions(+), 39 deletions(-) diff --git a/src/dstack/_internal/cli/services/configurators/fleet.py b/src/dstack/_internal/cli/services/configurators/fleet.py index 35699f216..60f2f9267 100644 --- a/src/dstack/_internal/cli/services/configurators/fleet.py +++ b/src/dstack/_internal/cli/services/configurators/fleet.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import List, Optional -import requests from rich.table import Table from dstack._internal.cli.services.configurators.base import ( @@ -32,7 +31,6 @@ from dstack._internal.utils.common import local_time from dstack._internal.utils.logging import get_logger from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str -from dstack.api._public import Client from dstack.api.utils import load_profile logger = get_logger(__name__) @@ -60,7 +58,10 @@ def apply_configuration( _preprocess_spec(spec) with console.status("Getting apply plan..."): - plan = _get_plan(api=self.api, spec=spec) + plan = self.api.client.fleets.get_plan( + project_name=self.api.project, + spec=spec, + ) _print_plan_header(plan) action_message = "" @@ -234,42 +235,6 @@ def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]: exit() -def _get_plan(api: Client, spec: FleetSpec) -> FleetPlan: - try: - return api.client.fleets.get_plan( - project_name=api.project, - spec=spec, - ) - except requests.exceptions.HTTPError as e: - # Handle older server versions that do not have /get_plan for fleets - # TODO: Can be removed in 0.19 - if e.response.status_code == 405: - logger.warning( - "Fleet apply plan is not fully supported before 0.18.17. " - "Update the server to view full-featured apply plan." - ) - user = api.client.users.get_my_user() - spec.configuration_path = None - current_resource = None - if spec.configuration.name is not None: - try: - current_resource = api.client.fleets.get( - project_name=api.project, name=spec.configuration.name - ) - except ResourceNotExistsError: - pass - return FleetPlan( - project_name=api.project, - user=user.username, - spec=spec, - current_resource=current_resource, - offers=[], - total_offers=0, - max_offer_price=0, - ) - raise e - - def _print_plan_header(plan: FleetPlan): def th(s: str) -> str: return f"[bold]{s}[/bold]" From 1856b5036089b9e11d41cf2e2fc8f2387de61f3a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 13 Mar 2025 15:58:34 +0500 Subject: [PATCH 5/7] Remove missing get_job_metrics handling --- src/dstack/_internal/cli/commands/stats.py | 25 +++++++--------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/dstack/_internal/cli/commands/stats.py b/src/dstack/_internal/cli/commands/stats.py index 12f52346b..761be846f 100644 --- a/src/dstack/_internal/cli/commands/stats.py +++ b/src/dstack/_internal/cli/commands/stats.py @@ -2,7 +2,6 @@ import time from typing import Any, Dict, List, Optional, Union -import requests from rich.live import Live from rich.table import Table @@ -64,22 +63,14 @@ def _command(self, args: argparse.Namespace): def _get_run_jobs_metrics(api: Client, run: Run) -> List[JobMetrics]: metrics = [] - try: - for job in run._run.jobs: - job_metrics = api.client.metrics.get_job_metrics( - project_name=api.project, - run_name=run.name, - replica_num=job.job_spec.replica_num, - job_num=job.job_spec.job_num, - ) - metrics.append(job_metrics) - except requests.exceptions.HTTPError as e: - if e.response.status_code == 404: - raise CLIError( - "Metrics API is not supported for server versions before 0.18.18. " - "Update the server to use `dstack stats`." - ) - raise + for job in run._run.jobs: + job_metrics = api.client.metrics.get_job_metrics( + project_name=api.project, + run_name=run.name, + replica_num=job.job_spec.replica_num, + job_num=job.job_spec.job_num, + ) + metrics.append(job_metrics) return metrics From e96f64e14a0863ea64f20679a8fce3fd0b6ce5a3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 14 Mar 2025 11:40:48 +0500 Subject: [PATCH 6/7] Document reserve_ports --- src/dstack/api/_public/runs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 44c0294bb..11e791c85 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -455,7 +455,7 @@ def apply_plan( Args: run_plan: Result of `get_run_plan` call. repo: Repo to use for the run. - reserve_ports: Reserve local ports before submit. + reserve_ports: Reserve local ports before applying. Use if you'll attach to the run. Returns: Submitted run. @@ -491,6 +491,7 @@ def apply_configuration( repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): The repo to mount to the run. profile: The profile to use for the run. configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file. + reserve_ports: Reserve local ports before applying. Use if you'll attach to the run. Returns: Submitted run. From a939524ad535a63cd794105eb7b3e5636ea77fee Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Fri, 14 Mar 2025 13:11:50 +0500 Subject: [PATCH 7/7] Do not upload repo code diff when no repo --- .../_internal/core/models/configurations.py | 5 -- src/dstack/_internal/core/models/runs.py | 5 +- .../background/tasks/process_running_jobs.py | 15 ++++-- src/dstack/api/_public/runs.py | 51 ++++++++++--------- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 565c2d6c7..8cdc11eb6 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -11,8 +11,6 @@ from dstack._internal.core.models.fleets import FleetConfiguration from dstack._internal.core.models.gateways import GatewayConfiguration from dstack._internal.core.models.profiles import ProfileParams, parse_off_duration -from dstack._internal.core.models.repos.base import Repo -from dstack._internal.core.models.repos.virtual import VirtualRepo from dstack._internal.core.models.resources import Range, ResourcesSpec from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser @@ -180,9 +178,6 @@ def validate_user(cls, v) -> Optional[str]: UnixUser.parse(v) return v - def get_repo(self) -> Repo: - return VirtualRepo() - class BaseRunConfigurationWithPorts(BaseRunConfiguration): ports: Annotated[ diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index bf4022f1f..e3b820e86 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -330,7 +330,10 @@ class RunSpec(CoreModel): description="The repo data such as the current branch and commit.", ), ] = None - repo_code_hash: Annotated[Optional[str], Field(description="The hash of the repo diff")] = None + repo_code_hash: Annotated[ + Optional[str], + Field(description="The hash of the repo diff. Can be omitted if there is no repo diff."), + ] = None working_dir: Annotated[ Optional[str], Field( diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 489e9bb9d..cb8491a12 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -743,20 +743,29 @@ def _get_cluster_info( async def _get_job_code( - session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: str + session: AsyncSession, project: ProjectModel, repo: RepoModel, code_hash: Optional[str] ) -> bytes: + if code_hash is None: + return b"" code_model = await get_code_model(session=session, repo=repo, code_hash=code_hash) if code_model is None: return b"" - storage = get_default_storage() - if storage is None or code_model.blob is not None: + if code_model.blob is not None: return code_model.blob + storage = get_default_storage() + if storage is None: + return b"" blob = await common_utils.run_async( storage.get_code, project.name, repo.name, code_hash, ) + if blob is None: + logger.error( + "Failed to get repo code hash %s from storage for repo %s", code_hash, repo.name + ) + return b"" return blob diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 11e791c85..226ee111f 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -26,6 +26,7 @@ UtilizationPolicy, ) from dstack._internal.core.models.repos.base import Repo +from dstack._internal.core.models.repos.virtual import VirtualRepo from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import ( Job, @@ -415,7 +416,8 @@ def get_run_plan( Args: configuration (Union[Task, Service, DevEnvironment]): The run configuration. - repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): The repo to mount to the run. + repo (Union[LocalRepo, RemoteRepo, VirtualRepo, None]): + The repo to use for the run. Pass `None` if repo is not needed. profile: The profile to use for the run. configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file. @@ -423,9 +425,7 @@ def get_run_plan( Run plan. """ if repo is None: - repo = configuration.get_repo() - if repo is None: - raise ConfigurationError("Repo is required for this type of configuration") + repo = VirtualRepo() run_spec = RunSpec( run_name=configuration.name, @@ -445,7 +445,7 @@ def get_run_plan( def apply_plan( self, run_plan: RunPlan, - repo: Repo, + repo: Optional[Repo] = None, reserve_ports: bool = True, ) -> Run: """ @@ -453,8 +453,9 @@ def apply_plan( Use this method to apply run plans returned by `get_run_plan`. Args: - run_plan: Result of `get_run_plan` call. - repo: Repo to use for the run. + run_plan: The result of `get_run_plan` call. + repo (Union[LocalRepo, RemoteRepo, VirtualRepo, None]): + The repo to use for the run. Should be the same repo that is passed to `get_run_plan`. reserve_ports: Reserve local ports before applying. Use if you'll attach to the run. Returns: @@ -465,12 +466,20 @@ def apply_plan( # TODO handle multiple jobs ports_lock = _reserve_ports(run_plan.job_plans[0].job_spec) - with tempfile.TemporaryFile("w+b") as fp: - run_plan.run_spec.repo_code_hash = repo.write_code_file(fp) - fp.seek(0) - self._api_client.repos.upload_code( - self._project, repo.repo_id, run_plan.run_spec.repo_code_hash, fp - ) + if repo is None: + repo = VirtualRepo() + else: + # Do not upload the diff without a repo (a default virtual repo) + # since upload_code() requires a repo to be initialized. + with tempfile.TemporaryFile("w+b") as fp: + run_plan.run_spec.repo_code_hash = repo.write_code_file(fp) + fp.seek(0) + self._api_client.repos.upload_code( + project_name=self._project, + repo_id=repo.repo_id, + code_hash=run_plan.run_spec.repo_code_hash, + fp=fp, + ) run = self._api_client.runs.apply_plan(self._project, run_plan) return self._model_to_submitted_run(run, ports_lock) @@ -488,7 +497,8 @@ def apply_configuration( Args: configuration (Union[Task, Service, DevEnvironment]): The run configuration. - repo (Union[LocalRepo, RemoteRepo, VirtualRepo]): The repo to mount to the run. + repo (Union[LocalRepo, RemoteRepo, VirtualRepo, None]): + The repo to use for the run. Pass `None` if repo is not needed. profile: The profile to use for the run. configuration_path: The path to the configuration file. Omit if the configuration is not loaded from a file. reserve_ports: Reserve local ports before applying. Use if you'll attach to the run. @@ -496,11 +506,6 @@ def apply_configuration( Returns: Submitted run. """ - if repo is None: - repo = configuration.get_repo() - if repo is None: - raise ConfigurationError("Repo is required for this type of configuration") - run_plan = self.get_run_plan( configuration=configuration, repo=repo, @@ -554,9 +559,7 @@ def submit( # """ logger.warning("The submit() method is deprecated in favor of apply_configuration().") if repo is None: - repo = configuration.get_repo() - if repo is None: - raise ConfigurationError("Repo is required for this type of configuration") + repo = VirtualRepo() # TODO: Add Git credentials to RemoteRepo and if they are set, pass them here to RepoCollection.init self._client.repos.init(repo) @@ -611,9 +614,7 @@ def get_plan( # """ logger.warning("The get_plan() method is deprecated in favor of get_run_plan().") if repo is None: - repo = configuration.get_repo() - if repo is None: - raise ConfigurationError("Repo is required for this type of configuration") + repo = VirtualRepo() if working_dir is None: working_dir = "."