Skip to content

Commit

Permalink
Merge branch 'deploy/hammer' into hammer/demo-tasks
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Chong <aaronchongth@gmail.com>
  • Loading branch information
aaronchongth committed Jun 26, 2024
2 parents f2c0b08 + d2e56f3 commit ce58404
Show file tree
Hide file tree
Showing 25 changed files with 301 additions and 308 deletions.
1 change: 1 addition & 0 deletions packages/api-server/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ disable=print-statement,
logging-fstring-interpolation,
line-too-long,
too-many-lines,
no-self-use,

# fastapi heavily uses singletons, disallowing global statements just ends up with things
# using nonlocal as singletons which are functionally the same as globals.
Expand Down
12 changes: 9 additions & 3 deletions packages/api-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,24 @@ Restart the `api-server` and the changes to the databse should be reflected.
### Running unit tests

```bash
npm test
pnpm test
```

By default in-memory sqlite database is used for testing, to test on another database, set the `RMF_API_SERVER_TEST_DB_URL` environment variable.

```bash
RMF_API_SERVER_TEST_DB_URL=<db_url> pnpm test
```

### Collecting code coverage

```bash
npm run test:cov
pnpm run test:cov
```

Generate coverage report
```bash
npm run test:report
pnpm run test:report
```

## Live reload
Expand Down
6 changes: 5 additions & 1 deletion packages/api-server/api_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def pagination_query(
) -> Pagination:
limit = limit or 100
offset = offset or 0
return Pagination(limit=limit, offset=offset, order_by=order_by)
return Pagination(
limit=limit,
offset=offset,
order_by=order_by.split(",") if order_by else [],
)


# hacky way to get the sio user
Expand Down
4 changes: 1 addition & 3 deletions packages/api-server/api_server/models/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional

from pydantic import BaseModel


class Pagination(BaseModel):
limit: int
offset: int
order_by: Optional[str]
order_by: list[str]
53 changes: 6 additions & 47 deletions packages/api-server/api_server/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from typing import Dict, Optional

import tortoise.functions as tfuncs
from tortoise.expressions import Q
from tortoise.queryset import MODEL, QuerySet

from api_server.models.pagination import Pagination
Expand All @@ -10,47 +6,10 @@
def add_pagination(
query: QuerySet[MODEL],
pagination: Pagination,
field_mappings: Optional[Dict[str, str]] = None,
group_by: str | None = None,
) -> QuerySet[MODEL]:
"""
Adds pagination and ordering to a query. If the order field starts with `label=`, it is
assumed to be a label and label sorting will used. In this case, the model must have
a reverse relation named "labels" and the `group_by` param is required.
:param field_mapping: A dict mapping the order fields to the fields used to build the
query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}`
will order the query result according to `db_field`.
:param group_by: Required when sorting by labels, must be the foreign key column of the label table.
"""
field_mappings = field_mappings or {}
annotations = {}
query = query.limit(pagination.limit).offset(pagination.offset)
if pagination.order_by is not None:
order_fields = []
order_values = pagination.order_by.split(",")
for v in order_values:
# perform the mapping after stripping the order prefix
order_prefix = ""
order_field = v
if v[0] in ["-", "+"]:
order_prefix = v[0]
order_field = v[1:]
order_field = field_mappings.get(order_field, order_field)

# add annotations required for sorting by labels
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value_str",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = query.annotate(**annotations)
if group_by is not None:
query = query.group_by(group_by)
query = query.order_by(*order_fields)
return query
"""Adds pagination and ordering to a query"""
return (
query.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*pagination.order_by)
)
102 changes: 88 additions & 14 deletions packages/api-server/api_server/repositories/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple, cast

import tortoise.functions as tfuncs
from fastapi import Depends, HTTPException
from tortoise.exceptions import FieldError, IntegrityError
from tortoise.expressions import Expression, Q
from tortoise.query_utils import Prefetch
from tortoise.queryset import QuerySet
from tortoise.transactions import in_transaction

