Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions backend/api/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,53 @@ class UpdateSavedViewInput:
related_sort_definition: str | None = None
related_parameters: str | None = None
visibility: SavedViewVisibility | None = None


@strawberry.enum
class TaskPresetScope(Enum):
PERSONAL = "PERSONAL"
GLOBAL = "GLOBAL"


@strawberry.input
class TaskGraphNodeInput:
node_id: str
title: str
description: str | None = None
priority: TaskPriority | None = None
estimated_time: int | None = None


@strawberry.input
class TaskGraphEdgeInput:
from_node_id: str
to_node_id: str


@strawberry.input
class TaskGraphInput:
nodes: list[TaskGraphNodeInput]
edges: list[TaskGraphEdgeInput]


@strawberry.input
class CreateTaskPresetInput:
name: str
key: str | None = None
scope: TaskPresetScope
graph: TaskGraphInput


@strawberry.input
class UpdateTaskPresetInput:
name: str | None = None
key: str | None = None
graph: TaskGraphInput | None = None


@strawberry.input
class ApplyTaskGraphInput:
patient_id: strawberry.ID
preset_id: strawberry.ID | None = None
graph: TaskGraphInput | None = None
assign_to_current_user: bool = False
3 changes: 3 additions & 0 deletions backend/api/resolvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from .query_metadata import QueryMetadataQuery
from .saved_view import SavedViewMutation, SavedViewQuery
from .task import TaskMutation, TaskQuery, TaskSubscription
from .task_preset import TaskPresetMutation, TaskPresetQuery
from .user import UserMutation, UserQuery


@strawberry.type
class Query(
PatientQuery,
TaskQuery,
TaskPresetQuery,
LocationQuery,
PropertyDefinitionQuery,
UserQuery,
Expand All @@ -28,6 +30,7 @@ class Query(
class Mutation(
PatientMutation,
TaskMutation,
TaskPresetMutation,
PropertyDefinitionMutation,
LocationMutation,
UserMutation,
Expand Down
99 changes: 97 additions & 2 deletions backend/api/resolvers/task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from collections.abc import AsyncGenerator
from typing import Any

import strawberry
from api.audit import audit_log
from api.context import Info
from api.errors import raise_forbidden
from api.inputs import CreateTaskInput, PaginationInput, PatientState, SortDirection, UpdateTaskInput
from api.inputs import (
ApplyTaskGraphInput,
CreateTaskInput,
PaginationInput,
PatientState,
SortDirection,
UpdateTaskInput,
)
from api.query.execute import count_unified_query, is_unset, unified_list_query
from api.query.inputs import (
QueryFilterClauseInput,
Expand All @@ -17,8 +25,16 @@
from api.services.checksum import validate_checksum
from api.services.datetime import normalize_datetime_to_utc
from api.services.property import PropertyService
from api.services.task_graph import (
apply_task_graph_to_patient,
graph_dict_from_preset_inputs,
insert_task_dependencies,
replace_incoming_task_dependencies,
validate_task_graph_dict,
)
from api.types.task import TaskType
from database import models
from database.models.task_preset import TaskPresetScope as DbTaskPresetScope
from graphql import GraphQLError
from sqlalchemy import and_, exists, or_, select
from sqlalchemy.orm import aliased, selectinload
Expand Down Expand Up @@ -751,6 +767,14 @@ async def create_task(self, info: Info, data: CreateTaskInput) -> TaskType:
"payload": {"task_id": task.id, "task_title": task.title},
},
)
if data.previous_task_ids:
await insert_task_dependencies(
info.context.db,
task.id,
[str(x) for x in data.previous_task_ids],
str(task.patient_id),
)
await info.context.db.commit()
return task

@strawberry.mutation
Expand Down Expand Up @@ -845,14 +869,23 @@ async def update_task(
"task",
)

return await BaseMutationResolver.update_and_notify(
result = await BaseMutationResolver.update_and_notify(
info,
task,
models.Task,
"task",
"patient",
task.patient_id,
)
if data.previous_task_ids is not None:
await replace_incoming_task_dependencies(
info.context.db,
str(id),
[str(x) for x in data.previous_task_ids],
str(task.patient_id),
)
await info.context.db.commit()
return result

@staticmethod
async def _update_task_field(
Expand Down Expand Up @@ -1005,6 +1038,68 @@ async def reopen_task(self, info: Info, id: strawberry.ID) -> TaskType:
lambda task: setattr(task, "done", False),
)

@strawberry.mutation
@audit_log("apply_task_graph")
async def apply_task_graph(
self,
info: Info,
data: ApplyTaskGraphInput,
) -> list[TaskType]:
user = info.context.user
if not user:
raise GraphQLError(
"Not authenticated",
extensions={"code": "UNAUTHENTICATED"},
)
auth_service = AuthorizationService(info.context.db)
if not await auth_service.can_access_patient_id(
user,
data.patient_id,
info.context,
):
raise_forbidden()
has_preset = data.preset_id is not None
has_graph = data.graph is not None
if has_preset == has_graph:
raise GraphQLError(
"Provide exactly one of presetId or graph",
extensions={"code": "BAD_REQUEST"},
)
graph_dict: dict[str, Any]
if data.preset_id:
pr = await info.context.db.execute(
select(models.TaskPreset).where(
models.TaskPreset.id == data.preset_id,
),
)
preset = pr.scalars().first()
if not preset:
raise GraphQLError(
"Preset not found",
extensions={"code": "NOT_FOUND"},
)
if (
preset.scope == DbTaskPresetScope.PERSONAL.value
and preset.owner_user_id != user.id
):
raise_forbidden()
graph_dict = preset.graph_json
else:
graph_dict = graph_dict_from_preset_inputs(
data.graph.nodes,
data.graph.edges,
)
validate_task_graph_dict(graph_dict)
assignee_id = user.id if data.assign_to_current_user else None
source_preset_id = str(data.preset_id) if data.preset_id else None
return await apply_task_graph_to_patient(
info.context.db,
str(data.patient_id),
graph_dict,
assignee_id,
source_preset_id,
)

@strawberry.mutation
@audit_log("delete_task")
async def delete_task(self, info: Info, id: strawberry.ID) -> bool:
Expand Down
Loading
Loading