-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(middlewares): ✨ Added new middleware for text translation
- Loading branch information
Showing
9 changed files
with
194 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from typing import Dict, Tuple | ||
|
||
from aiohttp import ClientTimeout, ContentTypeError, FormData | ||
from mautrix.util.config import RecursiveDict | ||
from ruamel.yaml.comments import CommentedMap | ||
|
||
from ..nodes import Base | ||
from ..repository import TTMMiddleware as TTMMiddlewareModel | ||
from ..room import Room | ||
|
||
|
||
class TTMMiddleware(Base): | ||
def __init__(self, ttm_data: TTMMiddlewareModel, room: Room, default_variables: Dict) -> None: | ||
Base.__init__(self, room=room, default_variables=default_variables) | ||
self.log = self.log.getChild(ttm_data.id) | ||
self.content: TTMMiddlewareModel = ttm_data | ||
|
||
@property | ||
def method(self) -> str: | ||
return self.content.method | ||
|
||
@property | ||
def url(self) -> str: | ||
return self.render_data(self.content.url) | ||
|
||
@property | ||
def variables(self) -> Dict: | ||
return self.render_data(self.content.variables) | ||
|
||
@property | ||
def cookies(self) -> Dict: | ||
return self.render_data(self.content.cookies) | ||
|
||
@property | ||
def headers(self) -> Dict: | ||
return self.render_data(self.content.headers) | ||
|
||
@property | ||
def basic_auth(self) -> Dict: | ||
return self.render_data(self.content.basic_auth) | ||
|
||
@property | ||
def target_language(self) -> str: | ||
return self.render_data(self.content.target_language) | ||
|
||
@property | ||
def source_language(self) -> str: | ||
return self.render_data(self.content.source_language) | ||
|
||
@property | ||
def provider(self) -> str: | ||
return self.render_data(self.content.provider) | ||
|
||
async def run(self, text: str) -> Tuple[int, str]: | ||
"""Make the auth request to refresh api token | ||
Parameters | ||
---------- | ||
session : ClientSession | ||
ClientSession | ||
Returns | ||
------- | ||
The status code and the response text. | ||
""" | ||
|
||
request_body = {} | ||
|
||
if self.headers: | ||
request_body["headers"] = self.headers | ||
|
||
data = FormData() | ||
data.add_field(name="text", value=text) | ||
data.add_field(name="target_language", value=self.target_language) | ||
data.add_field(name="source_language", value=self.source_language) | ||
data.add_field(name="provider", value=self.provider) | ||
request_body["data"] = data | ||
|
||
try: | ||
timeout = ClientTimeout(total=self.config["menuflow.timeouts.middlewares"]) | ||
response = await self.session.request( | ||
self.method, self.url, timeout=timeout, **request_body | ||
) | ||
except Exception as e: | ||
self.log.exception(f"Error in middleware: {e}") | ||
return | ||
|
||
variables = {} | ||
|
||
if self.cookies: | ||
for cookie in self.cookies: | ||
variables[cookie] = response.cookies.output(cookie) | ||
|
||
self.log.debug( | ||
f"middleware: {self.id} type: {self.type} method: {self.method} " | ||
f"url: {self.url} status: {response.status}" | ||
) | ||
|
||
try: | ||
response_data = await response.json() | ||
except ContentTypeError: | ||
response_data = await response.text() | ||
|
||
self.log.info(f"response_data: {response_data}") | ||
|
||
if isinstance(response_data, dict): | ||
# Tulir and its magic since time immemorial | ||
serialized_data = RecursiveDict(CommentedMap(**response_data)) | ||
if self.variables: | ||
for variable in self.variables: | ||
try: | ||
variables[variable] = self.render_data( | ||
serialized_data[self.variables[variable]] | ||
) | ||
except KeyError: | ||
pass | ||
elif isinstance(response_data, str): | ||
if self.variables: | ||
for variable in self.variables: | ||
try: | ||
variables[variable] = self.render_data(response_data) | ||
except KeyError: | ||
pass | ||
|
||
break | ||
|
||
if variables: | ||
await self.room.set_variables(variables=variables) | ||
|
||
return response.status, response_data.get("text") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Dict | ||
|
||
from attr import dataclass, ib | ||
|
||
from ..flow_object import FlowObject | ||
|
||
|
||
@dataclass | ||
class TTMMiddleware(FlowObject): | ||
"""TTM Middleware. | ||
An TTMMiddleware is used to translate text from a source language to a target language. | ||
content: | ||
``` | ||
middlewares: | ||
- id: irm_middleware | ||
type: irm | ||
method: POST | ||
url: "https://webapinet.userfoo.com/api/irm/recognize" | ||
prompt: "Given an image, give me the text in it" | ||
variables: | ||
token: token | ||
headers: | ||
Client-token: "example-token" | ||
""" | ||
|
||
method: str = ib(default=None) | ||
url: str = ib(default=None) | ||
variables: Dict[str, Any] = ib(factory=dict) | ||
cookies: Dict[str, Any] = ib(factory=dict) | ||
headers: Dict[str, Any] = ib(factory=dict) | ||
basic_auth: Dict[str, Any] = ib(factory=dict) | ||
target_language: str = ib(factory=str) | ||
source_language: str = ib(factory=str) | ||
provider: str = ib(factory=str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,3 +24,4 @@ class Middlewares(SerializableEnum): | |
IRM = "irm" | ||
LLM = "llm" | ||
ASR = "asr" | ||
TTM = "ttm" |