Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable, List, Optional

import sentry_sdk
from fastapi import FastAPI, Request, Response, status
Expand Down Expand Up @@ -62,6 +62,7 @@
CustomORJSONResponse,
check_client_server_compatibility,
error_detail,
get_client_version,
get_server_client_error_details,
)
from dstack._internal.settings import DSTACK_VERSION
Expand Down Expand Up @@ -319,8 +320,19 @@ async def check_client_version(request: Request, call_next):
or request.url.path in _NO_API_VERSION_CHECK_ROUTES
):
return await call_next(request)
try:
client_version = get_client_version(request)
except ValueError as e:
return CustomORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": [error_detail(str(e))]},
)
client_release: Optional[tuple[int, ...]] = None
if client_version is not None:
client_release = client_version.release
request.state.client_release = client_release
response = check_client_server_compatibility(
client_version=request.headers.get("x-api-version"),
client_version=client_version,
server_version=DSTACK_VERSION,
)
if response is not None:
Expand Down
21 changes: 15 additions & 6 deletions src/dstack/_internal/server/routers/runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Tuple
from typing import Annotated, List, Optional, Tuple, cast

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import ResourceNotExistsError
Expand Down Expand Up @@ -35,6 +35,11 @@
)


def use_legacy_default_working_dir(request: Request) -> bool:
client_release = cast(Optional[tuple[int, ...]], request.state.client_release)
return client_release is not None and client_release < (0, 19, 27)


@root_router.post(
"/list",
response_model=List[Run],
Expand Down Expand Up @@ -103,8 +108,9 @@ async def get_run(
)
async def get_plan(
body: GetRunPlanRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
session: Annotated[AsyncSession, Depends(get_session)],
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
legacy_default_working_dir: Annotated[bool, Depends(use_legacy_default_working_dir)],
):
"""
Returns a run plan for the given run spec.
Expand All @@ -119,6 +125,7 @@ async def get_plan(
user=user,
run_spec=body.run_spec,
max_offers=body.max_offers,
legacy_default_working_dir=legacy_default_working_dir,
)
return CustomORJSONResponse(run_plan)

Expand All @@ -129,8 +136,9 @@ async def get_plan(
)
async def apply_plan(
body: ApplyRunPlanRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
session: Annotated[AsyncSession, Depends(get_session)],
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
legacy_default_working_dir: Annotated[bool, Depends(use_legacy_default_working_dir)],
):
"""
Creates a new run or updates an existing run.
Expand All @@ -148,6 +156,7 @@ async def apply_plan(
project=project,
plan=body.plan,
force=body.force,
legacy_default_working_dir=legacy_default_working_dir,
)
)

Expand Down
19 changes: 16 additions & 3 deletions src/dstack/_internal/server/services/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData
from dstack._internal.core.models.runs import (
LEGACY_REPO_DIR,
ApplyRunPlanInput,
Job,
JobPlan,
Expand Down Expand Up @@ -308,6 +309,7 @@ async def get_plan(
user: UserModel,
run_spec: RunSpec,
max_offers: Optional[int],
legacy_default_working_dir: bool = False,
) -> RunPlan:
# Spec must be copied by parsing to calculate merged_profile
effective_run_spec = RunSpec.parse_obj(run_spec.dict())
Expand All @@ -317,7 +319,11 @@ async def get_plan(
spec=effective_run_spec,
)
effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict())
_validate_run_spec_and_set_defaults(user, effective_run_spec)
_validate_run_spec_and_set_defaults(
user=user,
run_spec=effective_run_spec,
legacy_default_working_dir=legacy_default_working_dir,
)

