Skip to content

Commit

Permalink
Merge pull request #28 from iKonoTelecomunicaciones/bug_fix_mixed_flows
Browse files Browse the repository at this point in the history
Bug fix mixed flows
  • Loading branch information
bramenn committed May 2, 2023
2 parents 1551fdf + 8b07764 commit c63db29
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 128 deletions.
153 changes: 82 additions & 71 deletions menuflow/flow.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,110 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, List

from mautrix.types import SerializableAttrs
from mautrix.util.logging import TraceLogger

from .middlewares import HTTPMiddleware
from .nodes import CheckTime, Email, HTTPRequest, Input, Location, Media, Message, Switch
from .repository import Flow as FlowModel
from .room import Room

if TYPE_CHECKING:
from .middlewares import HTTPMiddleware


class Flow:
log: TraceLogger = logging.getLogger("menuflow.flow")

nodes: Dict[str, (Message, Input, Switch, HTTPRequest, CheckTime)]
middlewares: Dict[str, HTTPMiddleware]
nodes: List[Dict]
middlewares: List[Dict]

def __init__(self, flow_data: FlowModel) -> None:
self.data: FlowModel = (
flow_data.serialize() if isinstance(flow_data, SerializableAttrs) else flow_data
)
self.nodes = {}
self.middlewares = {}
self.nodes = self.data.get("nodes", [])
self.middlewares = self.data.get("middlewares", [])
self.nodes_by_id: Dict[str, Dict] = {}
self.middlewares_by_id: Dict[str, Dict] = {}

def _add_node_to_cache(self, node_data: Dict):
self.nodes_by_id[node_data.get("id")] = node_data

def _add_middleware_to_cache(self, middleware_data: Dict):
self.middlewares_by_id[middleware_data.get("id")] = middleware_data

@property
def flow_variables(self) -> Dict:
return self.data.get("flow_variables", {})

def load(self):
self.load_middlewares()
self.load_nodes()

def load_nodes(self):
"""It takes the nodes from the flow data and creates a new node object for each one"""
for node in self.data.get("nodes", []):
if node.get("type") == "message":
node = Message(message_node_data=node)
elif node.get("type") == "media":
node = Media(media_node_data=node)
elif node.get("type") == "email":
node = Email(email_node_data=node)
elif node.get("type") == "location":
node = Location(location_node_data=node)
elif node.get("type") == "switch":
node = Switch(switch_node_data=node)
elif node.get("type") == "input":
node = Input(input_node_data=node)
elif node.get("type") == "check_time":
node = CheckTime(check_time_node_data=node)
elif node.get("type") == "http_request":
node = HTTPRequest(http_request_node_data=node)

if node.content.get("middleware"):
node.middleware = self.get_middleware_by_id(node.content.get("middleware"))
else:
continue

node.variables = self.flow_variables or {}
self.nodes[node.id] = node

def load_middlewares(self):
"""It loads the middlewares from the data file into the `middlewares` dictionary"""
for middleware in self.data.get("middlewares", []):
middleware = HTTPMiddleware(http_middleware_data=middleware)
self.middlewares[middleware.id] = middleware

def get_node_by_id(self, node_id: str) -> HTTPRequest | Input | Message | Switch | CheckTime:
return self.nodes.get(node_id)

def get_middleware_by_id(self, middleware_id: str) -> HTTPMiddleware:
return self.middlewares.get(middleware_id)

def node(self, room: Room) -> HTTPRequest | Input | Message | Switch | CheckTime:
"""It returns the node that should be executed next
Parameters
----------
room : Room
The room object that the user is currently in.
Returns
-------
The node object.
"""
node = self.get_node_by_id(node_id=room.node_id or "start")

if not node:
def get_node_by_id(self, node_id: str) -> Dict | None:
node = self.nodes_by_id.get(node_id)
if node:
return node

for node in self.nodes:
if node_id == node.get("id", ""):
self._add_node_to_cache(node)
return node

return None

def get_middleware_by_id(self, middleware_id: str) -> Dict | None:
middleware = self.middlewares_by_id.get(middleware_id)
if middleware:
return middleware

for middleware in self.middlewares:
if middleware_id == middleware.get("id", ""):
self._add_middleware_to_cache(middleware)
return middleware

return None

def middleware(self, middleware_id: str, room: Room) -> HTTPMiddleware:
middleware_data = self.get_middleware_by_id(middleware_id=middleware_id)

if not middleware_data:
return

middleware_initialized = HTTPMiddleware(http_middleware_data=middleware_data)
middleware_initialized.room = room

return middleware_initialized

def node(
self, room: Room
) -> Message | Input | HTTPRequest | Switch | CheckTime | Media | Email | Location | None:
node_data = self.get_node_by_id(node_id=room.node_id)

if not node_data:
return

if node_data.get("type") == "message":
node_initialized = Message(message_node_data=node_data)
elif node_data.get("type") == "media":
node_initialized = Media(media_node_data=node_data)
elif node_data.get("type") == "email":
node_initialized = Email(email_node_data=node_data)
elif node_data.get("type") == "location":
node_initialized = Location(location_node_data=node_data)
elif node_data.get("type") == "switch":
node_initialized = Switch(switch_node_data=node_data)
elif node_data.get("type") == "input":
node_initialized = Input(input_node_data=node_data)
elif node_data.get("type") == "check_time":
node_initialized = CheckTime(check_time_node_data=node_data)
elif node_data.get("type") == "http_request":
node_initialized = HTTPRequest(http_request_node_data=node_data)

