Skip to content

Commit

Permalink
Merge pull request #91 from iKonoTelecomunicaciones/83-midleware-to-p…
Browse files Browse the repository at this point in the history
…roccess-audio

Middleware to proccess audio
  • Loading branch information
jcardenas3 committed Feb 9, 2024
2 parents 300ea4e + 42f3e5e commit 995df0d
Show file tree
Hide file tree
Showing 12 changed files with 214 additions and 16 deletions.
12 changes: 10 additions & 2 deletions menuflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mautrix.util.logging import TraceLogger

from .flow_utils import FlowUtils
from .middlewares import HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .nodes import (
CheckTime,
Email,
Expand Down Expand Up @@ -75,7 +75,9 @@ def get_node_by_id(self, node_id: str) -> Dict | None:
self._add_node_to_cache(node)
return node

def middleware(self, middleware_id: str, room: Room) -> HTTPMiddleware | IRMMiddleware | None:
def middleware(
self, middleware_id: str, room: Room
) -> HTTPMiddleware | IRMMiddleware | ASRMiddleware | None:
middleware_model = self.flow_utils.get_middleware_by_id(middleware_id=middleware_id)
try:
middleware_type = Middlewares(middleware_model.type)
Expand All @@ -100,6 +102,12 @@ def middleware(self, middleware_id: str, room: Room) -> HTTPMiddleware | IRMMidd
middleware_initialized = LLMMiddleware(
llm_data=middleware_model, room=room, default_variables=self.flow_variables
)
elif middleware_type == Middlewares.ASR:
middleware_initialized = ASRMiddleware(
asr_middleware_content=middleware_model,
room=room,
default_variables=self.flow_variables,
)

return middleware_initialized

Expand Down
12 changes: 8 additions & 4 deletions menuflow/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,31 @@

from .middlewares.http import HTTPMiddleware
from .repository import FlowUtils as FlowUtilsModel
from .repository.middlewares import HTTPMiddleware, IRMMiddleware
from .repository.middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware
from .repository.middlewares.email import EmailServer

log: TraceLogger = logging.getLogger("menuflow.flow_utils")


class FlowUtils:
# Cache dicts
middlewares_by_id: Dict[str, HTTPMiddleware | IRMMiddleware] = {}
middlewares_by_id: Dict[str, HTTPMiddleware | IRMMiddleware | ASRMiddleware] = {}
email_servers_by_id: Dict[str, EmailServer] = {}

def __init__(self) -> None:
self.data: FlowUtilsModel = FlowUtilsModel.load_flow_utils()

def _add_middleware_to_cache(self, middleware_model: HTTPMiddleware | IRMMiddleware) -> None:
def _add_middleware_to_cache(
self, middleware_model: HTTPMiddleware | IRMMiddleware | ASRMiddleware
) -> None:
self.middlewares_by_id[middleware_model.id] = middleware_model

def _add_email_server_to_cache(self, email_server_model: EmailServer) -> None:
self.email_servers_by_id[email_server_model.server_id] = email_server_model

def get_middleware_by_id(self, middleware_id: str) -> HTTPMiddleware | IRMMiddleware | None:
def get_middleware_by_id(
self, middleware_id: str
) -> HTTPMiddleware | IRMMiddleware | ASRMiddleware | None:
"""This function retrieves a middleware by its ID from a cache or a list of middlewares.
Parameters
Expand Down
1 change: 1 addition & 0 deletions menuflow/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .asr import ASRMiddleware
from .http import HTTPMiddleware
from .irm import IRMMiddleware
from .llm import LLMMiddleware
123 changes: 123 additions & 0 deletions menuflow/middlewares/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Dict, Tuple

from aiohttp import ClientTimeout, ContentTypeError, FormData
from mautrix.util.config import RecursiveDict
from ruamel.yaml.comments import CommentedMap

from ..nodes import Base
from ..repository import ASRMiddleware as ASRMiddlewareModel
from ..room import Room


class ASRMiddleware(Base):
def __init__(
self,
asr_middleware_content: ASRMiddlewareModel,
room: Room,
default_variables: Dict,
) -> None:
Base.__init__(self, room=room, default_variables=default_variables)
self.log = self.log.getChild(asr_middleware_content.id)
self.content: ASRMiddlewareModel = asr_middleware_content

@property
def url(self) -> str:
return self.render_data(self.content.url)

@property
def headers(self) -> Dict:
return self.render_data(self.content.headers)

@property
def middleware_variables(self) -> Dict:
return self.render_data(self.content.variables)

@property
def method(self) -> Dict:
return self.render_data(self.content.method)

@property
def cookies(self) -> Dict:
return self.render_data(self.content.cookies)

@property
def provider(self) -> str:
return self.render_data(self.content.provider)

async def run(
self, extended_data: Dict, audio_url: str, audio_name: str = None
) -> Tuple[int, str]:
audio = await self.room.matrix_client.download_media(url=audio_url)
result = await self.http_request(audio=audio, audio_name=audio_name)

return result

async def http_request(self, audio, audio_name) -> Tuple[int, str]:
"""Recognize the text and return the status code and the text."""
request_body = {}
form_data = FormData()

if self.headers:
request_body["headers"] = self.headers

if audio:
form_data.add_field("audio", audio, filename=audio_name, content_type="audio/ogg")

form_data.add_field("provider", self.provider)
else:
self.log.error("Error getting the audio")
return

try:
timeout = ClientTimeout(total=self.config["menuflow.timeouts.middlewares"])
response = await self.session.request(
self.method,
self.url,
timeout=timeout,
data=form_data,
**request_body,
)
except Exception as e:
self.log.exception(f"Audio to text conversion error: {e}")
return

variables = {}

if self.cookies:
for cookie in self.cookies:
variables[cookie] = response.cookies.output(cookie)
try:
response_data = await response.json()

except ContentTypeError:
response_data = {}

if not response_data:
return response.status, None

if isinstance(response_data, dict):
# Tulir and its magic since time immemorial
serialized_data = RecursiveDict(CommentedMap(**response_data))
if self.middleware_variables:
for variable in self.middleware_variables:
try:
variables[variable] = self.render_data(
serialized_data[self.middleware_variables[variable]]
)
except KeyError:
pass

elif isinstance(response_data, str):
if self.middleware_variables:
for variable in self.middleware_variables:
try:
variables[variable] = self.render_data(response_data)
except KeyError:
pass

break

if variables:
await self.room.set_variables(variables=variables)

return response.status, await response.text()
13 changes: 10 additions & 3 deletions menuflow/nodes/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .switch import Switch

if TYPE_CHECKING:
from ..middlewares import IRMMiddleware, LLMMiddleware
from ..middlewares import ASRMiddleware, IRMMiddleware, LLMMiddleware


class Input(Switch, Message):
Expand Down Expand Up @@ -110,7 +110,6 @@ async def run(self, evt: Optional[MessageEvent]):
if not evt:
self.log.warning("A problem occurred getting message event.")
return

if self.input_type == MessageType.TEXT:
if self.middleware:
await self.middleware.run(text=evt.content.body)
Expand All @@ -127,8 +126,16 @@ async def run(self, evt: Optional[MessageEvent]):
o_connection = await Switch.run(self, generate_event=False)
else:
o_connection = await self.input_media(content=evt.content)
elif self.input_type == MessageType.AUDIO:
if self.middleware and evt.content.msgtype == MessageType.AUDIO:
audio_name = evt.content.file or "audio.ogg"
await self.middleware.run(
self, audio_url=evt.content.url, audio_name=audio_name
)
o_connection = await Switch.run(self=self, generate_event=False)
else:
o_connection = await self.input_media(content=evt.content)
elif self.input_type in [
MessageType.AUDIO,
MessageType.FILE,
MessageType.VIDEO,
]:
Expand Down
2 changes: 1 addition & 1 deletion menuflow/repository/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .flow import Flow
from .flow_utils import FlowUtils
from .middlewares import HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .nodes import (
Case,
CheckTime,
Expand Down
3 changes: 1 addition & 2 deletions menuflow/repository/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mautrix.util.logging import TraceLogger

from ..utils import Util
from .middlewares import HTTPMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware
from .nodes import CheckTime, HTTPRequest, Input, Message, Switch

log: TraceLogger = logging.getLogger("menuflow.repository.flow")
Expand All @@ -19,7 +19,6 @@
@dataclass
class Flow(SerializableAttrs):
nodes: List[Message, Input, HTTPRequest, Switch, CheckTime] = ib(factory=list)
middlewares: List[HTTPMiddleware] = ib(default=[])
flow_variables: Dict[str, Any] = ib(default={})

@classmethod
Expand Down
8 changes: 5 additions & 3 deletions menuflow/repository/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from mautrix.util.logging import TraceLogger

from ..utils import Middlewares
from .middlewares import EmailServer, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, EmailServer, HTTPMiddleware, IRMMiddleware, LLMMiddleware

log: TraceLogger = logging.getLogger("menuflow.repository.flow_utils")


@dataclass
class FlowUtils(SerializableAttrs):
middlewares: List[HTTPMiddleware, IRMMiddleware, LLMMiddleware] = ib(default=[])
middlewares: List[HTTPMiddleware, IRMMiddleware, LLMMiddleware, ASRMiddleware] = ib(default=[])
email_servers: List[EmailServer] = ib(default=[])

@classmethod
Expand Down Expand Up @@ -45,7 +45,7 @@ def from_dict(cls, data: dict) -> "FlowUtils":
@classmethod
def initialize_middleware_dataclass(
cls, middleware: Dict
) -> HTTPMiddleware | IRMMiddleware | LLMMiddleware | None:
) -> HTTPMiddleware | IRMMiddleware | LLMMiddleware | ASRMiddleware | None:
try:
middleware_type = Middlewares(middleware.get("type"))
except ValueError:
Expand All @@ -58,6 +58,8 @@ def initialize_middleware_dataclass(
return IRMMiddleware.from_dict(middleware)
elif middleware_type == Middlewares.LLM:
return LLMMiddleware.from_dict(middleware)
elif middleware_type == Middlewares.ASR:
return ASRMiddleware.from_dict(middleware)

@classmethod
def initialize_email_server_dataclass(cls, email_server: Dict) -> EmailServer:
Expand Down
1 change: 1 addition & 0 deletions menuflow/repository/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .asr import ASRMiddleware
from .email import EmailServer
from .http import HTTPMiddleware
from .irm import IRMMiddleware
Expand Down
52 changes: 52 additions & 0 deletions menuflow/repository/middlewares/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from typing import Any, Dict

from attr import dataclass, ib

from ..flow_object import FlowObject


@dataclass
class ASRMiddleware(FlowObject):
"""ASRMiddleware
Middleware node recognize the text from a sound file.
content:
```
- id: m1
type: asr
method: GET
url: "http://localhost:5000/asr"
provider: "azure"
cookies:
cookie1: "value1"
header:
Client-token: "client-token"
variables:
variable1: "value1"
"""

id: str = ib(default=None)
type: str = ib(default=None)
method: str = ib(default=None)
url: str = ib(default=None)
provider: str = ib(default=None)
cookies: Dict[str, Any] = ib(factory=dict)
headers: Dict[str, Any] = ib(default=None)
variables: Dict[str, Any] = ib(factory=dict)

@classmethod
def from_dict(cls, data: Dict) -> ASRMiddleware:
return cls(
id=data.get("id"),
type=data.get("type"),
method=data.get("method"),
url=data.get("url"),
provider=data.get("provider"),
variables=data.get("variables"),
cookies=data.get("cookies"),
headers=data.get("headers"),
)
2 changes: 1 addition & 1 deletion menuflow/repository/nodes/input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List
from typing import Any, Dict, List

from attr import dataclass, ib
from mautrix.types import SerializableAttrs
Expand Down
1 change: 1 addition & 0 deletions menuflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ class Middlewares(SerializableEnum):
BASE = "base"
IRM = "irm"
LLM = "llm"
ASR = "asr"

0 comments on commit 995df0d

Please sign in to comment.