Skip to content

Commit

Permalink
feat(api): ✨ Added new endpoints for flows administration
Browse files Browse the repository at this point in the history
  • Loading branch information
egalvis39 committed May 3, 2024
1 parent 024d8d2 commit 76c7a9b
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 31 deletions.
3 changes: 2 additions & 1 deletion menuflow/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from .db import upgrade_table
from .email_client import EmailClient
from .events import NatsPublisher
from .flow import Flow
from .flow_utils import FlowUtils
from .menu import MenuClient
from .repository.flow_utils import FlowUtils as FlowUtilsModel
from .repository.middlewares import EmailServer
from .server import MenuFlowServer

Expand Down Expand Up @@ -54,6 +54,7 @@ def prepare(self) -> None:
management_api = init_api(self.config, self.loop)
self.server = MenuFlowServer(management_api, self.config, self.loop)
self.flow_utils = FlowUtils()
Flow.init_cls(self.flow_utils)

async def start_email_connections(self):
self.log.debug("Starting email clients...")
Expand Down
107 changes: 105 additions & 2 deletions menuflow/api/client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
from __future__ import annotations

from json import JSONDecodeError
from typing import Dict
from logging import Logger, getLogger
from typing import Dict, Optional

from aiohttp import web
from mautrix.client import Client as MatrixClient
from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequestError
from mautrix.types import UserID

from ..db.client import Client as DBClient
from ..db.flow import Flow as DBFlow
from ..menu import MenuClient
from ..room import Room
from ..utils import Util
from .base import routes
from .responses import resp

log: Logger = getLogger("menuflow.api.client")

async def _create_client(user_id: UserID | None, data: dict) -> web.Response:

async def _create_client(user_id: UserID | None, data: Dict) -> web.Response:
homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None)
device_id = data.get("device_id", None)
Expand Down Expand Up @@ -50,6 +56,16 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
return resp.created(client.to_dict())


async def _reload_flow(client: MenuClient, flow_content: Optional[Dict] = None) -> web.Response:
await client.flow_cls.load_flow(
flow_mxid=client.id, content=flow_content, config=client.menuflow.config
)
client.flow_cls.nodes_by_id = {}

util = Util(client.menuflow.config)
await util.cancel_tasks()


@routes.post("/client/new")
async def create_client(request: web.Request) -> web.Response:
try:
Expand All @@ -73,3 +89,90 @@ async def set_variables(request: web.Request) -> web.Response:
await room.set_variable(variable_id="external", value=variables)

return resp.ok


# Update or create new flow
@routes.put("/flow")
async def create_flow(request: web.Request) -> web.Response:
try:
data: Dict = await request.json()
except JSONDecodeError:
return resp.body_not_json

flow_id = data.get("id", None)
incoming_flow = data.get("flow", None)

if not incoming_flow:
return resp.bad_request("Incoming flow is required")

if flow_id:
db_clients = await DBClient.get_by_flow_id(flow_id)
flow = await DBFlow.get_by_id(flow_id)
flow.flow = incoming_flow
await flow.update()
for db_client in db_clients:
client = MenuClient.cache[db_client.id]
await _reload_flow(client, incoming_flow)
message = "Flow updated successfully"
else:
new_flow = DBFlow(flow=incoming_flow)
await new_flow.insert()
message = "Flow created successfully"

return resp.ok({"error": message})


@routes.get("/flow")
async def get_flow(request: web.Request) -> web.Response:
flow_id = request.query.get("id", None)
client_mxid = request.query.get("client_mxid", None)
if flow_id:
flow = await DBFlow.get_by_id(int(flow_id))
data = flow.serialize()
elif client_mxid:
flow = await DBFlow.get_by_mxid(client_mxid)
data = flow.serialize()
else:
flows = await DBFlow.all()
data = {"flows": flows}

return resp.ok(data)


@routes.patch("/client/{mxid}/flow")
async def update_client(request: web.Request) -> web.Response:
mxid = request.match_info["mxid"]
client: MenuClient = MenuClient.cache.get(mxid)
if not client:
return resp.client_not_found(mxid)

try:
data: Dict = await request.json()
except JSONDecodeError:
return resp.body_not_json

flow_id = data.get("flow_id", None)
if not flow_id:
return resp.bad_request("Flow ID is required")

flow_db = await DBFlow.get_by_id(flow_id)
if not flow_db:
return resp.bad_request("Flow not found")

