From 3c695cc1551b2516e15cc4e6bf41f9f36161b5cc Mon Sep 17 00:00:00 2001 From: Esteban Galvis Date: Tue, 28 May 2024 10:53:47 -0500 Subject: [PATCH] feat(api): :sparkles: Updated create client endpoint --- menuflow/api/client.py | 25 +++++++++++++--- menuflow/api/responses.py | 60 ++++++++++++++++----------------------- menuflow/db/flow.py | 3 +- menuflow/menu.py | 10 ++++--- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/menuflow/api/client.py b/menuflow/api/client.py index ccd121c..8853c79 100644 --- a/menuflow/api/client.py +++ b/menuflow/api/client.py @@ -20,7 +20,9 @@ log: Logger = getLogger("menuflow.api.client") -async def _create_client(user_id: UserID | None, data: Dict) -> web.Response: +async def _create_client( + data: Dict, *, user_id: Optional[UserID] = None, flow_id: Optional[int] = None +) -> MenuClient | web.Response: homeserver = data.get("homeserver", None) access_token = data.get("access_token", None) device_id = data.get("device_id", None) @@ -47,7 +49,11 @@ async def _create_client(user_id: UserID | None, data: Dict) -> web.Response: elif whoami.device_id and device_id and whoami.device_id != device_id: return resp.device_id_mismatch(whoami.device_id) client: MenuClient = await MenuClient.get( - whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id + whoami.user_id, + homeserver=homeserver, + access_token=access_token, + device_id=device_id, + flow_id=flow_id, ) client.enabled = data.get("enabled", True) client.autojoin = data.get("autojoin", True) @@ -72,7 +78,14 @@ async def create_client(request: web.Request) -> web.Response: data = await request.json() except JSONDecodeError: return resp.body_not_json - return await _create_client(None, data) + + new_flow_id = None + if MenuClient.menuflow.config["menuflow.load_flow_from"] == "database": + example_flow = {"menu": {"flow_variables": {}, "nodes": []}} + new_flow = DBFlow(flow=example_flow) + new_flow_id = await new_flow.insert() + + return await _create_client(data, flow_id=new_flow_id) @routes.post("/room/{room_id}/set_variables") @@ -119,7 +132,7 @@ async def create_flow(request: web.Request) -> web.Response: await new_flow.insert() message = "Flow created successfully" - return resp.ok({"error": message}) + return resp.ok({"detail": {"message": message}}) @routes.get("/flow") @@ -128,9 +141,13 @@ async def get_flow(request: web.Request) -> web.Response: client_mxid = request.query.get("client_mxid", None) if flow_id: flow = await DBFlow.get_by_id(int(flow_id)) + if not flow: + return resp.not_found(f"Flow with ID {flow_id} not found") data = flow.serialize() elif client_mxid: flow = await DBFlow.get_by_mxid(client_mxid) + if not flow: + return resp.not_found(f"Flow with mxid {client_mxid} not found") data = flow.serialize() else: flows = await DBFlow.all() diff --git a/menuflow/api/responses.py b/menuflow/api/responses.py index cb8c10c..976122f 100644 --- a/menuflow/api/responses.py +++ b/menuflow/api/responses.py @@ -10,51 +10,40 @@ class _Response: @property def body_not_json(self) -> web.Response: return web.json_response( - { - "error": "Request body is not JSON", - "errcode": "body_not_json", - }, + {"detail": {"message": "Request body is not JSON"}}, status=HTTPStatus.BAD_REQUEST, ) @property def bad_client_access_token(self) -> web.Response: return web.json_response( - { - "error": "Invalid access token", - "errcode": "bad_client_access_token", - }, + {"detail": {"message": "Invalid access token"}}, status=HTTPStatus.BAD_REQUEST, ) @property def bad_client_access_details(self) -> web.Response: return web.json_response( - { - "error": "Invalid homeserver or access token", - "errcode": "bad_client_access_details", - }, + {"detail": {"message": "Invalid homeserver or access token"}}, status=HTTPStatus.BAD_REQUEST, ) @property def bad_client_connection_details(self) -> web.Response: return web.json_response( - { - "error": "Could not connect to homeserver", - "errcode": "bad_client_connection_details", - }, + {"detail": {"message": "Could not connect to homeserver"}}, status=HTTPStatus.BAD_REQUEST, ) def mxid_mismatch(self, found: str) -> web.Response: return web.json_response( { - "error": ( - "The Matrix user ID of the client and the user ID of the access token don't " - f"match. Access token is for user {found}" - ), - "errcode": "mxid_mismatch", + "detail": { + "message": f""" + The Matrix user ID of the client and the user ID of the access token don't + match. Access token is for user {found} + """ + } }, status=HTTPStatus.BAD_REQUEST, ) @@ -62,10 +51,12 @@ def mxid_mismatch(self, found: str) -> web.Response: def device_id_mismatch(self, found: str) -> web.Response: return web.json_response( { - "error": ( - "The Matrix device ID of the client and the device ID of the access token " - f"don't match. Access token is for device {found}" - ), + "detail": { + "message": """ + The Matrix device ID of the client and the device ID of the access token + don't match. Access token is for device {found} + """ + }, "errcode": "mxid_mismatch", }, status=HTTPStatus.BAD_REQUEST, @@ -74,10 +65,7 @@ def device_id_mismatch(self, found: str) -> web.Response: @property def user_exists(self) -> web.Response: return web.json_response( - { - "error": "There is already a client with the user ID of that token", - "errcode": "user_exists", - }, + {"detail": {"message": "There is already a client with the user ID of that token"}}, status=HTTPStatus.CONFLICT, ) @@ -90,18 +78,20 @@ def created(data: dict) -> web.Response: def bad_request(self, message: str) -> web.Response: return web.json_response( - { - "error": message, - "errcode": "bad_request", - }, + {"detail": {"message": message}}, status=HTTPStatus.BAD_REQUEST, ) def client_not_found(self, user_id: str) -> web.Response: + return web.json_response( + {"detail": {"message": f"Client with user ID {user_id} not found"}}, + status=HTTPStatus.NOT_FOUND, + ) + + def not_found(self, message: str) -> web.Response: return web.json_response( { - "error": f"Client with user ID {user_id} not found", - "errcode": "client_not_found", + "detail": {"message": message}, }, status=HTTPStatus.NOT_FOUND, ) diff --git a/menuflow/db/flow.py b/menuflow/db/flow.py index 306aa13..18fd1d3 100644 --- a/menuflow/db/flow.py +++ b/menuflow/db/flow.py @@ -24,9 +24,10 @@ def _from_row(cls, row: Record) -> Union["Flow", None]: def values(self) -> Dict[str, Any]: return json.dumps(self.flow) - async def insert(self) -> str: + async def insert(self) -> int: q = "INSERT INTO flow (flow) VALUES ($1)" await self.db.execute(q, self.values) + return await self.db.fetchval("SELECT MAX(id) FROM flow") async def update(self) -> None: q = "UPDATE flow SET flow=$1 WHERE id=$2" diff --git a/menuflow/menu.py b/menuflow/menu.py index a02e311..51b8708 100644 --- a/menuflow/menu.py +++ b/menuflow/menu.py @@ -3,7 +3,7 @@ import asyncio import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Optional, cast from aiohttp import ClientSession, TraceConfig from mautrix.client import Client, InternalEventType @@ -256,9 +256,10 @@ async def get( cls, user_id: UserID, *, - homeserver: str | None = None, - access_token: str | None = None, - device_id: DeviceID | None = None, + homeserver: Optional[str] = None, + access_token: Optional[str] = None, + device_id: Optional[DeviceID] = None, + flow_id: Optional[int] = None, ) -> Client | None: try: return cls.cache[user_id] @@ -276,6 +277,7 @@ async def get( homeserver=homeserver, access_token=access_token, device_id=device_id or "", + flow=flow_id, ) await user.insert() await user.postinit()