from api_server.authenticator import user_dep
from api_server.logging import LoggerAdapter, get_logger
from api_server.models import Labels, LogEntry, Pagination, Phases
from api_server.models import Status as TaskStatus
from api_server.models import (
LogEntry,
Pagination,
Phases,
TaskBookingLabel,
TaskEventLog,
TaskRequest,
Expand All @@ -25,7 +25,6 @@
from api_server.models.rmf_api.task_state import Category, Id, Phase
from api_server.models.tortoise_models import TaskRequest as DbTaskRequest
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.query import add_pagination
from api_server.rmf_io import task_events


Expand Down Expand Up @@ -140,17 +139,92 @@ async def save_task_state(self, task_state: TaskState) -> None:
)

async def query_task_states(
self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None
self,
task_id: list[str] | None = None,
category: list[str] | None = None,
assigned_to: list[str] | None = None,
start_time_between: tuple[datetime, datetime] | None = None,
finish_time_between: tuple[datetime, datetime] | None = None,
request_time_between: tuple[datetime, datetime] | None = None,
requester: list[str] | None = None,
status: list[str] | None = None,
label: Labels | None = None,
pagination: Optional[Pagination] = None,
) -> List[TaskState]:
filters = {}
if task_id is not None:
filters["id___in"] = task_id
if category is not None:
filters["category__in"] = category
if assigned_to is not None:
filters["assigned_to__in"] = assigned_to
if start_time_between is not None:
filters["unix_millis_start_time__gte"] = start_time_between[0]
filters["unix_millis_start_time__lte"] = start_time_between[1]
if finish_time_between is not None:
filters["unix_millis_finish_time__gte"] = finish_time_between[0]
filters["unix_millis_finish_time__lte"] = finish_time_between[1]
if request_time_between is not None:
filters["unix_millis_request_time__gte"] = request_time_between[0]
filters["unix_millis_request_time__lte"] = request_time_between[1]
if requester is not None:
filters["requester__in"] = requester
if status is not None:
valid_values = [member.value for member in TaskStatus]
filters["status__in"] = []
for status_string in status:
if status_string not in valid_values:
continue
filters["status__in"].append(TaskStatus(status_string))
query = DbTaskState.filter(**filters)

need_group_by = False
label_filters = {}
if label is not None:
label_filters.update(
{
f"label_filter_{k}": tfuncs.Count(
"id_",
_filter=Q(labels__label_name=k, labels__label_value_str=v),
)
for k, v in label.__root__.items()
}
)

if len(label_filters) > 0:
filter_gt = {f"{f}__gt": 0 for f in label_filters}
query = query.annotate(**label_filters).filter(**filter_gt)
need_group_by = True

if pagination:
order_fields: list[str] = []
annotations: dict[str, Expression] = {}
# add annotations required for sorting by labels
for f in pagination.order_by:
order_prefix = f[0] if f[0] == "-" else ""
order_field = f[1:] if order_prefix == "-" else f
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value_str",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = (
query.annotate(**annotations)
.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*order_fields)
)
need_group_by = True

if need_group_by:
query = query.group_by("id_", "labels__state_id")

try:
if pagination:
query = add_pagination(
query,
pagination,
# TODO(koonpeng): remove this mapping after `pickup` and `destination` query is removed.
{"pickup": "label=pickup", "destination": "label=destination"},
group_by="labels__state_id",
)
# TODO: enforce with authz
results = await query.values_list("data", flat=True)
return [TaskState(**r) for r in results]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def get_scheduled_tasks(
.offset(pagination.offset)
)
if pagination.order_by:
q.order_by(*pagination.order_by.split(","))
q.order_by(*pagination.order_by)
return await ttm.ScheduledTaskPydanticList.from_queryset(q)


Expand Down
97 changes: 22 additions & 75 deletions packages/api-server/api_server/routes/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from datetime import datetime
from typing import List, Optional, Tuple, cast

import tortoise.functions as tfuncs
from fastapi import Body, Depends, HTTPException, Path, Query
from rx import operators as rxops
from tortoise.expressions import Q

