Skip to content

Commit

Permalink
feat(middlewares): ✨ Added new middleware for text translation
Browse files Browse the repository at this point in the history
  • Loading branch information
egalvis39 committed Feb 14, 2024
1 parent 995df0d commit a076cbb
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 5 deletions.
6 changes: 5 additions & 1 deletion menuflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mautrix.util.logging import TraceLogger

from .flow_utils import FlowUtils
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware, TTMMiddleware
from .nodes import (
CheckTime,
Email,
Expand Down Expand Up @@ -108,6 +108,10 @@ def middleware(
room=room,
default_variables=self.flow_variables,
)
elif middleware_type == Middlewares.TTM:
middleware_initialized = TTMMiddleware(
ttm_data=middleware_model, room=room, default_variables=self.flow_variables
)

return middleware_initialized

Expand Down
1 change: 1 addition & 0 deletions menuflow/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .http import HTTPMiddleware
from .irm import IRMMiddleware
from .llm import LLMMiddleware
from .ttm import TTMMiddleware
2 changes: 1 addition & 1 deletion menuflow/middlewares/irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def run(self, image_mxc: str, content_type: str, filename: str) -> Tuple[i
except ContentTypeError:
response_data = await response.text()

self.log.critical(f"response_data: {response_data}")
self.log.info(f"response_data: {response_data}")

if isinstance(response_data, dict):
# Tulir and its magic since time immemorial
Expand Down
131 changes: 131 additions & 0 deletions menuflow/middlewares/ttm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from typing import Dict, Tuple

from aiohttp import ClientTimeout, ContentTypeError, FormData
from mautrix.util.config import RecursiveDict
from ruamel.yaml.comments import CommentedMap

from ..nodes import Base
from ..repository import TTMMiddleware as TTMMiddlewareModel
from ..room import Room


class TTMMiddleware(Base):
def __init__(self, ttm_data: TTMMiddlewareModel, room: Room, default_variables: Dict) -> None:
Base.__init__(self, room=room, default_variables=default_variables)
self.log = self.log.getChild(ttm_data.id)
self.content: TTMMiddlewareModel = ttm_data

@property
def method(self) -> str:
return self.content.method

@property
def url(self) -> str:
return self.render_data(self.content.url)

@property
def variables(self) -> Dict:
return self.render_data(self.content.variables)

@property
def cookies(self) -> Dict:
return self.render_data(self.content.cookies)

@property
def headers(self) -> Dict:
return self.render_data(self.content.headers)

@property
def basic_auth(self) -> Dict:
return self.render_data(self.content.basic_auth)

@property
def target_language(self) -> str:
return self.render_data(self.content.target_language)

@property
def source_language(self) -> str:
return self.render_data(self.content.source_language)

@property
def provider(self) -> str:
return self.render_data(self.content.provider)

async def run(self, text: str) -> Tuple[int, str]:
"""Make the auth request to refresh api token
Parameters
----------
session : ClientSession
ClientSession
Returns
-------
The status code and the response text.
"""

request_body = {}

if self.headers:
request_body["headers"] = self.headers

data = FormData()
data.add_field(name="text", value=text)
data.add_field(name="target_language", value=self.target_language)
data.add_field(name="source_language", value=self.source_language)
data.add_field(name="provider", value=self.provider)
request_body["data"] = data

try:
timeout = ClientTimeout(total=self.config["menuflow.timeouts.middlewares"])
response = await self.session.request(
self.method, self.url, timeout=timeout, **request_body
)
except Exception as e:
self.log.exception(f"Error in middleware: {e}")
return

variables = {}

if self.cookies:
for cookie in self.cookies:
variables[cookie] = response.cookies.output(cookie)

self.log.debug(
f"middleware: {self.id} type: {self.type} method: {self.method} "
f"url: {self.url} status: {response.status}"
)

try:
response_data = await response.json()
except ContentTypeError:
response_data = await response.text()

self.log.info(f"response_data: {response_data}")

if isinstance(response_data, dict):
# Tulir and its magic since time immemorial
serialized_data = RecursiveDict(CommentedMap(**response_data))
if self.variables:
for variable in self.variables:
try:
variables[variable] = self.render_data(
serialized_data[self.variables[variable]]
)
except KeyError:
pass
elif isinstance(response_data, str):
if self.variables:
for variable in self.variables:
try:
variables[variable] = self.render_data(response_data)
except KeyError:
pass

break

if variables:
await self.room.set_variables(variables=variables)

return response.status, response_data.get("text")
2 changes: 1 addition & 1 deletion menuflow/repository/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .flow import Flow
from .flow_utils import FlowUtils
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware, TTMMiddleware
from .nodes import (
Case,
CheckTime,
Expand Down
15 changes: 13 additions & 2 deletions menuflow/repository/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,23 @@
from mautrix.util.logging import TraceLogger

from ..utils import Middlewares
from .middlewares import ASRMiddleware, EmailServer, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import (
ASRMiddleware,
EmailServer,
HTTPMiddleware,
IRMMiddleware,
LLMMiddleware,
TTMMiddleware,
)

log: TraceLogger = logging.getLogger("menuflow.repository.flow_utils")


@dataclass
class FlowUtils(SerializableAttrs):
middlewares: List[HTTPMiddleware, IRMMiddleware, LLMMiddleware, ASRMiddleware] = ib(default=[])
middlewares: List[
HTTPMiddleware, IRMMiddleware, LLMMiddleware, ASRMiddleware, TTMMiddleware
] = ib(default=[])
email_servers: List[EmailServer] = ib(default=[])

@classmethod
Expand Down Expand Up @@ -60,6 +69,8 @@ def initialize_middleware_dataclass(
return LLMMiddleware.from_dict(middleware)
elif middleware_type == Middlewares.ASR:
return ASRMiddleware.from_dict(middleware)
elif middleware_type == Middlewares.TTM:
return TTMMiddleware(**middleware)

@classmethod
def initialize_email_server_dataclass(cls, email_server: Dict) -> EmailServer:
Expand Down
1 change: 1 addition & 0 deletions menuflow/repository/middlewares/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .http import HTTPMiddleware
from .irm import IRMMiddleware
from .llm import LLMMiddleware
from .ttm import TTMMiddleware
40 changes: 40 additions & 0 deletions menuflow/repository/middlewares/ttm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from typing import Any, Dict

from attr import dataclass, ib

from ..flow_object import FlowObject


@dataclass
class TTMMiddleware(FlowObject):
"""TTM Middleware.
An TTMMiddleware is used to translate text from a source language to a target language.
content:
```
middlewares:
- id: irm_middleware
type: irm
method: POST
url: "https://webapinet.userfoo.com/api/irm/recognize"
prompt: "Given an image, give me the text in it"
variables:
token: token
headers:
Client-token: "example-token"
"""

method: str = ib(default=None)
url: str = ib(default=None)
variables: Dict[str, Any] = ib(factory=dict)
cookies: Dict[str, Any] = ib(factory=dict)
headers: Dict[str, Any] = ib(factory=dict)
basic_auth: Dict[str, Any] = ib(factory=dict)
target_language: str = ib(factory=str)
source_language: str = ib(factory=str)
provider: str = ib(factory=str)
1 change: 1 addition & 0 deletions menuflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class Middlewares(SerializableEnum):
IRM = "irm"
LLM = "llm"
ASR = "asr"
TTM = "ttm"

0 comments on commit a076cbb

Please sign in to comment.