Skip to content
Merged
13 changes: 13 additions & 0 deletions cli/dstack/_internal/backend/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from datetime import datetime
from typing import Generator, List, Optional

import dstack._internal.core.build
from dstack._internal.backend.base import artifacts as base_artifacts
from dstack._internal.backend.base import build as base_build
from dstack._internal.backend.base import cache as base_cache
from dstack._internal.backend.base import jobs as base_jobs
from dstack._internal.backend.base import repos as base_repos
Expand All @@ -14,6 +16,7 @@
from dstack._internal.backend.base.secrets import SecretsManager
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.artifact import Artifact
from dstack._internal.core.build import BuildPlan
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job, JobHead, JobStatus
from dstack._internal.core.log_event import LogEvent
Expand Down Expand Up @@ -230,6 +233,10 @@ def get_signed_download_url(self, object_key: str) -> str:
def get_signed_upload_url(self, object_key: str) -> str:
pass

@abstractmethod
def predict_build_plan(self, job: Job) -> BuildPlan:
pass


class ComponentBasedBackend(Backend):
@abstractmethod
Expand Down Expand Up @@ -264,6 +271,7 @@ def list_jobs(self, repo_id: str, run_name: str) -> List[Job]:
return base_jobs.list_jobs(self.storage(), repo_id, run_name)

def run_job(self, job: Job, failed_to_start_job_new_status: JobStatus):
self.predict_build_plan(job) # raises exception on missing build
base_jobs.run_job(self.storage(), self.compute(), job, failed_to_start_job_new_status)