client.flow = flow_id
await _reload_flow(client, flow_db.flow)

await client.update()
return resp.ok(client.to_dict())


@routes.post("/client/{mxid}/flow/reload")
async def reload_client_flow(request: web.Request) -> web.Response:
mxid = request.match_info["mxid"]
client: MenuClient = MenuClient.cache.get(mxid)
if not client:
return resp.client_not_found(mxid)

await _reload_flow(client)

return resp.ok(client.to_dict())
24 changes: 21 additions & 3 deletions menuflow/api/responses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from http import HTTPStatus
from typing import Optional

from aiohttp import web

Expand Down Expand Up @@ -80,13 +81,30 @@ def user_exists(self) -> web.Response:
status=HTTPStatus.CONFLICT,
)

@property
def ok(self) -> web.Response:
return web.json_response({}, status=HTTPStatus.OK)
def ok(self, data: Optional[str] = {}) -> web.Response:
return web.json_response(data, status=HTTPStatus.OK)

@staticmethod
def created(data: dict) -> web.Response:
return web.json_response(data, status=HTTPStatus.CREATED)

def bad_request(self, message: str) -> web.Response:
return web.json_response(
{
"error": message,
"errcode": "bad_request",
},
status=HTTPStatus.BAD_REQUEST,
)

def client_not_found(self, user_id: str) -> web.Response:
return web.json_response(
{
"error": f"Client with user ID {user_id} not found",
"errcode": "client_not_found",
},
status=HTTPStatus.NOT_FOUND,
)


resp = _Response()
1 change: 1 addition & 0 deletions menuflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def do_update(self, helper: ConfigUpdateHelper) -> None:
copy("menuflow.typing_notification")
copy("menuflow.send_events")
copy("menuflow.load_flow_from")
copy_dict("menuflow.regex")
copy("server.hostname")
copy("server.port")
copy("server.public_url")
Expand Down
6 changes: 6 additions & 0 deletions menuflow/db/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ async def get(cls, id: str) -> Client | None:
q = f"SELECT {cls._columns} FROM client WHERE id=$1"
return cls._from_row(await cls.db.fetchrow(q, id))

@classmethod
async def get_by_flow_id(cls, flow_id: int) -> list[Client]:
q = f"SELECT {cls._columns} FROM client WHERE flow=$1"
rows = await cls.db.fetch(q, flow_id)
return [cls._from_row(row) for row in rows]

