diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index a6afa33a7..c1dc1da1b 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -1,5 +1,6 @@ import argparse +from rich.markup import escape from rich_argparse import RichHelpFormatter from dstack._internal.cli.commands.config import ConfigCommand @@ -65,7 +66,7 @@ def main(): check_for_updates() args.func(args) except (ClientError, CLIError) as e: - console.print(f"[error]{e}[/]") + console.print(f"[error]{escape(str(e))}[/]") logger.debug(e, exc_info=True) exit(1) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index a44aa52b6..5930d4a30 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -1,6 +1,7 @@ import asyncio import itertools import math +import re import uuid from datetime import timezone from typing import List, Optional, Tuple, cast @@ -290,6 +291,9 @@ async def create_instance( async def get_run_plan( session: AsyncSession, project: ProjectModel, user: UserModel, run_spec: RunSpec ) -> RunPlan: + if run_spec.run_name is not None: + _validate_run_name(run_spec.run_name) + profile = run_spec.profile # TODO: get_or_create_default_pool @@ -386,6 +390,7 @@ async def submit_run( project=project, ) else: + _validate_run_name(run_spec.run_name) await delete_runs(session=session, project=project, runs_names=[run_spec.run_name]) pool = await get_or_create_default_pool_by_name(session, project, run_spec.profile.pool_name) @@ -624,3 +629,11 @@ async def abort_runs_of_pool(session: AsyncSession, project_model: ProjectModel, active_run_names.append(run.run_spec.run_name) await stop_runs(session, project_model, active_run_names, abort=True) + + +# The run_name validation is not performed in pydantic models since +# the models are reused on the client, and we don't want to +# tie run_name validation to the client side. +def _validate_run_name(run_name: str): + if not re.match("^[a-z][a-z0-9-]{1,40}$", run_name): + raise ServerClientError("run_name should match regex '^[a-z][a-z0-9-]{1,40}$'") diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 30abbb495..85211cf26 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -494,6 +494,43 @@ async def test_submits_run_without_run_name(self, test_db, session: AsyncSession job = res.scalar() assert job is not None + @pytest.mark.asyncio + @pytest.mark.parametrize( + "run_name", + [ + "run_with_underscores", + "RunWithUppercase", + "тест_ран", + ], + ) + async def test_returns_400_if_bad_run_name( + self, test_db, session: AsyncSession, run_name: str + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo(session=session, project_id=project.id) + run_dict = get_dev_env_run_dict( + project_name=project.name, + username=user.name, + run_name=run_name, + repo_id=repo.name, + ) + body = {"run_spec": run_dict["run_spec"]} + with patch("uuid.uuid4") as uuid_mock, patch( + "dstack._internal.server.services.backends.get_project_backends" + ) as get_project_backends_mock: + get_project_backends_mock.return_value = [Mock()] + uuid_mock.return_value = run_dict["id"] + response = client.post( + f"/api/project/{project.name}/runs/submit", + headers=get_auth_headers(user.token), + json=body, + ) + assert response.status_code == 400 + @pytest.mark.asyncio async def test_returns_400_if_repo_does_not_exist(self, test_db, session: AsyncSession): user = await create_user(session=session, global_role=GlobalRole.USER)