Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions docs/docs/reference/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ from dstack.api import Task, GPU, Client, Resources
client = Client.from_config()

task = Task(
name="my-awesome-run", # If not specified, a random name is assigned
image="ghcr.io/huggingface/text-generation-inference:latest",
env={"MODEL_ID": "TheBloke/Llama-2-13B-chat-GPTQ"},
commands=[
Expand All @@ -23,8 +24,7 @@ task = Task(
resources=Resources(gpu=GPU(memory="24GB")),
)

run = client.runs.submit(
run_name="my-awesome-run", # If not specified, a random name is assigned
run = client.runs.apply_configuration(
configuration=task,
repo=None, # Specify to mount additional files
)
Expand All @@ -42,10 +42,9 @@ finally:
```

!!! info "NOTE:"
1. The `configuration` argument in the `submit` method can be either `dstack.api.Task` or `dstack.api.Service`.
2. If you create `dstack.api.Task` or `dstack.api.Service`, you may specify the `image` argument. If `image` isn't
specified, the default image will be used. For a private Docker registry, ensure you also pass the `registry_auth` argument.
3. The `repo` argument in the `submit` method allows the mounting of a local folder, a remote repo, or a
1. The `configuration` argument in the `apply_configuration` method can be either `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`.
2. When you create `dstack.api.Task`, `dstack.api.Service`, or `dstack.api.DevEnvironment`, you can specify the `image` argument. If `image` isn't specified, the default image will be used. For a private Docker registry, ensure you also pass the `registry_auth` argument.
3. The `repo` argument in the `apply_configuration` method allows the mounting of a local folder, a remote repo, or a
programmatically created repo. In this case, the `commands` argument can refer to the files within this repo.
4. The `attach` method waits for the run to start and, for `dstack.api.Task` sets up an SSH tunnel and forwards
configured `ports` to `localhost`.
Expand Down Expand Up @@ -109,6 +108,17 @@ finally:
registry_auth: dstack.api.RegistryAuth
resources: dstack.api.Resources

### `dstack.api.DevEnvironment` { #dstack.api.DevEnvironment data-toc-label="DevEnvironment" }

#SCHEMA# dstack.api.DevEnvironment
overrides:
show_root_heading: false
show_root_toc_entry: false
heading_level: 4
item_id_mapping:
registry_auth: dstack.api.RegistryAuth
resources: dstack.api.Resources

### `dstack.api.Run` { #dstack.api.Run data-toc-label="Run" }

::: dstack.api.Run
Expand Down
6 changes: 3 additions & 3 deletions examples/misc/airflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ DSTACK_VENV_PYTHON_BINARY_PATH = f"{DSTACK_VENV_PATH}/bin/python"
def pipeline(...):
...
@task.external_python(task_id="external_python", python=DSTACK_VENV_PYTHON_BINARY_PATH)
def dstack_api_submit_venv() -> str:
def dstack_api_submit_venv():
from dstack.api import Client, Task

task = Task(
name="my-airflow-task",
commands=[
"echo 'Running dstack task via Airflow'",
"sleep 10",
Expand All @@ -61,8 +62,7 @@ def pipeline(...):
# or set explicitly from Ariflow Variables.
client = Client.from_config()

run = client.runs.submit(
run_name="my-airflow-task",
run = client.runs.apply_configuration(
configuration=task,
)
run.attach()
Expand Down
8 changes: 4 additions & 4 deletions examples/misc/airflow/dags/dstack_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def dstack_cli_apply_venv() -> str:
)

@task.external_python(task_id="external_python", python=DSTACK_VENV_PYTHON_BINARY_PATH)
def dstack_api_submit_venv() -> str:
def dstack_api_submit_venv():
"""
This task shows how to run the dstack API when
dstack is installed into a separate virtual environment available to Airflow.
Expand All @@ -63,18 +63,18 @@ def dstack_api_submit_venv() -> str:
from dstack.api import Client, Task

task = Task(
name="my-airflow-task",
commands=[
"echo 'Running dstack task via Airflow'",
"sleep 10",
"echo 'Finished'",
]
],
)
# Pick up config from `~/.dstack/config.yml`
# or set explicitly from Ariflow Variables.
client = Client.from_config()

run = client.runs.submit(
run_name="my-airflow-task",
run = client.runs.apply_configuration(
configuration=task,
)
run.attach()
Expand Down
25 changes: 8 additions & 17 deletions src/dstack/_internal/cli/commands/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
from typing import Any, Dict, List, Optional, Union

import requests
from rich.live import Live
from rich.table import Table

Expand Down Expand Up @@ -64,22 +63,14 @@ def _command(self, args: argparse.Namespace):

def _get_run_jobs_metrics(api: Client, run: Run) -> List[JobMetrics]:
metrics = []
try:
for job in run._run.jobs:
job_metrics = api.client.metrics.get_job_metrics(
project_name=api.project,
run_name=run.name,
replica_num=job.job_spec.replica_num,
job_num=job.job_spec.job_num,
)
metrics.append(job_metrics)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
raise CLIError(
"Metrics API is not supported for server versions before 0.18.18. "
"Update the server to use `dstack stats`."
)
raise
for job in run._run.jobs:
job_metrics = api.client.metrics.get_job_metrics(
project_name=api.project,
run_name=run.name,
replica_num=job.job_spec.replica_num,
job_num=job.job_spec.job_num,
)
metrics.append(job_metrics)
return metrics


Expand Down
43 changes: 4 additions & 39 deletions src/dstack/_internal/cli/services/configurators/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from typing import List, Optional

import requests
from rich.table import Table

from dstack._internal.cli.services.configurators.base import (
Expand Down Expand Up @@ -32,7 +31,6 @@
from dstack._internal.utils.common import local_time
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import convert_ssh_key_to_pem, generate_public_key, pkey_from_str
from dstack.api._public import Client
from dstack.api.utils import load_profile

logger = get_logger(__name__)
Expand Down Expand Up @@ -60,7 +58,10 @@ def apply_configuration(
_preprocess_spec(spec)

with console.status("Getting apply plan..."):
plan = _get_plan(api=self.api, spec=spec)
plan = self.api.client.fleets.get_plan(
project_name=self.api.project,
spec=spec,
)
_print_plan_header(plan)

action_message = ""
Expand Down Expand Up @@ -234,42 +235,6 @@ def _resolve_ssh_key(ssh_key_path: Optional[str]) -> Optional[SSHKey]:
exit()


def _get_plan(api: Client, spec: FleetSpec) -> FleetPlan:
try:
return api.client.fleets.get_plan(
project_name=api.project,
spec=spec,
)
except requests.exceptions.HTTPError as e:
# Handle older server versions that do not have /get_plan for fleets
# TODO: Can be removed in 0.19
if e.response.status_code == 405:
logger.warning(
"Fleet apply plan is not fully supported before 0.18.17. "
"Update the server to view full-featured apply plan."
)
user = api.client.users.get_my_user()
spec.configuration_path = None
current_resource = None
if spec.configuration.name is not None:
try:
current_resource = api.client.fleets.get(
project_name=api.project, name=spec.configuration.name
)
except ResourceNotExistsError:
pass
return FleetPlan(
project_name=api.project,
user=user.username,
spec=spec,
current_resource=current_resource,
offers=[],
total_offers=0,
max_offer_price=0,
)
raise e


def _print_plan_header(plan: FleetPlan):
def th(s: str) -> str:
return f"[bold]{s}[/bold]"
Expand Down
23 changes: 4 additions & 19 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,26 +85,11 @@ def apply_configuration(
)
profile = load_profile(Path.cwd(), configurator_args.profile)
with console.status("Getting apply plan..."):
run_plan = self.api.runs.get_plan(
run_plan = self.api.runs.get_run_plan(
configuration=conf,
repo=repo,
configuration_path=configuration_path,
backends=profile.backends,
regions=profile.regions,
instance_types=profile.instance_types,
reservation=profile.reservation,
spot_policy=profile.spot_policy,
retry_policy=profile.retry_policy,
utilization_policy=profile.utilization_policy,
max_duration=profile.max_duration,
stop_duration=profile.stop_duration,
max_price=profile.max_price,
working_dir=conf.working_dir,
run_name=conf.name,
creation_policy=profile.creation_policy,
termination_policy=profile.termination_policy,
termination_policy_idle=profile.termination_idle_time,
idle_duration=profile.idle_duration,
profile=profile,
)

print_run_plan(run_plan, offers_limit=configurator_args.max_offers)
Expand Down Expand Up @@ -163,8 +148,8 @@ def apply_configuration(

try:
with console.status("Applying plan..."):
run = self.api.runs.exec_plan(
run_plan, repo, reserve_ports=not command_args.detach
run = self.api.runs.apply_plan(
run_plan=run_plan, repo=repo, reserve_ports=not command_args.detach
)
except ServerClientError as e:
raise CLIError(e.msg)
Expand Down
21 changes: 9 additions & 12 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from dstack._internal.core.models.fleets import FleetConfiguration
from dstack._internal.core.models.gateways import GatewayConfiguration
from dstack._internal.core.models.profiles import ProfileParams, parse_off_duration
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
Expand Down Expand Up @@ -93,7 +91,9 @@ class BaseRunConfiguration(CoreModel):
Optional[str],
Field(description="The run name. If not specified, a random name is generated"),
] = None
image: Annotated[Optional[str], Field(description="The name of the Docker image to run")]
image: Annotated[Optional[str], Field(description="The name of the Docker image to run")] = (
None
)
user: Annotated[
Optional[str],
Field(
Expand All @@ -104,7 +104,7 @@ class BaseRunConfiguration(CoreModel):
),
] = None
privileged: Annotated[bool, Field(description="Run the container in privileged mode")] = False
entrypoint: Annotated[Optional[str], Field(description="The Docker entrypoint")]
entrypoint: Annotated[Optional[str], Field(description="The Docker entrypoint")] = None
working_dir: Annotated[
Optional[str],
Field(
Expand All @@ -119,17 +119,17 @@ class BaseRunConfiguration(CoreModel):
home_dir: str = "/root"
registry_auth: Annotated[
Optional[RegistryAuth], Field(description="Credentials for pulling a private Docker image")
]
] = None
python: Annotated[
Optional[PythonVersion],
Field(description="The major version of Python. Mutually exclusive with `image`"),
]
] = None
nvcc: Annotated[
Optional[bool],
Field(
description="Use image with NVIDIA CUDA Compiler (NVCC) included. Mutually exclusive with `image`"
),
]
] = None
single_branch: Annotated[
Optional[bool],
Field(
Expand Down Expand Up @@ -178,9 +178,6 @@ def validate_user(cls, v) -> Optional[str]:
UnixUser.parse(v)
return v

def get_repo(self) -> Repo:
return VirtualRepo()


class BaseRunConfigurationWithPorts(BaseRunConfiguration):
ports: Annotated[
Expand Down Expand Up @@ -209,7 +206,7 @@ def check_image_or_commands_present(cls, values):

class DevEnvironmentConfigurationParams(CoreModel):
ide: Annotated[Literal["vscode"], Field(description="The IDE to run")]
version: Annotated[Optional[str], Field(description="The version of the IDE")]
version: Annotated[Optional[str], Field(description="The version of the IDE")] = None
init: Annotated[CommandsList, Field(description="The bash commands to run on startup")] = []
inactivity_duration: Annotated[
Optional[Union[Literal["off"], int, bool, str]],
Expand All @@ -225,7 +222,7 @@ class DevEnvironmentConfigurationParams(CoreModel):
" Defaults to `off`"
)
),
]
] = None

@validator("inactivity_duration", pre=True, allow_reuse=True)
def parse_inactivity_duration(
Expand Down
Loading