if node_data.get("middleware"):
middleware = self.middleware(node_data.get("middleware"), room)
node_initialized.middleware = middleware
else:
return

node.room = room
node.variables.update(room._variables)
node_initialized.room = room
node_initialized.variables = self.flow_variables or {}
node_initialized.variables.update(room._variables)

return node
return node_initialized
20 changes: 7 additions & 13 deletions menuflow/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(self, config: Config, *args, **kwargs) -> None:
session=self.api.session,
default_variables=self.flow.flow_variables,
)
self.flow.load()

def handle_sync(self, data: Dict) -> list[asyncio.Task]:
# This is a way to remove duplicate events from the sync
Expand Down Expand Up @@ -152,6 +151,10 @@ async def load_room_constants(self, room_id: RoomID):
if not await room.get_variable("bot_mxid"):
await room.set_variable("bot_mxid", self.mxid)

if not await room.get_variable("customer_mxid"):
await User.get_by_mxid(mxid=await room.creator)
await room.set_variable("customer_mxid", await room.creator)

async def handle_join(self, evt: StrippedStateEvent):
if evt.room_id in self.LOCKED_ROOMS:
self.log.debug(f"Ignoring menu request in {evt.room_id} Menu locked")
Expand Down Expand Up @@ -195,18 +198,9 @@ async def handle_message(self, message: MessageEvent) -> None:
)
return

try:
user: User = await User.get_by_mxid(mxid=message.sender)
room = await Room.get_by_room_id(room_id=message.room_id)
room.config = user.config = self.config
room.matrix_client = self

if not await room.get_variable("customer_mxid"):
await room.set_variable("customer_mxid", user.mxid)

except Exception as e:
self.log.exception(e)
return
room = await Room.get_by_room_id(room_id=message.room_id)
room.config = self.config = self.config
room.matrix_client = self

if not room:
return
Expand Down
67 changes: 34 additions & 33 deletions menuflow/nodes/http_request.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from typing import Dict
from typing import TYPE_CHECKING, Dict

from aiohttp import BasicAuth, ClientTimeout, ContentTypeError
from mautrix.util.config import RecursiveDict
from ruamel.yaml.comments import CommentedMap

from ..db.room import RoomState
from ..middlewares import HTTPMiddleware
from ..repository import HTTPRequest as HTTPRequestModel
from .switch import Switch

if TYPE_CHECKING:
from ..middlewares import HTTPMiddleware


class HTTPRequest(Switch):
HTTP_ATTEMPTS: Dict = {}

middleware: HTTPMiddleware = None
middleware: "HTTPMiddleware" = None

def __init__(self, http_request_node_data: HTTPRequestModel) -> None:
Switch.__init__(self, http_request_node_data)
Expand Down Expand Up @@ -130,36 +132,35 @@ async def make_request(self):
variables = {}
o_connection = None

if response.status in [200, 201]:
if self.cookies:
for cookie in self.cookies:
variables[cookie] = response.cookies.output(cookie)

try:
response_data = await response.json()
except ContentTypeError:
response_data = {}

if isinstance(response_data, dict):
# Tulir and its magic since time immemorial
serialized_data = RecursiveDict(CommentedMap(**response_data))
if self.http_variables:
for variable in self.http_variables:
try:
variables[variable] = self.render_data(
serialized_data[self.http_variables[variable]]
)
except KeyError:
pass
elif isinstance(response_data, str):
if self.http_variables:
for variable in self.http_variables:
try:
variables[variable] = self.render_data(response_data)
except KeyError:
pass

break
if self.cookies:
for cookie in self.cookies:
variables[cookie] = response.cookies.output(cookie)

try:
response_data = await response.json()
except ContentTypeError:
response_data = {}

if isinstance(response_data, dict):
# Tulir and its magic since time immemorial
serialized_data = RecursiveDict(CommentedMap(**response_data))
if self.http_variables:
for variable in self.http_variables:
try:
variables[variable] = self.render_data(
serialized_data[self.http_variables[variable]]
)
except KeyError:
pass
elif isinstance(response_data, str):
if self.http_variables:
for variable in self.http_variables:
try:
variables[variable] = self.render_data(response_data)
except KeyError:
pass

break

if self.cases:
o_connection = await self.get_case_by_id(id=response.status)
Expand Down
18 changes: 17 additions & 1 deletion menuflow/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from typing import Any, Dict, cast

from mautrix.client import Client as MatrixClient
from mautrix.types import RoomID
from mautrix.types import EventType, RoomID, StateEventContent
from mautrix.util.logging import TraceLogger

from .config import Config
from .db.room import Room as DBRoom
from .db.room import RoomState
from .utils import Util


class Room(DBRoom):
Expand Down Expand Up @@ -40,13 +41,28 @@ def _add_to_cache(self) -> None:
self.by_room_id[self.room_id] = self

async def clean_up(self):
await Util.cancel_task(task_name=self.room_id)
del self.by_room_id[self.room_id]
self.variables = "{}"
self._variables = {}
self.node_id = RoomState.START.value
self.state = None
await self.update()

@property
async def creator(self) -> Dict:
"""This function retrieves the creator of a Matrix room.
Returns
-------
The `creator` of the Matrix room is being returned as a string.
"""
created_room_event: StateEventContent = await self.matrix_client.get_state_event(
self.room_id, event_type=EventType.ROOM_CREATE
)
return created_room_event.get("creator")

@classmethod
async def get_by_room_id(cls, room_id: RoomID, create: bool = True) -> "Room" | None:
"""It gets a room from the database, or creates one if it doesn't exist
Expand Down
Loading

0 comments on commit c63db29

Please sign in to comment.