async def insert(self) -> None:
q = (
"INSERT INTO client (id, homeserver, access_token, device_id, next_batch, filter_id, "
Expand Down
24 changes: 19 additions & 5 deletions menuflow/db/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,45 @@
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Union

from asyncpg import Record
from attr import dataclass
from attr import dataclass, ib
from mautrix.types import SerializableAttrs
from mautrix.util.async_db import Database

fake_db = Database.create("") if TYPE_CHECKING else None


@dataclass
class Flow:
class Flow(SerializableAttrs):
db: ClassVar[Database] = fake_db

id: int | None
flow: Dict[str, Any]
id: int = ib(default=None)
flow: Dict[str, Any] = ib(factory=dict)

@classmethod
def _from_row(cls, row: Record) -> Union["Flow", None]:
return cls(id=row["id"], flow=json.loads(row["flow"]))

@property
def values(self) -> Dict[str, Any]:
return self.flow
return json.dumps(self.flow)

async def insert(self) -> str:
q = "INSERT INTO flow (flow) VALUES ($1)"
await self.db.execute(q, self.values)

async def update(self) -> None:
q = "UPDATE flow SET flow=$1 WHERE id=$2"
await self.db.execute(q, self.values, self.id)

@classmethod
async def all(cls) -> list[Dict]:
q = "SELECT id, flow FROM flow"
rows = await cls.db.fetch(q)
if not rows:
return []

return [cls._from_row(row).serialize() for row in rows]

@classmethod
async def get_by_id(cls, id: int) -> Union["Flow", None]:
q = "SELECT id, flow FROM flow WHERE id=$1"
Expand Down
3 changes: 3 additions & 0 deletions menuflow/example-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ menuflow:
# - database: the flow is defined in a database
load_flow_from: "yaml"

regex:
room_id: ^![\w-]+:[\w.-]+$


server:
# The IP and port to listen to.
Expand Down
13 changes: 8 additions & 5 deletions menuflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@


class Flow:
flow_utils: FlowUtils
log: TraceLogger = logging.getLogger("menuflow.flow")

def __init__(self) -> None:
self.data: Dict = {}
self.data: Flow = None
self.nodes: List[Dict] = []
self.nodes_by_id: Dict[str, Dict] = {}
self.flow_utils: Optional[FlowUtils] = None
Expand All @@ -62,20 +63,22 @@ async def load_flow(
self,
flow_mxid: Optional[str] = None,
content: Optional[Dict] = None,
flow_utils: Optional[FlowUtils] = None,
config: Optional[Config] = None,
) -> Flow:
self.data = await FlowModel.load_flow(flow_mxid=flow_mxid, content=content, config=config)
self.nodes = self.data.get("nodes", [])
self.nodes = self.data.nodes or []
self.nodes_by_id: Dict[str, Dict] = {}
self.flow_utils = flow_utils

def _add_node_to_cache(self, node_data: Dict):
self.nodes_by_id[node_data.get("id")] = node_data

@property
def flow_variables(self) -> Dict:
return {"flow": self.data.get("flow_variables", {})}
return {"flow": self.data.flow_variables or {}}

@classmethod
def init_cls(cls, flow_utils: FlowUtils) -> None:
cls.flow_utils = flow_utils

def get_node_by_id(self, node_id: str) -> Dict | None:
"""This function returns a node from a cache or a list of nodes based on its ID.
Expand Down
1 change: 1 addition & 0 deletions menuflow/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ async def algorithm(self, room: Room, evt: Optional[MessageEvent] = None) -> Non
if node is None:
self.log.debug(f"Room {room.room_id} does not have a node [{node}]")
await room.update_menu(node_id="start")
await self.algorithm(room=room)
return

self.log.debug(f"The [room: {room.room_id}] [node: {node.id}] [state: {room.route.state}]")
Expand Down
5 changes: 2 additions & 3 deletions menuflow/menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ async def postinit(self) -> None:
self.started = False
self.sync_ok = True
self.flow_cls = Flow()
await self.flow_cls.load_flow(
flow_utils=self.menuflow.flow_utils, flow_mxid=self.id, config=self.menuflow.config
)
await self.flow_cls.load_flow(flow_mxid=self.id, config=self.menuflow.config)
self.matrix_handler: MatrixHandler = self._make_client()
asyncio.create_task(self.matrix_handler.load_all_room_constants())
# if self.enable_crypto:
Expand Down Expand Up @@ -231,6 +229,7 @@ def to_dict(self) -> dict:
# self.crypto.account.fingerprint if self.crypto and self.crypto.account else None
# ),
"autojoin": self.autojoin,
"flow": self.flow,
}

async def delete(self) -> None:
Expand Down
13 changes: 7 additions & 6 deletions menuflow/nodes/check_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
) -> None:
Switch.__init__(self, check_time_node_data, room=room, default_variables=default_variables)
self.content = check_time_node_data
self.util = Util(self.config)

@property
def time_ranges(self) -> List[str]:
Expand Down Expand Up @@ -93,8 +94,8 @@ def check_month(self, month: int) -> bool:

for range_months in self.months:
month_start, month_end = range_months.split("-")
if Util.is_within_range(
month, Util.months.get(month_start), Util.months.get(month_end)
if self.util.is_within_range(
month, self.util.months.get(month_start), self.util.months.get(month_end)
):
return True

Expand All @@ -120,10 +121,10 @@ def check_week_day(self, week_day: str) -> bool:

for week_days_range in self.days_of_week:
week_day_start, week_day_end = week_days_range.split("-")
if Util.is_within_range(
Util.week_days.get(week_day),
Util.week_days.get(week_day_start),
Util.week_days.get(week_day_end),
if self.util.is_within_range(
self.util.week_days.get(week_day),
self.util.week_days.get(week_day_start),
self.util.week_days.get(week_day_end),
):
return True

Expand Down
6 changes: 3 additions & 3 deletions menuflow/repository/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ async def load_flow(
Returns:
Flow: The loaded flow.
"""
if flow_mxid:
if content:
flow = content
elif flow_mxid:
if config["menuflow.load_flow_from"] == "database":
flow = await cls.load_from_db(flow_mxid, config)
else:
flow = cls.load_from_yaml(flow_mxid)
elif content:
flow = content

return cls(**flow["menu"])
Loading

0 comments on commit 76c7a9b

Please sign in to comment.