From e3beef58975c4cf4727f3ac5a1e685604d907f5d Mon Sep 17 00:00:00 2001 From: egalvis Date: Tue, 16 May 2023 13:49:42 -0500 Subject: [PATCH] Fixed error in room variables --- menuflow/flow.py | 39 +++++++++++++++++++++++----------- menuflow/matrix.py | 1 - menuflow/middlewares/http.py | 7 +++--- menuflow/nodes/base.py | 10 +++++---- menuflow/nodes/check_time.py | 9 +++++--- menuflow/nodes/email.py | 7 +++--- menuflow/nodes/http_request.py | 9 ++++++-- menuflow/nodes/input.py | 7 +++--- menuflow/nodes/location.py | 7 ++++-- menuflow/nodes/media.py | 5 +++-- menuflow/nodes/message.py | 8 +++++-- menuflow/nodes/switch.py | 4 +++- test/conftest.py | 30 +++++++++++--------------- 13 files changed, 87 insertions(+), 56 deletions(-) diff --git a/menuflow/flow.py b/menuflow/flow.py index 34f27eb..f289602 100644 --- a/menuflow/flow.py +++ b/menuflow/flow.py @@ -93,8 +93,9 @@ def middleware(self, middleware_id: str, room: Room) -> HTTPMiddleware: if not middleware_data: return - middleware_initialized = HTTPMiddleware(http_middleware_data=middleware_data) - middleware_initialized.room = room + middleware_initialized = HTTPMiddleware( + http_middleware_data=middleware_data, room=room, default_variables=self.flow_variables + ) return middleware_initialized @@ -107,21 +108,37 @@ def node( return if node_data.get("type") == "message": - node_initialized = Message(message_node_data=node_data) + node_initialized = Message( + message_node_data=node_data, room=room, default_variables=self.flow_variables + ) elif node_data.get("type") == "media": - node_initialized = Media(media_node_data=node_data) + node_initialized = Media( + media_node_data=node_data, room=room, default_variables=self.flow_variables + ) elif node_data.get("type") == "email": - node_initialized = Email(email_node_data=node_data) + node_initialized = Email( + email_node_data=node_data, room=room, default_variables=self.flow_variables + ) elif node_data.get("type") == "location": - node_initialized = Location(location_node_data=node_data) + node_initialized = Location( + location_node_data=node_data, room=room, default_variables=self.flow_variables + ) elif node_data.get("type") == "switch": - node_initialized = Switch(switch_node_data=node_data) + node_initialized = Switch( + switch_node_data=node_data, room=room, default_variables=self.flow_variables + ) elif node_data.get("type") == "input": - node_initialized = Input(input_node_data=node_data) + node_initialized = Input( + input_node_data=node_data, room=room, default_variables=self.flow_variables + ) elif node_data.get("type") == "check_time": - node_initialized = CheckTime(check_time_node_data=node_data) + node_initialized = CheckTime( + check_time_node_data=node_data, room=room, default_variables=self.flow_variables + ) elif node_data.get("type") == "http_request": - node_initialized = HTTPRequest(http_request_node_data=node_data) + node_initialized = HTTPRequest( + http_request_node_data=node_data, room=room, default_variables=self.flow_variables + ) if node_data.get("middleware"): middleware = self.middleware(node_data.get("middleware"), room) @@ -129,6 +146,4 @@ def node( else: return - node_initialized.room = room - return node_initialized diff --git a/menuflow/matrix.py b/menuflow/matrix.py index d6df433..164a534 100644 --- a/menuflow/matrix.py +++ b/menuflow/matrix.py @@ -52,7 +52,6 @@ def __init__(self, config: Config, *args, **kwargs) -> None: Base.init_cls( config=self.config, session=self.api.session, - default_variables=self.flow.flow_variables, ) def handle_sync(self, data: Dict) -> list[asyncio.Task]: diff --git a/menuflow/middlewares/http.py b/menuflow/middlewares/http.py index bce6655..88dc853 100644 --- a/menuflow/middlewares/http.py +++ b/menuflow/middlewares/http.py @@ -10,9 +10,10 @@ class HTTPMiddleware(Base): - room: Room = None - - def __init__(self, http_middleware_data: HTTPMiddlewareModel) -> None: + def __init__( + self, http_middleware_data: HTTPMiddlewareModel, room: Room, default_variables: Dict + ) -> None: + Base.__init__(self, room=room, default_variables=default_variables) self.log = self.log.getChild(http_middleware_data.get("id")) self.content: Dict = http_middleware_data diff --git a/menuflow/nodes/base.py b/menuflow/nodes/base.py index 0a71c2a..b5c9b2a 100644 --- a/menuflow/nodes/base.py +++ b/menuflow/nodes/base.py @@ -66,7 +66,10 @@ class Base: session: ClientSession content: Dict - room: Room + + def __init__(self, room: Room, default_variables: Dict) -> None: + self.room = room + self.default_variables = default_variables @property def id(self) -> str: @@ -77,10 +80,9 @@ def type(self) -> str: return self.content.get("type", "") @classmethod - def init_cls(cls, config: Config, session: ClientSession, default_variables: Dict): + def init_cls(cls, config: Config, session: ClientSession): cls.config = config cls.session = session - cls.variables = default_variables or {} @abstractmethod async def run(self): @@ -150,7 +152,7 @@ def render_data(self, data: Dict | List | str) -> Dict | List | str: self.log.exception(e) return - copy_variables = {**self.variables, **self.room._variables} + copy_variables = {**self.default_variables, **self.room._variables} try: data = loads(data_template.render(**copy_variables)) diff --git a/menuflow/nodes/check_time.py b/menuflow/nodes/check_time.py index 49eaf5b..97513a8 100644 --- a/menuflow/nodes/check_time.py +++ b/menuflow/nodes/check_time.py @@ -1,16 +1,19 @@ from datetime import datetime -from typing import Any, List +from typing import Any, Dict, List import pytz from ..repository import CheckTime as CheckTimeModel +from ..room import Room from ..utils import Util from .switch import Switch class CheckTime(Switch): - def __init__(self, check_time_node_data: CheckTimeModel) -> None: - Switch.__init__(self, check_time_node_data) + def __init__( + self, check_time_node_data: CheckTimeModel, room: Room, default_variables: Dict + ) -> None: + Switch.__init__(self, check_time_node_data, room=room, default_variables=default_variables) self.content = check_time_node_data @property diff --git a/menuflow/nodes/email.py b/menuflow/nodes/email.py index 86c5114..a53d307 100644 --- a/menuflow/nodes/email.py +++ b/menuflow/nodes/email.py @@ -1,17 +1,18 @@ import asyncio -from typing import List +from typing import Dict, List from ..email_client import Email as EmailMessage from ..email_client import EmailClient from ..repository import Email as EmailModel +from ..room import Room from .message import Message class Email(Message): email_client: EmailClient = None - def __init__(self, email_node_data: EmailModel) -> None: - Message.__init__(self, email_node_data) + def __init__(self, email_node_data: EmailModel, room: Room, default_variables: Dict) -> None: + Message.__init__(self, email_node_data, room=room, default_variables=default_variables) self.content = email_node_data @property diff --git a/menuflow/nodes/http_request.py b/menuflow/nodes/http_request.py index 5bb7d74..880857e 100644 --- a/menuflow/nodes/http_request.py +++ b/menuflow/nodes/http_request.py @@ -6,6 +6,7 @@ from ..db.room import RoomState from ..repository import HTTPRequest as HTTPRequestModel +from ..room import Room from .switch import Switch if TYPE_CHECKING: @@ -17,8 +18,12 @@ class HTTPRequest(Switch): middleware: "HTTPMiddleware" = None - def __init__(self, http_request_node_data: HTTPRequestModel) -> None: - Switch.__init__(self, http_request_node_data) + def __init__( + self, http_request_node_data: HTTPRequestModel, room: Room, default_variables: Dict + ) -> None: + Switch.__init__( + self, http_request_node_data, room=room, default_variables=default_variables + ) self.log = self.log.getChild(http_request_node_data.get("id")) self.content: Dict = http_request_node_data diff --git a/menuflow/nodes/input.py b/menuflow/nodes/input.py index 890c958..4109497 100644 --- a/menuflow/nodes/input.py +++ b/menuflow/nodes/input.py @@ -14,15 +14,16 @@ from ..db.room import RoomState from ..repository import Input as InputModel +from ..room import Room from ..utils import Util from .message import Message from .switch import Switch class Input(Switch, Message): - def __init__(self, input_node_data: InputModel) -> None: - Switch.__init__(self, input_node_data) - Message.__init__(self, input_node_data) + def __init__(self, input_node_data: InputModel, room: Room, default_variables: Dict) -> None: + Switch.__init__(self, input_node_data, room=room, default_variables=default_variables) + Message.__init__(self, input_node_data, room=room, default_variables=default_variables) self.content = input_node_data @property diff --git a/menuflow/nodes/location.py b/menuflow/nodes/location.py index 473cb40..e58b91c 100644 --- a/menuflow/nodes/location.py +++ b/menuflow/nodes/location.py @@ -7,12 +7,15 @@ from ..db.room import RoomState from ..repository import Location as LocationModel +from ..room import Room from .message import Message class Location(Message): - def __init__(self, location_node_data: LocationModel) -> None: - Message.__init__(self, location_node_data) + def __init__( + self, location_node_data: LocationModel, room: Room, default_variables: Dict + ) -> None: + Message.__init__(self, location_node_data, room=room, default_variables=default_variables) self.log = self.log.getChild(location_node_data.get("id")) self.content: Dict = location_node_data diff --git a/menuflow/nodes/media.py b/menuflow/nodes/media.py index 146f240..f02ac43 100644 --- a/menuflow/nodes/media.py +++ b/menuflow/nodes/media.py @@ -18,6 +18,7 @@ from ..db.room import RoomState from ..repository import Media as MediaModel +from ..room import Room from .message import Message try: @@ -29,8 +30,8 @@ class Media(Message): media_cache: Dict[str, MediaMessageEventContent] = {} - def __init__(self, media_node_data: MediaModel) -> None: - Message.__init__(self, media_node_data) + def __init__(self, media_node_data: MediaModel, room: Room, default_variables: Dict) -> None: + Message.__init__(self, media_node_data, room=room, default_variables=default_variables) self.log = self.log.getChild(media_node_data.get("id")) self.content: Dict = media_node_data diff --git a/menuflow/nodes/message.py b/menuflow/nodes/message.py index a1399ab..db3d024 100644 --- a/menuflow/nodes/message.py +++ b/menuflow/nodes/message.py @@ -1,15 +1,19 @@ -from typing import Dict +from typing import Any, Dict from markdown import markdown from mautrix.types import Format, MessageType, TextMessageEventContent from ..db.room import RoomState from ..repository import Message as MessageModel +from ..room import Room from .base import Base class Message(Base): - def __init__(self, message_node_data: MessageModel) -> None: + def __init__( + self, message_node_data: MessageModel, room: Room, default_variables: Dict + ) -> None: + Base.__init__(self, room=room, default_variables=default_variables) self.log = self.log.getChild(message_node_data.get("id")) self.content: Dict = message_node_data diff --git a/menuflow/nodes/switch.py b/menuflow/nodes/switch.py index 1df8abc..6aacaad 100644 --- a/menuflow/nodes/switch.py +++ b/menuflow/nodes/switch.py @@ -3,11 +3,13 @@ from typing import Dict, List from ..repository import Switch as SwitchModel +from ..room import Room from .base import Base, safe_data_convertion class Switch(Base): - def __init__(self, switch_node_data: SwitchModel) -> None: + def __init__(self, switch_node_data: SwitchModel, room: Room, default_variables: Dict) -> None: + Base.__init__(self, room=room, default_variables=default_variables) self.log = self.log.getChild(switch_node_data.get("id")) self.content: Dict = switch_node_data diff --git a/test/conftest.py b/test/conftest.py index 5339a36..c4376ba 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -50,18 +50,16 @@ async def base(sample_flow_1: Flow, room: Room, mocker: MockerFixture) -> Base: Base, "run", ) - base = Base() - base.room = room - base.variables = sample_flow_1.flow_variables + base = Base(room=room, default_variables=sample_flow_1.flow_variables) return base @pytest_asyncio.fixture async def message(sample_flow_1: Flow, base: Base) -> Message: message_node_data = sample_flow_1.get_node_by_id("start") - message_node = Message(message_node_data) - message_node.room = base.room - message_node.variables = base.variables + message_node = Message( + message_node_data, room=base.room, default_variables=base.default_variables + ) message_node.matrix_client = Client(base_url="") return message_node @@ -69,9 +67,9 @@ async def message(sample_flow_1: Flow, base: Base) -> Message: @pytest_asyncio.fixture async def switch(sample_flow_1: Flow, base: Base) -> Switch: switch_node_data = sample_flow_1.get_node_by_id("switch-1") - switch_node = Switch(switch_node_data) - switch_node.room = base.room - switch_node.variables = base.variables + switch_node = Switch( + switch_node_data, room=base.room, default_variables=base.default_variables + ) switch_node.matrix_client = Client(base_url="") return switch_node @@ -79,9 +77,7 @@ async def switch(sample_flow_1: Flow, base: Base) -> Switch: @pytest_asyncio.fixture async def input_text(sample_flow_1: Flow, base: Base) -> Input: input_node_data = sample_flow_1.get_node_by_id("input-1") - input_node = Input(input_node_data) - input_node.room = base.room - input_node.variables = base.variables + input_node = Input(input_node_data, room=base.room, default_variables=base.default_variables) input_node.matrix_client = Client(base_url="") return input_node @@ -89,9 +85,7 @@ async def input_text(sample_flow_1: Flow, base: Base) -> Input: @pytest_asyncio.fixture async def input_media(sample_flow_1: Flow, base: Base) -> Input: input_node_data = sample_flow_1.get_node_by_id("input-4") - input_node = Input(input_node_data) - input_node.room = base.room - input_node.variables = base.variables + input_node = Input(input_node_data, room=base.room, default_variables=base.default_variables) input_node.matrix_client = Client(base_url="") return input_node @@ -99,8 +93,8 @@ async def input_media(sample_flow_1: Flow, base: Base) -> Input: @pytest_asyncio.fixture async def location(sample_flow_1: Flow, base: Base) -> Location: location_node_data = sample_flow_1.get_node_by_id("location-1") - location_node = Location(location_node_data) - location_node.room = base.room - location_node.variables = base.variables + location_node = Location( + location_node_data, room=base.room, default_variables=base.default_variables + ) location_node.matrix_client = Client(base_url="") return location_node