from api_server import models as mdl
from api_server.dependencies import (
Expand All @@ -17,7 +15,6 @@
)
from api_server.fast_io import FastIORouter, SubscriptionRequest
from api_server.logging import LoggerAdapter, get_logger
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.repositories import RmfRepository, TaskRepository
from api_server.response import RawJSONResponse
from api_server.rmf_io import task_events, tasks_service
Expand Down Expand Up @@ -102,12 +99,12 @@ async def query_task_states(
),
pickup: Optional[str] = Query(
None,
description="comma separated list of pickup names. [deprecated] use `label` instead",
description="pickup name. [deprecated] use `label` instead",
deprecated=True,
),
destination: Optional[str] = Query(
None,
description="comma separated list of destination names, [deprecated] use `label` instead",
description="destination name, [deprecated] use `label` instead",
deprecated=True,
),
assigned_to: Optional[str] = Query(
Expand All @@ -127,76 +124,26 @@ async def query_task_states(
),
pagination: mdl.Pagination = Depends(pagination_query),
):
"""
Note that sorting by `pickup` and `destination` is mutually exclusive and sorting
by either of them will filter only tasks which has those labels.
"""
filters = {}
if task_id is not None:
filters["id___in"] = task_id.split(",")
if category is not None:
filters["category__in"] = category.split(",")
if request_time_between is not None:
filters["unix_millis_request_time__gte"] = request_time_between[0]
filters["unix_millis_request_time__lte"] = request_time_between[1]
if requester is not None:
filters["requester__in"] = requester.split(",")
if assigned_to is not None:
filters["assigned_to__in"] = assigned_to.split(",")
if start_time_between is not None:
filters["unix_millis_start_time__gte"] = start_time_between[0]
filters["unix_millis_start_time__lte"] = start_time_between[1]
if finish_time_between is not None:
filters["unix_millis_finish_time__gte"] = finish_time_between[0]
filters["unix_millis_finish_time__lte"] = finish_time_between[1]
if status is not None:
valid_values = [member.value for member in mdl.Status]
filters["status__in"] = []
for status_string in status.split(","):
if status_string not in valid_values:
continue
filters["status__in"].append(mdl.Status(status_string))
query = DbTaskState.filter(**filters)

label_filters = {}
if pickup is not None:
label_filters["label_filter_pickup"] = tfuncs.Count(
"id_",
_filter=Q(
labels__label_name="pickup",
labels__label_value_str__in=pickup.split(","),
),
)
if destination is not None:
label_filters["label_filter_destination"] = tfuncs.Count(
"id_",
_filter=Q(
labels__label_name="destination",
labels__label_value_str__in=destination.split(","),
),
)
if label is not None:
labels = mdl.Labels.from_strings(label.split(","))
label_filters.update(
{
f"label_filter_{k}": tfuncs.Count(
"id_", _filter=Q(labels__label_name=k, labels__label_value_str=v)
)
for k, v in labels.__root__.items()
}
)

if len(label_filters) > 0:
filter_gt = {f"{f}__gt": 0 for f in label_filters}
query = (
query.annotate(**label_filters)
.group_by(
"labels__state_id"
) # need to group by a related field to make tortoise-orm generate joins
.filter(**filter_gt)
)

return await task_repo.query_task_states(query, pagination)
labels = (
mdl.Labels.from_strings(label.split(",")) if label else mdl.Labels(__root__={})
)
if pickup:
labels.__root__["pickup"] = pickup
if destination:
labels.__root__["destination"] = destination

return await task_repo.query_task_states(
task_id=task_id.split(",") if task_id else None,
category=category.split(",") if category else None,
assigned_to=assigned_to.split(",") if assigned_to else None,
start_time_between=start_time_between,
finish_time_between=finish_time_between,
request_time_between=request_time_between,
requester=requester.split(",") if requester else None,
status=status.split(",") if status else None,
label=labels,
pagination=pagination,
)


@router.get("/{task_id}/state", response_model=mdl.TaskState)
Expand Down
Loading

0 comments on commit ce58404

Please sign in to comment.