diff --git a/packages/api-server/.pylintrc b/packages/api-server/.pylintrc index 8a1354566..dff6626a9 100644 --- a/packages/api-server/.pylintrc +++ b/packages/api-server/.pylintrc @@ -152,6 +152,7 @@ disable=print-statement, logging-fstring-interpolation, line-too-long, too-many-lines, + no-self-use, # fastapi heavily uses singletons, disallowing global statements just ends up with things # using nonlocal as singletons which are functionally the same as globals. diff --git a/packages/api-server/README.md b/packages/api-server/README.md index 405ef6aa3..3c2548a6c 100644 --- a/packages/api-server/README.md +++ b/packages/api-server/README.md @@ -252,18 +252,24 @@ Restart the `api-server` and the changes to the databse should be reflected. ### Running unit tests ```bash -npm test +pnpm test +``` + +By default in-memory sqlite database is used for testing, to test on another database, set the `RMF_API_SERVER_TEST_DB_URL` environment variable. + +```bash +RMF_API_SERVER_TEST_DB_URL= pnpm test ``` ### Collecting code coverage ```bash -npm run test:cov +pnpm run test:cov ``` Generate coverage report ```bash -npm run test:report +pnpm run test:report ``` ## Live reload diff --git a/packages/api-server/api_server/dependencies.py b/packages/api-server/api_server/dependencies.py index 199c86061..8706a42e7 100644 --- a/packages/api-server/api_server/dependencies.py +++ b/packages/api-server/api_server/dependencies.py @@ -20,7 +20,11 @@ def pagination_query( ) -> Pagination: limit = limit or 100 offset = offset or 0 - return Pagination(limit=limit, offset=offset, order_by=order_by) + return Pagination( + limit=limit, + offset=offset, + order_by=order_by.split(",") if order_by else [], + ) # hacky way to get the sio user diff --git a/packages/api-server/api_server/models/pagination.py b/packages/api-server/api_server/models/pagination.py index b88a78b4c..9832a379d 100644 --- a/packages/api-server/api_server/models/pagination.py +++ b/packages/api-server/api_server/models/pagination.py @@ -1,9 +1,7 @@ -from typing import Optional - from pydantic import BaseModel class Pagination(BaseModel): limit: int offset: int - order_by: Optional[str] + order_by: list[str] diff --git a/packages/api-server/api_server/query.py b/packages/api-server/api_server/query.py index 35dcb62c1..c0ea9db33 100644 --- a/packages/api-server/api_server/query.py +++ b/packages/api-server/api_server/query.py @@ -1,7 +1,3 @@ -from typing import Dict, Optional - -import tortoise.functions as tfuncs -from tortoise.expressions import Q from tortoise.queryset import MODEL, QuerySet from api_server.models.pagination import Pagination @@ -10,47 +6,10 @@ def add_pagination( query: QuerySet[MODEL], pagination: Pagination, - field_mappings: Optional[Dict[str, str]] = None, - group_by: str | None = None, ) -> QuerySet[MODEL]: - """ - Adds pagination and ordering to a query. If the order field starts with `label=`, it is - assumed to be a label and label sorting will used. In this case, the model must have - a reverse relation named "labels" and the `group_by` param is required. - - :param field_mapping: A dict mapping the order fields to the fields used to build the - query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}` - will order the query result according to `db_field`. - :param group_by: Required when sorting by labels, must be the foreign key column of the label table. - """ - field_mappings = field_mappings or {} - annotations = {} - query = query.limit(pagination.limit).offset(pagination.offset) - if pagination.order_by is not None: - order_fields = [] - order_values = pagination.order_by.split(",") - for v in order_values: - # perform the mapping after stripping the order prefix - order_prefix = "" - order_field = v - if v[0] in ["-", "+"]: - order_prefix = v[0] - order_field = v[1:] - order_field = field_mappings.get(order_field, order_field) - - # add annotations required for sorting by labels - if order_field.startswith("label="): - f = order_field[6:] - annotations[f"label_sort_{f}"] = tfuncs.Max( - "labels__label_value_str", - _filter=Q(labels__label_name=f), - ) - order_field = f"label_sort_{f}" - - order_fields.append(order_prefix + order_field) - - query = query.annotate(**annotations) - if group_by is not None: - query = query.group_by(group_by) - query = query.order_by(*order_fields) - return query + """Adds pagination and ordering to a query""" + return ( + query.limit(pagination.limit) + .offset(pagination.offset) + .order_by(*pagination.order_by) + ) diff --git a/packages/api-server/api_server/repositories/tasks.py b/packages/api-server/api_server/repositories/tasks.py index 5fb1ee963..f47b33453 100644 --- a/packages/api-server/api_server/repositories/tasks.py +++ b/packages/api-server/api_server/repositories/tasks.py @@ -2,18 +2,18 @@ from datetime import datetime from typing import Dict, List, Optional, Sequence, Tuple, cast +import tortoise.functions as tfuncs from fastapi import Depends, HTTPException from tortoise.exceptions import FieldError, IntegrityError +from tortoise.expressions import Expression, Q from tortoise.query_utils import Prefetch -from tortoise.queryset import QuerySet from tortoise.transactions import in_transaction from api_server.authenticator import user_dep from api_server.logging import LoggerAdapter, get_logger +from api_server.models import Labels, LogEntry, Pagination, Phases +from api_server.models import Status as TaskStatus from api_server.models import ( - LogEntry, - Pagination, - Phases, TaskBookingLabel, TaskEventLog, TaskRequest, @@ -25,7 +25,6 @@ from api_server.models.rmf_api.task_state import Category, Id, Phase from api_server.models.tortoise_models import TaskRequest as DbTaskRequest from api_server.models.tortoise_models import TaskState as DbTaskState -from api_server.query import add_pagination from api_server.rmf_io import task_events @@ -140,17 +139,92 @@ async def save_task_state(self, task_state: TaskState) -> None: ) async def query_task_states( - self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None + self, + task_id: list[str] | None = None, + category: list[str] | None = None, + assigned_to: list[str] | None = None, + start_time_between: tuple[datetime, datetime] | None = None, + finish_time_between: tuple[datetime, datetime] | None = None, + request_time_between: tuple[datetime, datetime] | None = None, + requester: list[str] | None = None, + status: list[str] | None = None, + label: Labels | None = None, + pagination: Optional[Pagination] = None, ) -> List[TaskState]: + filters = {} + if task_id is not None: + filters["id___in"] = task_id + if category is not None: + filters["category__in"] = category + if assigned_to is not None: + filters["assigned_to__in"] = assigned_to + if start_time_between is not None: + filters["unix_millis_start_time__gte"] = start_time_between[0] + filters["unix_millis_start_time__lte"] = start_time_between[1] + if finish_time_between is not None: + filters["unix_millis_finish_time__gte"] = finish_time_between[0] + filters["unix_millis_finish_time__lte"] = finish_time_between[1] + if request_time_between is not None: + filters["unix_millis_request_time__gte"] = request_time_between[0] + filters["unix_millis_request_time__lte"] = request_time_between[1] + if requester is not None: + filters["requester__in"] = requester + if status is not None: + valid_values = [member.value for member in TaskStatus] + filters["status__in"] = [] + for status_string in status: + if status_string not in valid_values: + continue + filters["status__in"].append(TaskStatus(status_string)) + query = DbTaskState.filter(**filters) + + need_group_by = False + label_filters = {} + if label is not None: + label_filters.update( + { + f"label_filter_{k}": tfuncs.Count( + "id_", + _filter=Q(labels__label_name=k, labels__label_value_str=v), + ) + for k, v in label.__root__.items() + } + ) + + if len(label_filters) > 0: + filter_gt = {f"{f}__gt": 0 for f in label_filters} + query = query.annotate(**label_filters).filter(**filter_gt) + need_group_by = True + + if pagination: + order_fields: list[str] = [] + annotations: dict[str, Expression] = {} + # add annotations required for sorting by labels + for f in pagination.order_by: + order_prefix = f[0] if f[0] == "-" else "" + order_field = f[1:] if order_prefix == "-" else f + if order_field.startswith("label="): + f = order_field[6:] + annotations[f"label_sort_{f}"] = tfuncs.Max( + "labels__label_value_str", + _filter=Q(labels__label_name=f), + ) + order_field = f"label_sort_{f}" + + order_fields.append(order_prefix + order_field) + + query = ( + query.annotate(**annotations) + .limit(pagination.limit) + .offset(pagination.offset) + .order_by(*order_fields) + ) + need_group_by = True + + if need_group_by: + query = query.group_by("id_", "labels__state_id") + try: - if pagination: - query = add_pagination( - query, - pagination, - # TODO(koonpeng): remove this mapping after `pickup` and `destination` query is removed. - {"pickup": "label=pickup", "destination": "label=destination"}, - group_by="labels__state_id", - ) # TODO: enforce with authz results = await query.values_list("data", flat=True) return [TaskState(**r) for r in results] diff --git a/packages/api-server/api_server/routes/tasks/scheduled_tasks.py b/packages/api-server/api_server/routes/tasks/scheduled_tasks.py index a0e5a455f..58fed9199 100644 --- a/packages/api-server/api_server/routes/tasks/scheduled_tasks.py +++ b/packages/api-server/api_server/routes/tasks/scheduled_tasks.py @@ -182,7 +182,7 @@ async def get_scheduled_tasks( .offset(pagination.offset) ) if pagination.order_by: - q.order_by(*pagination.order_by.split(",")) + q.order_by(*pagination.order_by) return await ttm.ScheduledTaskPydanticList.from_queryset(q) diff --git a/packages/api-server/api_server/routes/tasks/tasks.py b/packages/api-server/api_server/routes/tasks/tasks.py index b5b7671dd..89940ad4b 100644 --- a/packages/api-server/api_server/routes/tasks/tasks.py +++ b/packages/api-server/api_server/routes/tasks/tasks.py @@ -1,10 +1,8 @@ from datetime import datetime from typing import List, Optional, Tuple, cast -import tortoise.functions as tfuncs from fastapi import Body, Depends, HTTPException, Path, Query from rx import operators as rxops -from tortoise.expressions import Q from api_server import models as mdl from api_server.dependencies import ( @@ -17,7 +15,6 @@ ) from api_server.fast_io import FastIORouter, SubscriptionRequest from api_server.logging import LoggerAdapter, get_logger -from api_server.models.tortoise_models import TaskState as DbTaskState from api_server.repositories import RmfRepository, TaskRepository from api_server.response import RawJSONResponse from api_server.rmf_io import task_events, tasks_service @@ -102,12 +99,12 @@ async def query_task_states( ), pickup: Optional[str] = Query( None, - description="comma separated list of pickup names. [deprecated] use `label` instead", + description="pickup name. [deprecated] use `label` instead", deprecated=True, ), destination: Optional[str] = Query( None, - description="comma separated list of destination names, [deprecated] use `label` instead", + description="destination name, [deprecated] use `label` instead", deprecated=True, ), assigned_to: Optional[str] = Query( @@ -127,76 +124,26 @@ async def query_task_states( ), pagination: mdl.Pagination = Depends(pagination_query), ): - """ - Note that sorting by `pickup` and `destination` is mutually exclusive and sorting - by either of them will filter only tasks which has those labels. - """ - filters = {} - if task_id is not None: - filters["id___in"] = task_id.split(",") - if category is not None: - filters["category__in"] = category.split(",") - if request_time_between is not None: - filters["unix_millis_request_time__gte"] = request_time_between[0] - filters["unix_millis_request_time__lte"] = request_time_between[1] - if requester is not None: - filters["requester__in"] = requester.split(",") - if assigned_to is not None: - filters["assigned_to__in"] = assigned_to.split(",") - if start_time_between is not None: - filters["unix_millis_start_time__gte"] = start_time_between[0] - filters["unix_millis_start_time__lte"] = start_time_between[1] - if finish_time_between is not None: - filters["unix_millis_finish_time__gte"] = finish_time_between[0] - filters["unix_millis_finish_time__lte"] = finish_time_between[1] - if status is not None: - valid_values = [member.value for member in mdl.Status] - filters["status__in"] = [] - for status_string in status.split(","): - if status_string not in valid_values: - continue - filters["status__in"].append(mdl.Status(status_string)) - query = DbTaskState.filter(**filters) - - label_filters = {} - if pickup is not None: - label_filters["label_filter_pickup"] = tfuncs.Count( - "id_", - _filter=Q( - labels__label_name="pickup", - labels__label_value_str__in=pickup.split(","), - ), - ) - if destination is not None: - label_filters["label_filter_destination"] = tfuncs.Count( - "id_", - _filter=Q( - labels__label_name="destination", - labels__label_value_str__in=destination.split(","), - ), - ) - if label is not None: - labels = mdl.Labels.from_strings(label.split(",")) - label_filters.update( - { - f"label_filter_{k}": tfuncs.Count( - "id_", _filter=Q(labels__label_name=k, labels__label_value_str=v) - ) - for k, v in labels.__root__.items() - } - ) - - if len(label_filters) > 0: - filter_gt = {f"{f}__gt": 0 for f in label_filters} - query = ( - query.annotate(**label_filters) - .group_by( - "labels__state_id" - ) # need to group by a related field to make tortoise-orm generate joins - .filter(**filter_gt) - ) - - return await task_repo.query_task_states(query, pagination) + labels = ( + mdl.Labels.from_strings(label.split(",")) if label else mdl.Labels(__root__={}) + ) + if pickup: + labels.__root__["pickup"] = pickup + if destination: + labels.__root__["destination"] = destination + + return await task_repo.query_task_states( + task_id=task_id.split(",") if task_id else None, + category=category.split(",") if category else None, + assigned_to=assigned_to.split(",") if assigned_to else None, + start_time_between=start_time_between, + finish_time_between=finish_time_between, + request_time_between=request_time_between, + requester=requester.split(",") if requester else None, + status=status.split(",") if status else None, + label=labels, + pagination=pagination, + ) @router.get("/{task_id}/state", response_model=mdl.TaskState) diff --git a/packages/api-server/api_server/routes/tasks/test_tasks.py b/packages/api-server/api_server/routes/tasks/test_tasks.py index 5c3c7f673..ed7385c7a 100644 --- a/packages/api-server/api_server/routes/tasks/test_tasks.py +++ b/packages/api-server/api_server/routes/tasks/test_tasks.py @@ -1,10 +1,13 @@ +from typing import cast from unittest.mock import patch from uuid import uuid4 import pydantic from api_server import models as mdl -from api_server.rmf_io import tasks_service +from api_server.models import TaskEventLog, TaskState +from api_server.repositories import TaskRepository +from api_server.rmf_io import task_events, tasks_service from api_server.test import ( AppFixture, make_task_booking_label, @@ -45,15 +48,12 @@ def setUpClass(cls): ] cls.task_logs = [make_task_log(task_id=f"test_{x}") for x in task_ids] - with cls.client.websocket_connect("/_internal") as ws: - for x in cls.task_states: - ws.send_text( - mdl.TaskStateUpdate(type="task_state_update", data=x).json() - ) - for x in cls.task_logs: - ws.send_text( - mdl.TaskEventLogUpdate(type="task_log_update", data=x).json() - ) + portal = cls.get_portal() + repo = TaskRepository(cls.admin_user) + for x in cls.task_states: + portal.call(repo.save_task_state, x) + for x in cls.task_logs: + portal.call(repo.save_task_log, x) def test_get_task_state(self): resp = self.client.get(f"/tasks/{self.task_states[0].booking.id}/state") @@ -95,7 +95,7 @@ def test_query_task_states(self): def test_query_task_states_filter_by_label(self): resp = self.client.get("/tasks?label=not_existing") - self.assertEqual(200, resp.status_code) + self.assertEqual(200, resp.status_code, resp.content) results = pydantic.parse_raw_as(list[mdl.TaskState], resp.content) self.assertEqual(0, len(results)) @@ -173,15 +173,10 @@ def test_query_task_states_sort_by_label(self): def test_sub_task_state(self): task_id = self.task_states[0].booking.id - gen = self.subscribe_sio(f"/tasks/{task_id}/state") - with self.client.websocket_connect("/_internal") as ws: - ws.send_text( - mdl.TaskStateUpdate( - type="task_state_update", data=self.task_states[0] - ).json() - ) - state = next(gen) - self.assertEqual(task_id, state.booking.id) # type: ignore + with self.subscribe_sio(f"/tasks/{task_id}/state") as sub: + task_events.task_states.on_next(self.task_states[0]) + state = TaskState(**next(sub)) + self.assertEqual(task_id, cast(TaskState, state).booking.id) def test_get_task_booking_label(self): resp = self.client.get(f"/tasks/{self.task_states[0].booking.id}/booking_label") @@ -264,15 +259,10 @@ def test_get_task_log(self): def test_sub_task_log(self): task_id = self.task_logs[0].task_id - gen = self.subscribe_sio(f"/tasks/{task_id}/log") - with self.client.websocket_connect("/_internal") as ws: - ws.send_text( - mdl.TaskEventLogUpdate( - type="task_log_update", data=self.task_logs[0] - ).json() - ) - log = next(gen) - self.assertEqual(task_id, log.task_id) # type: ignore + with self.subscribe_sio(f"/tasks/{task_id}/log") as sub: + task_events.task_event_logs.on_next(self.task_logs[0]) + log = TaskEventLog(**next(sub)) + self.assertEqual(task_id, cast(TaskEventLog, log).task_id) def test_activity_discovery(self): with patch.object(tasks_service(), "call") as mock: diff --git a/packages/api-server/api_server/routes/test_building_map.py b/packages/api-server/api_server/routes/test_building_map.py index e73a65495..dbc0e75e2 100644 --- a/packages/api-server/api_server/routes/test_building_map.py +++ b/packages/api-server/api_server/routes/test_building_map.py @@ -10,7 +10,8 @@ class TestBuildingMapRoute(AppFixture): def test_get_building_map(self): building_map = make_building_map() - rmf_events.building_map.on_next(building_map) + portal = self.get_portal() + portal.call(building_map.save) resp = try_until( lambda: self.client.get("/building_map"), lambda x: x.status_code == 200 diff --git a/packages/api-server/api_server/routes/test_dispensers.py b/packages/api-server/api_server/routes/test_dispensers.py index 7d453593d..0e135b17b 100644 --- a/packages/api-server/api_server/routes/test_dispensers.py +++ b/packages/api-server/api_server/routes/test_dispensers.py @@ -1,7 +1,7 @@ from typing import List from uuid import uuid4 -from api_server.rmf_io import rmf_events +from api_server.models import DispenserState from api_server.test import AppFixture, make_dispenser_state @@ -11,8 +11,9 @@ def setUpClass(cls): super().setUpClass() cls.dispenser_states = [make_dispenser_state(f"test_{uuid4()}")] + portal = cls.get_portal() for x in cls.dispenser_states: - rmf_events.dispenser_states.on_next(x) + portal.call(x.save) def test_get_dispensers(self): resp = self.client.get("/dispensers") @@ -31,7 +32,8 @@ def test_get_dispenser_state(self): self.assertEqual(self.dispenser_states[0].guid, state["guid"]) def test_sub_dispenser_state(self): - msg = next( - self.subscribe_sio(f"/dispensers/{self.dispenser_states[0].guid}/state") - ) - self.assertEqual(self.dispenser_states[0].guid, msg.guid) # type: ignore + with self.subscribe_sio( + f"/dispensers/{self.dispenser_states[0].guid}/state" + ) as sub: + msg = DispenserState(**next(sub)) + self.assertEqual(self.dispenser_states[0].guid, msg.guid) diff --git a/packages/api-server/api_server/routes/test_doors.py b/packages/api-server/api_server/routes/test_doors.py index 0bae2681a..ec4a29f28 100644 --- a/packages/api-server/api_server/routes/test_doors.py +++ b/packages/api-server/api_server/routes/test_doors.py @@ -2,7 +2,7 @@ from rmf_door_msgs.msg import DoorMode as RmfDoorMode -from api_server.rmf_io import rmf_events +from api_server.models import DoorState from api_server.test import AppFixture, make_building_map, make_door_state @@ -11,11 +11,13 @@ class TestDoorsRoute(AppFixture): def setUpClass(cls): super().setUpClass() cls.building_map = make_building_map() + portal = cls.get_portal() + portal.call(cls.building_map.save) + cls.door_states = [make_door_state(f"test_{uuid4()}")] - rmf_events.building_map.on_next(cls.building_map) for x in cls.door_states: - rmf_events.door_states.on_next(x) + portal.call(x.save) def test_get_doors(self): resp = self.client.get("/doors") @@ -31,8 +33,9 @@ def test_get_door_state(self): self.assertEqual(self.door_states[0].door_name, state["door_name"]) def test_sub_door_state(self): - msg = next(self.subscribe_sio(f"/doors/{self.door_states[0].door_name}/state")) - self.assertEqual(self.door_states[0].door_name, msg.door_name) # type: ignore + with self.subscribe_sio(f"/doors/{self.door_states[0].door_name}/state") as sub: + msg = DoorState(**next(sub)) + self.assertEqual(self.door_states[0].door_name, msg.door_name) def test_post_door_request(self): resp = self.client.post( diff --git a/packages/api-server/api_server/routes/test_fleets.py b/packages/api-server/api_server/routes/test_fleets.py index 4ed0c77f4..3b55b640d 100644 --- a/packages/api-server/api_server/routes/test_fleets.py +++ b/packages/api-server/api_server/routes/test_fleets.py @@ -1,6 +1,13 @@ from urllib.parse import urlencode -from api_server.models import FleetLogUpdate, FleetStateUpdate, MutexGroups +from api_server.models import ( + FleetLog, + FleetLogUpdate, + FleetState, + FleetStateUpdate, + MutexGroups, +) +from api_server.rmf_io import fleet_events from api_server.test import ( AppFixture, make_fleet_log, @@ -11,45 +18,48 @@ class TestFleetsRoute(AppFixture): def test_fleet_states(self): - # subscribe to fleet states fleet_state = make_fleet_state("test_fleet") - gen = self.subscribe_sio(f"/fleets/{fleet_state.name}/state") - with self.client.websocket_connect("/_internal") as ws: + with self.client.websocket_connect("/_internal") as ws, self.subscribe_sio( + f"/fleets/{fleet_state.name}/state" + ) as sub: ws.send_text( FleetStateUpdate(type="fleet_state_update", data=fleet_state).json() ) - msg = next(gen) - self.assertEqual(fleet_state.name, msg.name) # type: ignore + msg = FleetState(**next(sub)) + self.assertEqual(fleet_state.name, msg.name) - # get fleet state - resp = self.client.get(f"/fleets/{fleet_state.name}/state") - self.assertEqual(200, resp.status_code) - state = resp.json() - self.assertEqual(fleet_state.name, state["name"]) + # get fleet state + resp = self.client.get(f"/fleets/{fleet_state.name}/state") + self.assertEqual(200, resp.status_code) + state = resp.json() + self.assertEqual(fleet_state.name, state["name"]) - # query fleets - resp = self.client.get(f"/fleets?fleet_name={fleet_state.name}") - self.assertEqual(200, resp.status_code) - resp_json = resp.json() - self.assertEqual(1, len(resp_json)) - self.assertEqual(fleet_state.name, resp_json[0]["name"]) + # query fleets + resp = self.client.get(f"/fleets?fleet_name={fleet_state.name}") + self.assertEqual(200, resp.status_code) + resp_json = resp.json() + self.assertEqual(1, len(resp_json), resp_json) + self.assertEqual(fleet_state.name, resp_json[0]["name"]) def test_fleet_logs(self): fleet_log = make_fleet_log() - gen = self.subscribe_sio(f"/fleets/{fleet_log.name}/log") - with self.client.websocket_connect("/_internal") as ws: + with self.client.websocket_connect("/_internal") as ws, self.subscribe_sio( + f"/fleets/{fleet_log.name}/log" + ) as sub: + fleet_events.fleet_logs.on_next(fleet_log) + ws.send_text(FleetLogUpdate(type="fleet_log_update", data=fleet_log).json()) - msg = next(gen) - self.assertEqual(fleet_log.name, msg.name) # type: ignore + msg = FleetLog(**next(sub)) + self.assertEqual(fleet_log.name, msg.name) - # Since there are no sample fleet logs, we cannot check the log contents - resp = self.client.get(f"/fleets/{fleet_log.name}/log") - self.assertEqual(200, resp.status_code) - self.assertEqual(fleet_log.name, resp.json()["name"]) + # Since there are no sample fleet logs, we cannot check the log contents + resp = self.client.get(f"/fleets/{fleet_log.name}/log") + self.assertEqual(200, resp.status_code) + self.assertEqual(fleet_log.name, resp.json()["name"]) def test_decommission_robot(self): # add a new robot diff --git a/packages/api-server/api_server/routes/test_ingestors.py b/packages/api-server/api_server/routes/test_ingestors.py index ecf3c90d5..810180508 100644 --- a/packages/api-server/api_server/routes/test_ingestors.py +++ b/packages/api-server/api_server/routes/test_ingestors.py @@ -1,7 +1,7 @@ from typing import List from uuid import uuid4 -from api_server.rmf_io import rmf_events +from api_server.models import IngestorState from api_server.test import AppFixture, make_ingestor_state @@ -11,8 +11,9 @@ def setUpClass(cls): super().setUpClass() cls.ingestor_states = [make_ingestor_state(f"test_{uuid4()}")] + portal = cls.get_portal() for x in cls.ingestor_states: - rmf_events.ingestor_states.on_next(x) + portal.call(x.save) def test_get_ingestors(self): resp = self.client.get("/ingestors") @@ -31,7 +32,8 @@ def test_get_ingestor_state(self): self.assertEqual(self.ingestor_states[0].guid, state["guid"]) def test_sub_ingestor_state(self): - msg = next( - self.subscribe_sio(f"/ingestors/{self.ingestor_states[0].guid}/state") - ) - self.assertEqual(self.ingestor_states[0].guid, msg.guid) # type: ignore + with self.subscribe_sio( + f"/ingestors/{self.ingestor_states[0].guid}/state" + ) as sub: + msg = IngestorState(**next(sub)) + self.assertEqual(self.ingestor_states[0].guid, msg.guid) diff --git a/packages/api-server/api_server/routes/test_lifts.py b/packages/api-server/api_server/routes/test_lifts.py index 2d11a2c10..4473dc03a 100644 --- a/packages/api-server/api_server/routes/test_lifts.py +++ b/packages/api-server/api_server/routes/test_lifts.py @@ -2,7 +2,7 @@ from rmf_lift_msgs.msg import LiftRequest as RmfLiftRequest -from api_server.rmf_io import rmf_events +from api_server.models import LiftState from api_server.test import AppFixture, make_building_map, make_lift_state @@ -11,11 +11,12 @@ class TestLiftsRoute(AppFixture): def setUpClass(cls): super().setUpClass() cls.building_map = make_building_map() - cls.lift_states = [make_lift_state(f"test_{uuid4()}")] + portal = cls.get_portal() + portal.call(cls.building_map.save) - rmf_events.building_map.on_next(cls.building_map) + cls.lift_states = [make_lift_state(f"test_{uuid4()}")] for x in cls.lift_states: - rmf_events.lift_states.on_next(x) + portal.call(x.save) def test_get_lifts(self): resp = self.client.get("/lifts") @@ -31,8 +32,9 @@ def test_get_lift_state(self): self.assertEqual(self.lift_states[0].lift_name, state["lift_name"]) def test_sub_lift_state(self): - msg = next(self.subscribe_sio(f"/lifts/{self.lift_states[0].lift_name}/state")) - self.assertEqual(self.lift_states[0].lift_name, msg.lift_name) # type: ignore + with self.subscribe_sio(f"/lifts/{self.lift_states[0].lift_name}/state") as sub: + msg = LiftState(**next(sub)) + self.assertEqual(self.lift_states[0].lift_name, msg.lift_name) def test_request_lift(self): resp = self.client.post( diff --git a/packages/api-server/api_server/test/__init__.py b/packages/api-server/api_server/test/__init__.py index 0cfa7abe6..10fa57f6c 100644 --- a/packages/api-server/api_server/test/__init__.py +++ b/packages/api-server/api_server/test/__init__.py @@ -4,6 +4,5 @@ from .mocks import * from .test_data import * from .test_fixtures import * -from .test_utils import * test_user = User(username="test_user", is_admin=True) diff --git a/packages/api-server/api_server/test/test_fixtures.py b/packages/api-server/api_server/test/test_fixtures.py index 261bdb923..c63c9695d 100644 --- a/packages/api-server/api_server/test/test_fixtures.py +++ b/packages/api-server/api_server/test/test_fixtures.py @@ -1,14 +1,19 @@ import asyncio +import contextlib import inspect import os import os.path import time import unittest import unittest.mock -from typing import Awaitable, Callable, Optional, TypeVar, Union +from typing import Awaitable, Callable, Generator, Optional, TypeVar, Union from uuid import uuid4 -from api_server.app import app, on_sio_connect +import pydantic +from anyio.abc import BlockingPortal +from tortoise import Tortoise + +from api_server.app import app, app_config from api_server.models import User from .mocks import patch_sio @@ -80,20 +85,62 @@ async def async_try_until( class AppFixture(unittest.TestCase): @classmethod def setUpClass(cls): + async def clean_db(): + # connect to the db to drop it + await Tortoise.init(db_url=app_config.db_url, modules={"models": []}) + await Tortoise._drop_databases() # pylint: disable=protected-access + # connect to it again to recreate it + await Tortoise.init( + db_url=app_config.db_url, modules={"models": []}, _create_db=True + ) + await Tortoise.close_connections() + + asyncio.run(clean_db()) + cls.admin_user = User(username="admin", is_admin=True) cls.client = TestClient() cls.client.headers["Content-Type"] = "application/json" cls.client.__enter__() cls.addClassCleanup(cls.client.__exit__) + @classmethod + def get_portal(cls) -> BlockingPortal: + if not cls.client.portal: + raise AssertionError( + "missing client portal, is the client context entered?" + ) + return cls.client.portal + + @contextlib.contextmanager def subscribe_sio(self, room: str, *, user="admin"): """ Subscribes to a socketio room and return a generator of messages Returns a tuple of (success: bool, messages: Any). """ - - def impl(): - with patch_sio() as mock_sio: + if self.client.portal is None: + raise AssertionError( + "self.client.portal is None, make sure this is called within a test context" + ) + portal = self.client.portal + + on_sio_connect = app.sio.handlers["/"]["connect"] + on_subscribe = app.sio.handlers["/"]["subscribe"] + on_disconnect = app.sio.handlers["/"]["disconnect"] + + def gen() -> Generator[dict, None, None]: + async def wait_for_msgs(): + async with condition: + if len(msgs) == 0: + await condition.wait() + return msgs.pop(0) + + while True: + # TODO: type check is ignored because pyright is outdated + yield portal.call(asyncio.wait_for, wait_for_msgs(), 5) # type: ignore + + with patch_sio() as mock_sio: + connected = False + try: msgs = [] condition = asyncio.Condition() @@ -102,39 +149,23 @@ async def handle_resp(emit_room, msg, *_args, **_kwargs): raise Exception("Failed to subscribe") if emit_room == room: async with condition: + if isinstance(msg, pydantic.BaseModel): + msg = msg.dict() msgs.append(msg) condition.notify() mock_sio.emit.side_effect = handle_resp - self.assertIsNotNone(self.client.portal) - assert self.client.portal is not None - - self.client.portal.call( + portal.call( on_sio_connect, "test", {}, {"token": self.client.token(user)} ) - # pylint: disable=protected-access - self.client.portal.call(app._on_subscribe, "test", {"room": room}) - - yield - - async def wait_for_msgs(): - async with condition: - if len(msgs) == 0: - await condition.wait() - return msgs.pop(0) - - try: - while True: - yield self.client.portal.call( - asyncio.wait_for, wait_for_msgs(), 5 - ) - finally: - self.client.portal.call(app._on_disconnect, "test") - - gen = impl() - next(gen) - return gen + connected = True + portal.call(on_subscribe, "test", {"room": room}) + + yield gen() + finally: + if connected: + portal.call(on_disconnect, "test") def setUp(self): self.test_time = 0 diff --git a/packages/api-server/api_server/test/test_stress.py b/packages/api-server/api_server/test/test_stress.py index a18ad6ea2..5b1a51ee6 100644 --- a/packages/api-server/api_server/test/test_stress.py +++ b/packages/api-server/api_server/test/test_stress.py @@ -18,6 +18,5 @@ def test_stress(self): rmf_events.door_states.on_next(door_state) while True: - gen = self.subscribe_sio(f"/doors/{door_state.door_name}/state") - next(gen) - gen.close() + with self.subscribe_sio(f"/doors/{door_state.door_name}/state") as sub: + next(sub) diff --git a/packages/api-server/api_server/test/test_utils.py b/packages/api-server/api_server/test/test_utils.py deleted file mode 100644 index 71ae280f1..000000000 --- a/packages/api-server/api_server/test/test_utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Optional, Sequence - -from tortoise import Tortoise - - -async def init_db(models: Optional[Sequence[str]] = None): - models = models or ["api_server.models.tortoise_models"] - await Tortoise.init( - db_url="sqlite://:memory:", - modules={"models": models}, - ) - await Tortoise.generate_schemas() diff --git a/packages/api-server/api_server/test_sio_auth.py b/packages/api-server/api_server/test_sio_auth.py index 0ddbda633..a47f4085a 100644 --- a/packages/api-server/api_server/test_sio_auth.py +++ b/packages/api-server/api_server/test_sio_auth.py @@ -1,35 +1,20 @@ import asyncio -from typing import Optional from unittest.mock import AsyncMock, patch from api_server.app import app, on_sio_connect +from api_server.authenticator import authenticator from .test.test_fixtures import AppFixture class TestSioAuth(AppFixture): - @staticmethod - def try_connect(token: Optional[str]): - with patch.object(app, "sio") as mock: + def test_token_is_verified(self): + with patch.object( + authenticator, "verify_token" + ) as mock_verify_token, patch.object(app, "sio") as mock_sio: # set up mocks session = {} - mock.get_session = AsyncMock(return_value=session) + mock_sio.get_session = AsyncMock(return_value=session) - loop = asyncio.new_event_loop() - fut = asyncio.Future(loop=loop) - - async def result(): - fut.set_result(await on_sio_connect("test", {}, {"token": token})) - - loop.run_until_complete(result()) - loop.close() - return fut.result() - - def test_fail_with_no_token(self): - self.assertFalse(self.try_connect(None)) - - def test_fail_with_invalid_token(self): - self.assertFalse(self.try_connect("invalid")) - - def test_success_with_valid_token(self): - self.assertTrue(self.try_connect(self.client.token("admin"))) + asyncio.run(on_sio_connect("test", {}, {"token": "test-token"})) + mock_verify_token.assert_awaited_once_with("test-token") diff --git a/packages/api-server/scripts/sqlite_test_config.py b/packages/api-server/scripts/sqlite_test_config.py deleted file mode 100644 index 2d15de704..000000000 --- a/packages/api-server/scripts/sqlite_test_config.py +++ /dev/null @@ -1,8 +0,0 @@ -from base_test_config import config - -config.update( - { - "db_url": "sqlite://:memory:", - "timezone": "Asia/Singapore", - } -) diff --git a/packages/api-server/scripts/test.py b/packages/api-server/scripts/test.py index e164368e6..77e8dc11f 100644 --- a/packages/api-server/scripts/test.py +++ b/packages/api-server/scripts/test.py @@ -1,8 +1,6 @@ import os -os.environ[ - "RMF_API_SERVER_CONFIG" -] = f"{os.path.dirname(__file__)}/sqlite_test_config.py" +os.environ["RMF_API_SERVER_CONFIG"] = f"{os.path.dirname(__file__)}/test_config.py" import unittest diff --git a/packages/api-server/scripts/base_test_config.py b/packages/api-server/scripts/test_config.py similarity index 60% rename from packages/api-server/scripts/base_test_config.py rename to packages/api-server/scripts/test_config.py index 35fce9a44..0e3ef18cb 100644 --- a/packages/api-server/scripts/base_test_config.py +++ b/packages/api-server/scripts/test_config.py @@ -4,7 +4,7 @@ here = os.path.dirname(__file__) -test_port = os.environ.get("RMF_SERVER_TEST_PORT", "8000") +test_port = os.environ.get("RMF_API_SERVER_TEST_PORT", "8000") config.update( { "host": "127.0.0.1", @@ -12,5 +12,7 @@ "log_level": "CRITICAL", "jwt_public_key": f"{here}/test.pub", "iss": "test", + "db_url": os.environ.get("RMF_API_SERVER_TEST_DB_URL", "sqlite://:memory:"), + "timezone": "Asia/Singapore", } ) diff --git a/packages/dashboard/src/components/tasks/tasks-app.tsx b/packages/dashboard/src/components/tasks/tasks-app.tsx index 3980d3e37..ae9812988 100644 --- a/packages/dashboard/src/components/tasks/tasks-app.tsx +++ b/packages/dashboard/src/components/tasks/tasks-app.tsx @@ -191,8 +191,8 @@ export const TasksApp = React.memo( filterColumn && filterColumn === 'id_' ? filterValue : undefined, filterColumn && filterColumn === 'category' ? filterValue : undefined, filterColumn && filterColumn === 'requester' ? filterValue : undefined, - filterColumn && filterColumn === 'pickup' ? filterValue : undefined, - filterColumn && filterColumn === 'destination' ? filterValue : undefined, + filterColumn && filterColumn === 'label=pickup' ? filterValue : undefined, + filterColumn && filterColumn === 'label=destination' ? filterValue : undefined, filterColumn && filterColumn === 'assigned_to' ? filterValue : undefined, filterColumn && filterColumn === 'status' ? filterValue : undefined, undefined, diff --git a/packages/react-components/lib/tasks/task-table-datagrid.tsx b/packages/react-components/lib/tasks/task-table-datagrid.tsx index 9b2ad8952..506975169 100644 --- a/packages/react-components/lib/tasks/task-table-datagrid.tsx +++ b/packages/react-components/lib/tasks/task-table-datagrid.tsx @@ -218,7 +218,7 @@ export function TaskDataGridTable({ filterable: true, }, { - field: 'pickup', + field: 'label=pickup', headerName: 'Pickup', width: 150, editable: false, @@ -233,7 +233,7 @@ export function TaskDataGridTable({ filterable: true, }, { - field: 'destination', + field: 'label=destination', headerName: 'Destination', width: 150, editable: false,