def stop_job(self, repo_id: str, abort: bool, job_id: str):
Expand Down Expand Up @@ -435,3 +443,8 @@ def delete_configuration_cache(
base_cache.delete_configuration_cache(
self.storage(), repo_id, hub_user_name, configuration_path
)

def predict_build_plan(self, job: Job) -> BuildPlan:
return base_build.predict_build_plan(
self.storage(), job, dstack._internal.core.build.DockerPlatform.amd64
)
63 changes: 63 additions & 0 deletions cli/dstack/_internal/backend/base/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pathlib import Path
from platform import uname as platform_uname
from typing import Optional

import cpuinfo

from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.build import BuildNotFoundError, BuildPlan, DockerPlatform
from dstack._internal.core.job import Job
from dstack._internal.utils.escape import escape_head


def predict_build_plan(
storage: Storage, job: Job, platform: Optional[DockerPlatform]
) -> BuildPlan:
if job.build_policy in ["force-build", "build-only"]:
return BuildPlan.yes

if platform is None:
platform = guess_docker_platform()
if build_exists(storage, job, platform):
return BuildPlan.use

if job.build_commands:
if job.build_policy == "use-build":
raise BuildNotFoundError("Build not found. Run `dstack build` or add `--build` flag")
return BuildPlan.yes

if job.optional_build_commands and job.build_policy == "build":
return BuildPlan.yes
return BuildPlan.no


def build_exists(storage: Storage, job: Job, platform: DockerPlatform) -> bool:
prefix = _get_build_head_prefix(job, platform)
return len(storage.list_objects(prefix)) > 0


def _get_build_head_prefix(job: Job, platform: DockerPlatform) -> str:
parts = [
job.configuration_type.value,
job.configuration_path or "",
(Path("/workflow") / (job.working_dir or "")).as_posix(),
job.image_name,
platform.value,
# digest
# timestamp_utc
]
parts = ";".join(escape_head(p) for p in parts)
return f"builds/{job.repo_ref.repo_id}/{parts};"


def guess_docker_platform() -> DockerPlatform:
uname = platform_uname()
if uname.system == "Darwin":
brand = cpuinfo.get_cpu_info().get("brand_raw")
m_arch = "m1" in brand.lower() or "m2" in brand.lower()
arch = "arm64" if m_arch else "x86_64"
else:
arch = uname.machine
if uname.system == "Darwin" and arch in ["arm64", "aarch64"]:
return DockerPlatform.arm64
return DockerPlatform.amd64
2 changes: 2 additions & 0 deletions cli/dstack/_internal/backend/base/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import yaml

from dstack._internal.backend.base import runners
from dstack._internal.backend.base.build import predict_build_plan
from dstack._internal.backend.base.compute import Compute, NoCapacityError
from dstack._internal.backend.base.storage import Storage
from dstack._internal.core.build import DockerPlatform
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job, JobErrorCode, JobHead, JobStatus, SpotPolicy
Expand Down
7 changes: 7 additions & 0 deletions cli/dstack/_internal/backend/local/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Optional

from dstack._internal.backend.base import ComponentBasedBackend
from dstack._internal.backend.base import build as base_build
from dstack._internal.backend.local.compute import LocalCompute
from dstack._internal.backend.local.config import LocalConfig
from dstack._internal.backend.local.logs import LocalLogging
from dstack._internal.backend.local.secrets import LocalSecretsManager
from dstack._internal.backend.local.storage import LocalStorage
from dstack._internal.core.build import BuildPlan
from dstack._internal.core.job import Job


class LocalBackend(ComponentBasedBackend):
Expand Down Expand Up @@ -39,3 +42,7 @@ def secrets_manager(self) -> LocalSecretsManager:

def logging(self) -> LocalLogging:
return self._logging

def predict_build_plan(self, job: Job) -> BuildPlan:
# guess platform from uname
return base_build.predict_build_plan(self.storage(), job, platform=None)
8 changes: 4 additions & 4 deletions cli/dstack/_internal/cli/commands/build/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def _command(self, args: argparse.Namespace):
ssh_pub_key = _read_ssh_key_pub(config.repo_user_config.ssh_key_path)

run_plan = hub_client.get_run_plan(
provider_name=provider_name, provider_data=provider_data, args=args
configuration_path=configuration_path,
provider_name=provider_name,
provider_data=provider_data,
args=args,
)
console.print("dstack will execute the following plan:\n")
_print_run_plan(configuration_path, run_plan)
Expand All @@ -69,9 +72,6 @@ def _command(self, args: argparse.Namespace):
)
runs = list_runs_hub(hub_client, run_name=run_name)
run = runs[0]
if run.status == JobStatus.FAILED:
console.print("\nProvisioning failed\n")
exit(1)
_poll_run(
hub_client,
run,
Expand Down
15 changes: 13 additions & 2 deletions cli/dstack/_internal/cli/commands/run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def _command(self, args: Namespace):
ssh_pub_key = _read_ssh_key_pub(config.repo_user_config.ssh_key_path)

run_plan = hub_client.get_run_plan(
provider_name=provider_name, provider_data=provider_data, args=args
configuration_path=configuration_path,
provider_name=provider_name,
provider_data=provider_data,
args=args,
)
console.print("dstack will execute the following plan:\n")
_print_run_plan(configuration_path, run_plan)
Expand Down Expand Up @@ -184,12 +187,20 @@ def _print_run_plan(configuration_file: str, run_plan: RunPlan):
table.add_column("INSTANCE")
table.add_column("RESOURCES")
table.add_column("SPOT POLICY")
table.add_column("BUILD")
job_plan = run_plan.job_plans[0]
instance = job_plan.instance_type.instance_name or "-"
instance_info = _format_resources(job_plan.instance_type)
spot = job_plan.job.spot_policy.value
build_plan = job_plan.build_plan.value.title()
table.add_row(
configuration_file, run_plan.hub_user_name, run_plan.project, instance, instance_info, spot
configuration_file,
run_plan.hub_user_name,
run_plan.project,
instance,
instance_info,
spot,
build_plan,
)
console.print(table)
console.print()
Expand Down
3 changes: 2 additions & 1 deletion cli/dstack/_internal/cli/commands/run/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _parse_dev_environment_configuration_data(
"sea_green3]Command Palette[/sea_green3], executing [sea_green3]Shell Command: Install 'code' command in "
"PATH[/sea_green3], and restarting terminal.[/]\n"
)
provider_data["optional_build"].append("pip install -q --no-cache-dir ipykernel")
for key in ["optional_build", "commands"]:
provider_data[key].append("pip install -q --no-cache-dir ipykernel")
provider_data["commands"].extend(configuration_data.get("init") or [])
return provider_name, provider_data

Expand Down
18 changes: 18 additions & 0 deletions cli/dstack/_internal/core/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from enum import Enum

from dstack._internal.core.error import DstackError


class DockerPlatform(str, Enum):
amd64 = "amd64"
arm64 = "arm64"


class BuildPlan(str, Enum):
no = "no"
use = "use"
yes = "yes"


class BuildNotFoundError(DstackError):
code = "build_not_found"
2 changes: 2 additions & 0 deletions cli/dstack/_internal/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from pydantic import BaseModel

from dstack._internal.core.build import BuildPlan
from dstack._internal.core.instance import InstanceType
from dstack._internal.core.job import Job


class JobPlan(BaseModel):
job: Job
instance_type: InstanceType
build_plan: BuildPlan


class RunPlan(BaseModel):
Expand Down
6 changes: 6 additions & 0 deletions cli/dstack/_internal/hub/routers/runners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status

from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.job import Job, JobStatus
from dstack._internal.hub.models import StopRunners
Expand Down Expand Up @@ -29,6 +30,11 @@ async def run_runners(project_name: str, job: Job):
NoMatchingInstanceError.message, code=NoMatchingInstanceError.code
),
)
except BuildNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)


@router.post("/{project_name}/runners/stop")
Expand Down
10 changes: 9 additions & 1 deletion cli/dstack/_internal/hub/routers/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi.responses import PlainTextResponse

from dstack._internal.backend.base import Backend
from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.job import Job, JobStatus
from dstack._internal.core.plan import JobPlan, RunPlan
Expand Down Expand Up @@ -35,7 +36,14 @@ async def get_run_plan(
msg=NoMatchingInstanceError.message, code=NoMatchingInstanceError.code
),
)
job_plans.append(JobPlan(job=job, instance_type=instance_type))
try:
build = backend.predict_build_plan(job)
except BuildNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error_detail(msg=e.message, code=e.code),
)
job_plans.append(JobPlan(job=job, instance_type=instance_type, build_plan=build))
run_plan = RunPlan(project=project_name, hub_user_name=user.name, job_plans=job_plans)
return run_plan

