diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 747aef25d..d0e2b9398 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from pathlib import Path -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable, List, Optional import sentry_sdk from fastapi import FastAPI, Request, Response, status @@ -62,6 +62,7 @@ CustomORJSONResponse, check_client_server_compatibility, error_detail, + get_client_version, get_server_client_error_details, ) from dstack._internal.settings import DSTACK_VERSION @@ -319,8 +320,19 @@ async def check_client_version(request: Request, call_next): or request.url.path in _NO_API_VERSION_CHECK_ROUTES ): return await call_next(request) + try: + client_version = get_client_version(request) + except ValueError as e: + return CustomORJSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": [error_detail(str(e))]}, + ) + client_release: Optional[tuple[int, ...]] = None + if client_version is not None: + client_release = client_version.release + request.state.client_release = client_release response = check_client_server_compatibility( - client_version=request.headers.get("x-api-version"), + client_version=client_version, server_version=DSTACK_VERSION, ) if response is not None: diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 11cbbfa0b..ba15af6e5 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -1,6 +1,6 @@ -from typing import List, Tuple +from typing import Annotated, List, Optional, Tuple, cast -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.errors import ResourceNotExistsError @@ -35,6 +35,11 @@ ) +def use_legacy_default_working_dir(request: Request) -> bool: + client_release = cast(Optional[tuple[int, ...]], request.state.client_release) + return client_release is not None and client_release < (0, 19, 27) + + @root_router.post( "/list", response_model=List[Run], @@ -103,8 +108,9 @@ async def get_run( ) async def get_plan( body: GetRunPlanRequest, - session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + session: Annotated[AsyncSession, Depends(get_session)], + user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())], + legacy_default_working_dir: Annotated[bool, Depends(use_legacy_default_working_dir)], ): """ Returns a run plan for the given run spec. @@ -119,6 +125,7 @@ async def get_plan( user=user, run_spec=body.run_spec, max_offers=body.max_offers, + legacy_default_working_dir=legacy_default_working_dir, ) return CustomORJSONResponse(run_plan) @@ -129,8 +136,9 @@ async def get_plan( ) async def apply_plan( body: ApplyRunPlanRequest, - session: AsyncSession = Depends(get_session), - user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), + session: Annotated[AsyncSession, Depends(get_session)], + user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())], + legacy_default_working_dir: Annotated[bool, Depends(use_legacy_default_working_dir)], ): """ Creates a new run or updates an existing run. @@ -148,6 +156,7 @@ async def apply_plan( project=project, plan=body.plan, force=body.force, + legacy_default_working_dir=legacy_default_working_dir, ) ) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 25ac750aa..14ad4c1c8 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -34,6 +34,7 @@ ) from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData from dstack._internal.core.models.runs import ( + LEGACY_REPO_DIR, ApplyRunPlanInput, Job, JobPlan, @@ -308,6 +309,7 @@ async def get_plan( user: UserModel, run_spec: RunSpec, max_offers: Optional[int], + legacy_default_working_dir: bool = False, ) -> RunPlan: # Spec must be copied by parsing to calculate merged_profile effective_run_spec = RunSpec.parse_obj(run_spec.dict()) @@ -317,7 +319,11 @@ async def get_plan( spec=effective_run_spec, ) effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict()) - _validate_run_spec_and_set_defaults(user, effective_run_spec) + _validate_run_spec_and_set_defaults( + user=user, + run_spec=effective_run_spec, + legacy_default_working_dir=legacy_default_working_dir, + ) profile = effective_run_spec.merged_profile creation_policy = profile.creation_policy @@ -413,6 +419,7 @@ async def apply_plan( project: ProjectModel, plan: ApplyRunPlanInput, force: bool, + legacy_default_working_dir: bool = False, ) -> Run: run_spec = plan.run_spec run_spec = await apply_plugin_policies( @@ -422,7 +429,9 @@ async def apply_plan( ) # Spec must be copied by parsing to calculate merged_profile run_spec = RunSpec.parse_obj(run_spec.dict()) - _validate_run_spec_and_set_defaults(user, run_spec) + _validate_run_spec_and_set_defaults( + user=user, run_spec=run_spec, legacy_default_working_dir=legacy_default_working_dir + ) if run_spec.run_name is None: return await submit_run( session=session, @@ -985,7 +994,9 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float: return job_submission.job_provisioning_data.price * duration_hours -def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): +def _validate_run_spec_and_set_defaults( + user: UserModel, run_spec: RunSpec, legacy_default_working_dir: bool = False +): # This function may set defaults for null run_spec values, # although most defaults are resolved when building job_spec # so that we can keep both the original user-supplied value (null in run_spec) @@ -1040,6 +1051,8 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): run_spec.ssh_key_pub = user.ssh_public_key else: raise ServerClientError("ssh_key_pub must be set if the user has no ssh_public_key") + if run_spec.configuration.working_dir is None and legacy_default_working_dir: + run_spec.configuration.working_dir = LEGACY_REPO_DIR _UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"] diff --git a/src/dstack/_internal/server/utils/routers.py b/src/dstack/_internal/server/utils/routers.py index 131ec5cc3..a625ccd9a 100644 --- a/src/dstack/_internal/server/utils/routers.py +++ b/src/dstack/_internal/server/utils/routers.py @@ -1,12 +1,13 @@ from typing import Any, Dict, List, Optional import orjson +import packaging.version from fastapi import HTTPException, Request, Response, status -from packaging import version from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode from dstack._internal.core.models.common import CoreModel from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default +from dstack._internal.utils.version import parse_version class CustomORJSONResponse(Response): @@ -122,8 +123,15 @@ def get_request_size(request: Request) -> int: return int(request.headers["content-length"]) +def get_client_version(request: Request) -> Optional[packaging.version.Version]: + version = request.headers.get("x-api-version") + if version is None: + return None + return parse_version(version) + + def check_client_server_compatibility( - client_version: Optional[str], + client_version: Optional[packaging.version.Version], server_version: Optional[str], ) -> Optional[CustomORJSONResponse]: """ @@ -132,28 +140,18 @@ def check_client_server_compatibility( """ if client_version is None or server_version is None: return None - parsed_server_version = version.parse(server_version) - # latest allows client to bypass compatibility check (e.g. frontend) - if client_version == "latest": + parsed_server_version = parse_version(server_version) + if parsed_server_version is None: return None - try: - parsed_client_version = version.parse(client_version) - except version.InvalidVersion: - return CustomORJSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={ - "detail": get_server_client_error_details( - ServerClientError("Bad API version specified") - ) - }, - ) # We preserve full client backward compatibility across patch releases. # Server is always partially backward-compatible (so no check). - if parsed_client_version > parsed_server_version and ( - parsed_client_version.major > parsed_server_version.major - or parsed_client_version.minor > parsed_server_version.minor + if client_version > parsed_server_version and ( + client_version.major > parsed_server_version.major + or client_version.minor > parsed_server_version.minor ): - return error_incompatible_versions(client_version, server_version, ask_cli_update=False) + return error_incompatible_versions( + str(client_version), server_version, ask_cli_update=False + ) return None diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py index 608bdd3e1..2c0035bad 100644 --- a/src/dstack/_internal/settings.py +++ b/src/dstack/_internal/settings.py @@ -1,9 +1,10 @@ import os from dstack import version +from dstack._internal.utils.version import parse_version DSTACK_VERSION = os.getenv("DSTACK_VERSION", version.__version__) -if DSTACK_VERSION == "0.0.0": +if parse_version(DSTACK_VERSION) is None: # The build backend (hatching) requires not None for versions, # but the code currently treats None as dev version. # TODO: update the code to treat 0.0.0 as dev version. diff --git a/src/dstack/_internal/utils/version.py b/src/dstack/_internal/utils/version.py new file mode 100644 index 000000000..bd8703d0e --- /dev/null +++ b/src/dstack/_internal/utils/version.py @@ -0,0 +1,22 @@ +from typing import Optional + +import packaging.version + + +def parse_version(version_string: str) -> Optional[packaging.version.Version]: + """ + Returns a `packaging.version.Version` instance or `None` if the version is dev/latest. + + Values parsed as the dev/latest version: + * the "latest" literal + * any "0.0.0" release, e.g., "0.0.0", "0.0.0a1", "0.0.0.dev0" + """ + if version_string == "latest": + return None + try: + version = packaging.version.parse(version_string) + except packaging.version.InvalidVersion as e: + raise ValueError(f"Invalid version: {version_string}") from e + if version.release == (0, 0, 0): + return None + return version diff --git a/src/tests/_internal/server/utils/test_routers.py b/src/tests/_internal/server/utils/test_routers.py index 9dfda50ac..d3ea11213 100644 --- a/src/tests/_internal/server/utils/test_routers.py +++ b/src/tests/_internal/server/utils/test_routers.py @@ -1,13 +1,16 @@ from typing import Optional +import packaging.version import pytest from dstack._internal.server.utils.routers import check_client_server_compatibility class TestCheckClientServerCompatibility: - @pytest.mark.parametrize("client_version", ["12.12.12", None]) - def test_returns_none_if_server_version_is_none(self, client_version: Optional[str]): + @pytest.mark.parametrize("client_version", [packaging.version.parse("12.12.12"), None]) + def test_returns_none_if_server_version_is_none( + self, client_version: Optional[packaging.version.Version] + ): assert ( check_client_server_compatibility( client_version=client_version, @@ -27,12 +30,10 @@ def test_returns_none_if_server_version_is_none(self, client_version: Optional[s ("1.0.5", "1.0.6"), ], ) - def test_returns_none_if_compatible( - self, client_version: Optional[str], server_version: Optional[str] - ): + def test_returns_none_if_compatible(self, client_version: str, server_version: str): assert ( check_client_server_compatibility( - client_version=client_version, + client_version=packaging.version.parse(client_version), server_version=server_version, ) is None @@ -46,10 +47,10 @@ def test_returns_none_if_compatible( ], ) def test_returns_error_if_client_version_larger( - self, client_version: Optional[str], server_version: Optional[str] + self, client_version: str, server_version: str ): res = check_client_server_compatibility( - client_version=client_version, + client_version=packaging.version.parse(client_version), server_version=server_version, ) assert res is not None @@ -63,7 +64,7 @@ def test_returns_error_if_client_version_larger( ) def test_returns_none_if_client_version_is_latest(self, server_version: Optional[str]): res = check_client_server_compatibility( - client_version="latest", + client_version=None, server_version=server_version, ) assert res is None diff --git a/src/tests/_internal/utils/test_version.py b/src/tests/_internal/utils/test_version.py new file mode 100644 index 000000000..d24d0d163 --- /dev/null +++ b/src/tests/_internal/utils/test_version.py @@ -0,0 +1,17 @@ +import packaging.version +import pytest + +from dstack._internal.utils.version import parse_version + + +class TestParseVersion: + @pytest.mark.parametrize("version", ["0.0.0", "0.0.0.dev0", "0.0.0alpha", "latest"]) + def test_latest(self, version: str): + assert parse_version(version) is None + + def test_release(self): + assert parse_version("0.19.27") == packaging.version.parse("0.19.27") + + def test_error_invalid_version(self): + with pytest.raises(ValueError, match=r"Invalid version: 0\.0invalid"): + parse_version("0.0invalid")