profile = effective_run_spec.merged_profile
creation_policy = profile.creation_policy
Expand Down Expand Up @@ -413,6 +419,7 @@ async def apply_plan(
project: ProjectModel,
plan: ApplyRunPlanInput,
force: bool,
legacy_default_working_dir: bool = False,
) -> Run:
run_spec = plan.run_spec
run_spec = await apply_plugin_policies(
Expand All @@ -422,7 +429,9 @@ async def apply_plan(
)
# Spec must be copied by parsing to calculate merged_profile
run_spec = RunSpec.parse_obj(run_spec.dict())
_validate_run_spec_and_set_defaults(user, run_spec)
_validate_run_spec_and_set_defaults(
user=user, run_spec=run_spec, legacy_default_working_dir=legacy_default_working_dir
)
if run_spec.run_name is None:
return await submit_run(
session=session,
Expand Down Expand Up @@ -985,7 +994,9 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float:
return job_submission.job_provisioning_data.price * duration_hours


def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
def _validate_run_spec_and_set_defaults(
user: UserModel, run_spec: RunSpec, legacy_default_working_dir: bool = False
):
# This function may set defaults for null run_spec values,
# although most defaults are resolved when building job_spec
# so that we can keep both the original user-supplied value (null in run_spec)
Expand Down Expand Up @@ -1040,6 +1051,8 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
run_spec.ssh_key_pub = user.ssh_public_key
else:
raise ServerClientError("ssh_key_pub must be set if the user has no ssh_public_key")
if run_spec.configuration.working_dir is None and legacy_default_working_dir:
run_spec.configuration.working_dir = LEGACY_REPO_DIR


_UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"]
Expand Down
38 changes: 18 additions & 20 deletions src/dstack/_internal/server/utils/routers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Dict, List, Optional

import orjson
import packaging.version
from fastapi import HTTPException, Request, Response, status
from packaging import version

from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode
from dstack._internal.core.models.common import CoreModel
from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default
from dstack._internal.utils.version import parse_version


class CustomORJSONResponse(Response):
Expand Down Expand Up @@ -122,8 +123,15 @@ def get_request_size(request: Request) -> int:
return int(request.headers["content-length"])


def get_client_version(request: Request) -> Optional[packaging.version.Version]:
version = request.headers.get("x-api-version")
if version is None:
return None
return parse_version(version)


def check_client_server_compatibility(
client_version: Optional[str],
client_version: Optional[packaging.version.Version],
server_version: Optional[str],
) -> Optional[CustomORJSONResponse]:
"""
Expand All @@ -132,28 +140,18 @@ def check_client_server_compatibility(
"""
if client_version is None or server_version is None:
return None
parsed_server_version = version.parse(server_version)
# latest allows client to bypass compatibility check (e.g. frontend)
if client_version == "latest":
parsed_server_version = parse_version(server_version)
if parsed_server_version is None:
return None
try:
parsed_client_version = version.parse(client_version)
except version.InvalidVersion:
return CustomORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
"detail": get_server_client_error_details(
ServerClientError("Bad API version specified")
)
},
)
# We preserve full client backward compatibility across patch releases.
# Server is always partially backward-compatible (so no check).
if parsed_client_version > parsed_server_version and (
parsed_client_version.major > parsed_server_version.major
or parsed_client_version.minor > parsed_server_version.minor
if client_version > parsed_server_version and (
client_version.major > parsed_server_version.major
or client_version.minor > parsed_server_version.minor
):
return error_incompatible_versions(client_version, server_version, ask_cli_update=False)
return error_incompatible_versions(
str(client_version), server_version, ask_cli_update=False
)
return None


Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/settings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os

from dstack import version
from dstack._internal.utils.version import parse_version

DSTACK_VERSION = os.getenv("DSTACK_VERSION", version.__version__)
if DSTACK_VERSION == "0.0.0":
if parse_version(DSTACK_VERSION) is None:
# The build backend (hatching) requires not None for versions,
# but the code currently treats None as dev version.
# TODO: update the code to treat 0.0.0 as dev version.
Expand Down
22 changes: 22 additions & 0 deletions src/dstack/_internal/utils/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Optional

import packaging.version


def parse_version(version_string: str) -> Optional[packaging.version.Version]:
"""
Returns a `packaging.version.Version` instance or `None` if the version is dev/latest.

Values parsed as the dev/latest version:
* the "latest" literal
* any "0.0.0" release, e.g., "0.0.0", "0.0.0a1", "0.0.0.dev0"
"""
if version_string == "latest":
return None
try:
version = packaging.version.parse(version_string)
except packaging.version.InvalidVersion as e:
raise ValueError(f"Invalid version: {version_string}") from e
if version.release == (0, 0, 0):
return None
return version
19 changes: 10 additions & 9 deletions src/tests/_internal/server/utils/test_routers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Optional

import packaging.version
import pytest

from dstack._internal.server.utils.routers import check_client_server_compatibility


class TestCheckClientServerCompatibility:
@pytest.mark.parametrize("client_version", ["12.12.12", None])
def test_returns_none_if_server_version_is_none(self, client_version: Optional[str]):
@pytest.mark.parametrize("client_version", [packaging.version.parse("12.12.12"), None])
def test_returns_none_if_server_version_is_none(
self, client_version: Optional[packaging.version.Version]
):
assert (
check_client_server_compatibility(
client_version=client_version,
Expand All @@ -27,12 +30,10 @@ def test_returns_none_if_server_version_is_none(self, client_version: Optional[s
("1.0.5", "1.0.6"),
],
)
def test_returns_none_if_compatible(
self, client_version: Optional[str], server_version: Optional[str]
):
def test_returns_none_if_compatible(self, client_version: str, server_version: str):
assert (
check_client_server_compatibility(
client_version=client_version,
client_version=packaging.version.parse(client_version),
server_version=server_version,
)
is None
Expand All @@ -46,10 +47,10 @@ def test_returns_none_if_compatible(
],
)
def test_returns_error_if_client_version_larger(
self, client_version: Optional[str], server_version: Optional[str]
self, client_version: str, server_version: str
):
res = check_client_server_compatibility(
client_version=client_version,
client_version=packaging.version.parse(client_version),
server_version=server_version,
)
assert res is not None
Expand All @@ -63,7 +64,7 @@ def test_returns_error_if_client_version_larger(
)
def test_returns_none_if_client_version_is_latest(self, server_version: Optional[str]):
res = check_client_server_compatibility(
client_version="latest",
client_version=None,
server_version=server_version,
)
assert res is None
17 changes: 17 additions & 0 deletions src/tests/_internal/utils/test_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import packaging.version
import pytest

from dstack._internal.utils.version import parse_version


class TestParseVersion:
@pytest.mark.parametrize("version", ["0.0.0", "0.0.0.dev0", "0.0.0alpha", "latest"])
def test_latest(self, version: str):
assert parse_version(version) is None

def test_release(self):
assert parse_version("0.19.27") == packaging.version.parse("0.19.27")

def test_error_invalid_version(self):
with pytest.raises(ValueError, match=r"Invalid version: 0\.0invalid"):
parse_version("0.0invalid")