diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index fe5f609c2f..a78649a9d3 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -138,7 +138,16 @@ async def get_plan( spec: FleetSpec, ) -> FleetPlan: # TODO: refactor offers logic into a separate module to avoid depending on runs - await _check_ssh_hosts_not_yet_added(session, spec) + current_fleet: Optional[Fleet] = None + current_fleet_id: Optional[uuid.UUID] = None + if spec.configuration.name is not None: + current_fleet_model = await get_project_fleet_model_by_name( + session=session, project=project, name=spec.configuration.name + ) + if current_fleet_model is not None: + current_fleet = fleet_model_to_fleet(current_fleet_model) + current_fleet_id = current_fleet_model.id + await _check_ssh_hosts_not_yet_added(session, spec, current_fleet_id) offers = [] if spec.configuration.ssh_config is None: @@ -148,13 +157,6 @@ async def get_plan( requirements=_get_fleet_requirements(spec), ) offers = [offer for _, offer in offers_with_backends] - current_fleet = None - if spec.configuration.name is not None: - current_fleet = await get_fleet_by_name( - session=session, - project=project, - name=spec.configuration.name, - ) plan = FleetPlan( project_name=project.name, user=user.name, @@ -540,13 +542,18 @@ def _check_can_manage_ssh_fleets(user: UserModel, project: ProjectModel): raise ForbiddenError() -async def _check_ssh_hosts_not_yet_added(session: AsyncSession, spec: FleetSpec): +async def _check_ssh_hosts_not_yet_added( + session: AsyncSession, spec: FleetSpec, current_fleet_id: Optional[uuid.UUID] = None +): if spec.configuration.ssh_config and spec.configuration.ssh_config.hosts: # there are manually listed hosts, need to check them for existence active_instances = await list_active_remote_instances(session=session) existing_hosts = set() for instance in active_instances: + # ignore instances belonging to the same fleet -- in-place update/recreate + if current_fleet_id is not None and instance.fleet_id == current_fleet_id: + continue instance_conn_info = RemoteConnectionInfo.parse_raw( cast(str, instance.remote_connection_info) ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 1a42df0949..9ed269dec3 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -18,6 +18,7 @@ InstanceConfiguration, InstanceStatus, InstanceType, + RemoteConnectionInfo, Resources, ) from dstack._internal.core.models.placement import ( @@ -418,6 +419,7 @@ async def create_instance( session: AsyncSession, project: ProjectModel, pool: PoolModel, + fleet: Optional[FleetModel] = None, status: InstanceStatus = InstanceStatus.IDLE, created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), finished_at: Optional[datetime] = None, @@ -430,6 +432,7 @@ async def create_instance( instance_num: int = 0, backend: BackendType = BackendType.DATACRUNCH, region: str = "eu-west", + remote_connection_info: Optional[RemoteConnectionInfo] = None, ) -> InstanceModel: if instance_id is None: instance_id = uuid.uuid4() @@ -495,6 +498,7 @@ async def create_instance( name="test_instance", instance_num=instance_num, pool=pool, + fleet=fleet, project=project, status=status, unreachable=False, @@ -510,6 +514,7 @@ async def create_instance( profile=profile.json(), requirements=requirements.json(), instance_configuration=instance_configuration.json(), + remote_connection_info=remote_connection_info.json() if remote_connection_info else None, job=job, ) session.add(im) diff --git a/src/tests/_internal/server/services/test_fleets.py b/src/tests/_internal/server/services/test_fleets.py new file mode 100644 index 0000000000..3e7e9a892f --- /dev/null +++ b/src/tests/_internal/server/services/test_fleets.py @@ -0,0 +1,170 @@ +from typing import Optional, Union +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.backends.base import Backend +from dstack._internal.core.errors import ServerClientError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.fleets import ( + FleetConfiguration, + FleetSpec, + SSHHostParams, + SSHParams, +) +from dstack._internal.core.models.instances import RemoteConnectionInfo +from dstack._internal.server.models import FleetModel, ProjectModel +from dstack._internal.server.services.backends import get_project_backends +from dstack._internal.server.services.fleets import get_plan +from dstack._internal.server.testing.common import ( + create_fleet, + create_instance, + create_pool, + create_project, + create_user, + get_fleet_spec, +) + + +class TestGetPlanSSHFleetHostsValidation: + @pytest.fixture + def get_project_backends_mock(self, monkeypatch: pytest.MonkeyPatch) -> list[Backend]: + mock = Mock(spec_set=get_project_backends, return_value=[]) + monkeypatch.setattr("dstack._internal.server.services.backends.get_project_backends", mock) + return mock + + def get_ssh_fleet_spec( + self, name: Optional[str], hosts: list[Union[SSHHostParams, str]] + ) -> FleetSpec: + ssh_config = SSHParams(hosts=hosts, network=None) + fleet_conf = FleetConfiguration(name=name, ssh_config=ssh_config) + return get_fleet_spec(conf=fleet_conf) + + async def create_fleet( + self, session: AsyncSession, project: ProjectModel, spec: FleetSpec + ) -> FleetModel: + assert spec.configuration.ssh_config is not None, spec.configuration + pool = await create_pool(session=session, project=project) + fleet = await create_fleet(session=session, project=project, spec=spec) + for host in spec.configuration.ssh_config.hosts: + if isinstance(host, SSHHostParams): + hostname = host.hostname + else: + hostname = host + rci = RemoteConnectionInfo(host=hostname, port=22, ssh_user="admin", ssh_keys=[]) + await create_instance( + session=session, + project=project, + pool=pool, + fleet=fleet, + backend=BackendType.REMOTE, + remote_connection_info=rci, + ) + return fleet + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("test_db", "get_project_backends_mock") + async def test_ok_same_fleet_update(self, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + old_fleet_spec = self.get_ssh_fleet_spec(name="my-fleet", hosts=["192.168.100.201"]) + await self.create_fleet(session, project, old_fleet_spec) + new_fleet_spec = self.get_ssh_fleet_spec( + name="my-fleet", hosts=["192.168.100.201", "192.168.100.202"] + ) + plan = await get_plan(session=session, project=project, user=user, spec=new_fleet_spec) + assert plan.current_resource is not None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("test_db", "get_project_backends_mock") + async def test_ok_deleted_instances_ignored(self, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + deleted_fleet_spec = self.get_ssh_fleet_spec(name="my-fleet", hosts=["192.168.100.201"]) + deleted_fleet = await self.create_fleet(session, project, deleted_fleet_spec) + for instance in deleted_fleet.instances: + instance.deleted = True + deleted_fleet.deleted = True + await session.commit() + fleet_spec = self.get_ssh_fleet_spec( + name="my-fleet", hosts=["192.168.100.201", "192.168.100.202"] + ) + plan = await get_plan(session=session, project=project, user=user, spec=fleet_spec) + assert plan.current_resource is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("test_db", "get_project_backends_mock") + async def test_ok_no_common_hosts_with_another_fleet(self, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + another_fleet_spec = self.get_ssh_fleet_spec( + name="another-fleet", hosts=["192.168.100.201"] + ) + await self.create_fleet(session, project, another_fleet_spec) + fleet_spec = self.get_ssh_fleet_spec(name="new-fleet", hosts=["192.168.100.202"]) + plan = await get_plan(session=session, project=project, user=user, spec=fleet_spec) + assert plan.current_resource is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("test_db", "get_project_backends_mock") + async def test_error_another_fleet_same_project(self, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + another_fleet_spec = self.get_ssh_fleet_spec( + name="another-fleet", hosts=["192.168.100.201"] + ) + await self.create_fleet(session, project, another_fleet_spec) + fleet_spec = self.get_ssh_fleet_spec( + name="new-fleet", hosts=["192.168.100.201", "192.168.100.202"] + ) + with pytest.raises( + ServerClientError, match=r"Instances \[192\.168\.100\.201\] are already assigned" + ): + await get_plan(session=session, project=project, user=user, spec=fleet_spec) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("test_db", "get_project_backends_mock") + async def test_error_another_fleet_another_project(self, session: AsyncSession): + another_user = await create_user(session=session, name="another-user") + another_project = await create_project( + session=session, owner=another_user, name="another-project" + ) + another_fleet_spec = self.get_ssh_fleet_spec( + name="another-fleet", hosts=["192.168.100.201"] + ) + await self.create_fleet(session, another_project, another_fleet_spec) + user = await create_user(session=session, name="my-user") + project = await create_project(session=session, owner=user, name="my-project") + fleet_spec = self.get_ssh_fleet_spec( + name="my-fleet", hosts=["192.168.100.201", "192.168.100.202"] + ) + with pytest.raises( + ServerClientError, match=r"Instances \[192\.168\.100\.201\] are already assigned" + ): + await get_plan(session=session, project=project, user=user, spec=fleet_spec) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("test_db", "get_project_backends_mock") + async def test_error_fleet_spec_without_name(self, session: AsyncSession): + # Even if the user apply the same configuration again, we cannot be sure if it is the same + # fleet or a brand new fleet, as we identify fleets by name. + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + existing_fleet_spec = self.get_ssh_fleet_spec( + name="autogenerated-fleet-name", hosts=["192.168.100.201"] + ) + await self.create_fleet(session, project, existing_fleet_spec) + fleet_spec_without_name = self.get_ssh_fleet_spec(name=None, hosts=["192.168.100.201"]) + with pytest.raises( + ServerClientError, match=r"Instances \[192\.168\.100\.201\] are already assigned" + ): + await get_plan( + session=session, project=project, user=user, spec=fleet_spec_without_name + )