Expand Down
5 changes: 5 additions & 0 deletions cli/dstack/api/hub/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests

from dstack._internal.core.artifact import Artifact
from dstack._internal.core.build import BuildNotFoundError
from dstack._internal.core.error import NoMatchingInstanceError
from dstack._internal.core.job import Job, JobHead
from dstack._internal.core.log_event import LogEvent
Expand Down Expand Up @@ -83,6 +84,8 @@ def get_run_plan(self, jobs: List[Job]) -> RunPlan:
body = resp.json()
if body["detail"]["code"] == NoMatchingInstanceError.code:
raise HubClientError(body["detail"]["msg"])
elif body["detail"]["code"] == BuildNotFoundError.code:
raise HubClientError(body["detail"]["msg"])
resp.raise_for_status()

def create_run(self) -> str:
Expand Down Expand Up @@ -168,6 +171,8 @@ def run_job(self, job: Job):
body = resp.json()
if body["detail"]["code"] == NoMatchingInstanceError.code:
raise HubClientError(body["detail"]["msg"])
elif body["detail"]["code"] == BuildNotFoundError.code:
raise HubClientError(body["detail"]["msg"])
resp.raise_for_status()

def stop_job(self, job_id: str, abort: bool):
Expand Down
3 changes: 2 additions & 1 deletion cli/dstack/api/hub/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def delete_configuration_cache(self, configuration_path: str):

def get_run_plan(
self,
configuration_path: str,
provider_name: str,
provider_data: Optional[Dict[str, Any]] = None,
args: Optional[argparse.Namespace] = None,
Expand All @@ -277,7 +278,7 @@ def get_run_plan(
run_name="dry-run",
ssh_key_pub="",
)
jobs = provider.get_jobs(repo=self.repo)
jobs = provider.get_jobs(repo=self.repo, configuration_path=configuration_path)
run_plan = self._api_client.get_run_plan(jobs)
return run_plan

Expand Down
16 changes: 12 additions & 4 deletions runner/internal/backend/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/dstackai/dstack/runner/internal/container"
"io"
"io/ioutil"
"path"
Expand Down Expand Up @@ -321,13 +322,20 @@ func (s *AWSBackend) GetRepoArchive(ctx context.Context, path, dir string) error
return gerrors.Wrap(base.GetRepoArchive(ctx, s.storage, path, dir))
}

func (s *AWSBackend) GetBuildDiffInfo(ctx context.Context, spec *container.BuildSpec) (*base.StorageObject, error) {
obj, err := base.GetBuildDiffInfo(ctx, s.storage, spec)
if err != nil {
return nil, gerrors.Wrap(err)
}
return obj, nil
}

func (s *AWSBackend) GetBuildDiff(ctx context.Context, key, dst string) error {
_ = base.DownloadFile(ctx, s.storage, key, dst)
return nil
return gerrors.Wrap(base.DownloadFile(ctx, s.storage, key, dst))
}

func (s *AWSBackend) PutBuildDiff(ctx context.Context, src, key string) error {
return gerrors.Wrap(base.UploadFile(ctx, s.storage, src, key))
func (s *AWSBackend) PutBuildDiff(ctx context.Context, src string, spec *container.BuildSpec) error {
return gerrors.Wrap(base.PutBuildDiff(ctx, s.storage, src, spec))
}

func (s *AWSBackend) GetTMPDir(ctx context.Context) string {
Expand Down
16 changes: 12 additions & 4 deletions runner/internal/backend/azure/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"github.com/dstackai/dstack/runner/internal/backend/base"
"github.com/dstackai/dstack/runner/internal/container"
"io"
"os"
"path"
Expand Down Expand Up @@ -236,13 +237,20 @@ func (azbackend *AzureBackend) GetRepoArchive(ctx context.Context, path, dir str
return gerrors.Wrap(base.GetRepoArchive(ctx, azbackend.storage, path, dir))
}

func (azbackend *AzureBackend) GetBuildDiffInfo(ctx context.Context, spec *container.BuildSpec) (*base.StorageObject, error) {
obj, err := base.GetBuildDiffInfo(ctx, azbackend.storage, spec)
if err != nil {
return nil, gerrors.Wrap(err)
}
return obj, nil
}

func (azbackend *AzureBackend) GetBuildDiff(ctx context.Context, key, dst string) error {
_ = base.DownloadFile(ctx, azbackend.storage, key, dst)
return nil
return gerrors.Wrap(base.DownloadFile(ctx, azbackend.storage, key, dst))
}

func (azbackend *AzureBackend) PutBuildDiff(ctx context.Context, src, key string) error {
return gerrors.Wrap(base.UploadFile(ctx, azbackend.storage, src, key))
func (azbackend *AzureBackend) PutBuildDiff(ctx context.Context, src string, spec *container.BuildSpec) error {
return gerrors.Wrap(base.PutBuildDiff(ctx, azbackend.storage, src, spec))
}

func (azbackend *AzureBackend) GetTMPDir(ctx context.Context) string {
Expand Down
Loading