From a076cbbfb65d1db42cacc09c8e40af03a969c529 Mon Sep 17 00:00:00 2001 From: Esteban Galvis Date: Wed, 14 Feb 2024 16:51:03 -0500 Subject: [PATCH] feat(middlewares): :sparkles: Added new middleware for text translation --- menuflow/flow.py | 6 +- menuflow/middlewares/__init__.py | 1 + menuflow/middlewares/irm.py | 2 +- menuflow/middlewares/ttm.py | 131 ++++++++++++++++++++ menuflow/repository/__init__.py | 2 +- menuflow/repository/flow_utils.py | 15 ++- menuflow/repository/middlewares/__init__.py | 1 + menuflow/repository/middlewares/ttm.py | 40 ++++++ menuflow/utils/types.py | 1 + 9 files changed, 194 insertions(+), 5 deletions(-) create mode 100644 menuflow/middlewares/ttm.py create mode 100644 menuflow/repository/middlewares/ttm.py diff --git a/menuflow/flow.py b/menuflow/flow.py index 1872199..a724bdb 100644 --- a/menuflow/flow.py +++ b/menuflow/flow.py @@ -6,7 +6,7 @@ from mautrix.util.logging import TraceLogger from .flow_utils import FlowUtils -from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware +from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware, TTMMiddleware from .nodes import ( CheckTime, Email, @@ -108,6 +108,10 @@ def middleware( room=room, default_variables=self.flow_variables, ) + elif middleware_type == Middlewares.TTM: + middleware_initialized = TTMMiddleware( + ttm_data=middleware_model, room=room, default_variables=self.flow_variables + ) return middleware_initialized diff --git a/menuflow/middlewares/__init__.py b/menuflow/middlewares/__init__.py index 0fa1d26..ceed507 100644 --- a/menuflow/middlewares/__init__.py +++ b/menuflow/middlewares/__init__.py @@ -2,3 +2,4 @@ from .http import HTTPMiddleware from .irm import IRMMiddleware from .llm import LLMMiddleware +from .ttm import TTMMiddleware diff --git a/menuflow/middlewares/irm.py b/menuflow/middlewares/irm.py index 63a5ce4..0360745 100644 --- a/menuflow/middlewares/irm.py +++ b/menuflow/middlewares/irm.py @@ -97,7 +97,7 @@ async def run(self, image_mxc: str, content_type: str, filename: str) -> Tuple[i except ContentTypeError: response_data = await response.text() - self.log.critical(f"response_data: {response_data}") + self.log.info(f"response_data: {response_data}") if isinstance(response_data, dict): # Tulir and its magic since time immemorial diff --git a/menuflow/middlewares/ttm.py b/menuflow/middlewares/ttm.py new file mode 100644 index 0000000..0acd02a --- /dev/null +++ b/menuflow/middlewares/ttm.py @@ -0,0 +1,131 @@ +from typing import Dict, Tuple + +from aiohttp import ClientTimeout, ContentTypeError, FormData +from mautrix.util.config import RecursiveDict +from ruamel.yaml.comments import CommentedMap + +from ..nodes import Base +from ..repository import TTMMiddleware as TTMMiddlewareModel +from ..room import Room + + +class TTMMiddleware(Base): + def __init__(self, ttm_data: TTMMiddlewareModel, room: Room, default_variables: Dict) -> None: + Base.__init__(self, room=room, default_variables=default_variables) + self.log = self.log.getChild(ttm_data.id) + self.content: TTMMiddlewareModel = ttm_data + + @property + def method(self) -> str: + return self.content.method + + @property + def url(self) -> str: + return self.render_data(self.content.url) + + @property + def variables(self) -> Dict: + return self.render_data(self.content.variables) + + @property + def cookies(self) -> Dict: + return self.render_data(self.content.cookies) + + @property + def headers(self) -> Dict: + return self.render_data(self.content.headers) + + @property + def basic_auth(self) -> Dict: + return self.render_data(self.content.basic_auth) + + @property + def target_language(self) -> str: + return self.render_data(self.content.target_language) + + @property + def source_language(self) -> str: + return self.render_data(self.content.source_language) + + @property + def provider(self) -> str: + return self.render_data(self.content.provider) + + async def run(self, text: str) -> Tuple[int, str]: + """Make the auth request to refresh api token + + Parameters + ---------- + session : ClientSession + ClientSession + + Returns + ------- + The status code and the response text. + + """ + + request_body = {} + + if self.headers: + request_body["headers"] = self.headers + + data = FormData() + data.add_field(name="text", value=text) + data.add_field(name="target_language", value=self.target_language) + data.add_field(name="source_language", value=self.source_language) + data.add_field(name="provider", value=self.provider) + request_body["data"] = data + + try: + timeout = ClientTimeout(total=self.config["menuflow.timeouts.middlewares"]) + response = await self.session.request( + self.method, self.url, timeout=timeout, **request_body + ) + except Exception as e: + self.log.exception(f"Error in middleware: {e}") + return + + variables = {} + + if self.cookies: + for cookie in self.cookies: + variables[cookie] = response.cookies.output(cookie) + + self.log.debug( + f"middleware: {self.id} type: {self.type} method: {self.method} " + f"url: {self.url} status: {response.status}" + ) + + try: + response_data = await response.json() + except ContentTypeError: + response_data = await response.text() + + self.log.info(f"response_data: {response_data}") + + if isinstance(response_data, dict): + # Tulir and its magic since time immemorial + serialized_data = RecursiveDict(CommentedMap(**response_data)) + if self.variables: + for variable in self.variables: + try: + variables[variable] = self.render_data( + serialized_data[self.variables[variable]] + ) + except KeyError: + pass + elif isinstance(response_data, str): + if self.variables: + for variable in self.variables: + try: + variables[variable] = self.render_data(response_data) + except KeyError: + pass + + break + + if variables: + await self.room.set_variables(variables=variables) + + return response.status, response_data.get("text") diff --git a/menuflow/repository/__init__.py b/menuflow/repository/__init__.py index 8d18b41..70f83db 100644 --- a/menuflow/repository/__init__.py +++ b/menuflow/repository/__init__.py @@ -1,6 +1,6 @@ from .flow import Flow from .flow_utils import FlowUtils -from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware +from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware, TTMMiddleware from .nodes import ( Case, CheckTime, diff --git a/menuflow/repository/flow_utils.py b/menuflow/repository/flow_utils.py index 992e123..2ca416c 100644 --- a/menuflow/repository/flow_utils.py +++ b/menuflow/repository/flow_utils.py @@ -9,14 +9,23 @@ from mautrix.util.logging import TraceLogger from ..utils import Middlewares -from .middlewares import ASRMiddleware, EmailServer, HTTPMiddleware, IRMMiddleware, LLMMiddleware +from .middlewares import ( + ASRMiddleware, + EmailServer, + HTTPMiddleware, + IRMMiddleware, + LLMMiddleware, + TTMMiddleware, +) log: TraceLogger = logging.getLogger("menuflow.repository.flow_utils") @dataclass class FlowUtils(SerializableAttrs): - middlewares: List[HTTPMiddleware, IRMMiddleware, LLMMiddleware, ASRMiddleware] = ib(default=[]) + middlewares: List[ + HTTPMiddleware, IRMMiddleware, LLMMiddleware, ASRMiddleware, TTMMiddleware + ] = ib(default=[]) email_servers: List[EmailServer] = ib(default=[]) @classmethod @@ -60,6 +69,8 @@ def initialize_middleware_dataclass( return LLMMiddleware.from_dict(middleware) elif middleware_type == Middlewares.ASR: return ASRMiddleware.from_dict(middleware) + elif middleware_type == Middlewares.TTM: + return TTMMiddleware(**middleware) @classmethod def initialize_email_server_dataclass(cls, email_server: Dict) -> EmailServer: diff --git a/menuflow/repository/middlewares/__init__.py b/menuflow/repository/middlewares/__init__.py index 5170314..6e16370 100644 --- a/menuflow/repository/middlewares/__init__.py +++ b/menuflow/repository/middlewares/__init__.py @@ -3,3 +3,4 @@ from .http import HTTPMiddleware from .irm import IRMMiddleware from .llm import LLMMiddleware +from .ttm import TTMMiddleware diff --git a/menuflow/repository/middlewares/ttm.py b/menuflow/repository/middlewares/ttm.py new file mode 100644 index 0000000..cf8b210 --- /dev/null +++ b/menuflow/repository/middlewares/ttm.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any, Dict + +from attr import dataclass, ib + +from ..flow_object import FlowObject + + +@dataclass +class TTMMiddleware(FlowObject): + """TTM Middleware. + + An TTMMiddleware is used to translate text from a source language to a target language. + + content: + + ``` + middlewares: + - id: irm_middleware + type: irm + method: POST + url: "https://webapinet.userfoo.com/api/irm/recognize" + prompt: "Given an image, give me the text in it" + variables: + token: token + headers: + Client-token: "example-token" + + """ + + method: str = ib(default=None) + url: str = ib(default=None) + variables: Dict[str, Any] = ib(factory=dict) + cookies: Dict[str, Any] = ib(factory=dict) + headers: Dict[str, Any] = ib(factory=dict) + basic_auth: Dict[str, Any] = ib(factory=dict) + target_language: str = ib(factory=str) + source_language: str = ib(factory=str) + provider: str = ib(factory=str) diff --git a/menuflow/utils/types.py b/menuflow/utils/types.py index 001eaac..2d49991 100644 --- a/menuflow/utils/types.py +++ b/menuflow/utils/types.py @@ -24,3 +24,4 @@ class Middlewares(SerializableEnum): IRM = "irm" LLM = "llm" ASR = "asr" + TTM = "ttm"