Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Middleware to proccess audio #91

Merged
merged 11 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions menuflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mautrix.util.logging import TraceLogger

from .flow_utils import FlowUtils
from .middlewares import HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .nodes import (
CheckTime,
Email,
Expand Down Expand Up @@ -75,7 +75,9 @@ def get_node_by_id(self, node_id: str) -> Dict | None:
self._add_node_to_cache(node)
return node

def middleware(self, middleware_id: str, room: Room) -> HTTPMiddleware | IRMMiddleware | None:
def middleware(
self, middleware_id: str, room: Room
) -> HTTPMiddleware | IRMMiddleware | ASRMiddleware | None:
middleware_model = self.flow_utils.get_middleware_by_id(middleware_id=middleware_id)
try:
middleware_type = Middlewares(middleware_model.type)
Expand All @@ -100,6 +102,12 @@ def middleware(self, middleware_id: str, room: Room) -> HTTPMiddleware | IRMMidd
middleware_initialized = LLMMiddleware(
llm_data=middleware_model, room=room, default_variables=self.flow_variables
)
elif middleware_type == Middlewares.ASR:
middleware_initialized = ASRMiddleware(
asr_middleware_content=middleware_model,
room=room,
default_variables=self.flow_variables,
)

return middleware_initialized

Expand Down
12 changes: 8 additions & 4 deletions menuflow/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,31 @@

from .middlewares.http import HTTPMiddleware
from .repository import FlowUtils as FlowUtilsModel
from .repository.middlewares import HTTPMiddleware, IRMMiddleware
from .repository.middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware
from .repository.middlewares.email import EmailServer

log: TraceLogger = logging.getLogger("menuflow.flow_utils")


class FlowUtils:
# Cache dicts
middlewares_by_id: Dict[str, HTTPMiddleware | IRMMiddleware] = {}
middlewares_by_id: Dict[str, HTTPMiddleware | IRMMiddleware | ASRMiddleware] = {}
email_servers_by_id: Dict[str, EmailServer] = {}

def __init__(self) -> None:
self.data: FlowUtilsModel = FlowUtilsModel.load_flow_utils()

def _add_middleware_to_cache(self, middleware_model: HTTPMiddleware | IRMMiddleware) -> None:
def _add_middleware_to_cache(
self, middleware_model: HTTPMiddleware | IRMMiddleware | ASRMiddleware
) -> None:
self.middlewares_by_id[middleware_model.id] = middleware_model

def _add_email_server_to_cache(self, email_server_model: EmailServer) -> None:
self.email_servers_by_id[email_server_model.server_id] = email_server_model

def get_middleware_by_id(self, middleware_id: str) -> HTTPMiddleware | IRMMiddleware | None:
def get_middleware_by_id(
self, middleware_id: str
) -> HTTPMiddleware | IRMMiddleware | ASRMiddleware | None:
"""This function retrieves a middleware by its ID from a cache or a list of middlewares.

Parameters
Expand Down
1 change: 1 addition & 0 deletions menuflow/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .asr import ASRMiddleware
from .http import HTTPMiddleware
from .irm import IRMMiddleware
from .llm import LLMMiddleware
123 changes: 123 additions & 0 deletions menuflow/middlewares/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Dict, Tuple

from aiohttp import ClientTimeout, ContentTypeError, FormData
from mautrix.util.config import RecursiveDict
from ruamel.yaml.comments import CommentedMap

from ..nodes import Base
from ..repository import ASRMiddleware as ASRMiddlewareModel
from ..room import Room


class ASRMiddleware(Base):
def __init__(
self,
asr_middleware_content: ASRMiddlewareModel,
room: Room,
default_variables: Dict,
) -> None:
Base.__init__(self, room=room, default_variables=default_variables)
self.log = self.log.getChild(asr_middleware_content.id)
self.content: ASRMiddlewareModel = asr_middleware_content

@property
def url(self) -> str:
return self.render_data(self.content.url)

@property
def headers(self) -> Dict:
return self.render_data(self.content.headers)

@property
def middleware_variables(self) -> Dict:
return self.render_data(self.content.variables)

@property
def method(self) -> Dict:
return self.render_data(self.content.method)

@property
def cookies(self) -> Dict:
return self.render_data(self.content.cookies)

@property
def provider(self) -> str:
return self.render_data(self.content.provider)

async def run(
self, extended_data: Dict, audio_url: str, audio_name: str = None
) -> Tuple[int, str]:
audio = await self.room.matrix_client.download_media(url=audio_url)
result = await self.http_request(audio=audio, audio_name=audio_name)

return result

async def http_request(self, audio, audio_name) -> Tuple[int, str]:
"""Recognize the text and return the status code and the text."""
request_body = {}
form_data = FormData()

if self.headers:
request_body["headers"] = self.headers

if audio:
form_data.add_field("audio", audio, filename=audio_name, content_type="audio/ogg")

form_data.add_field("provider", self.provider)
else:
self.log.error("Error getting the audio")
return

try:
timeout = ClientTimeout(total=self.config["menuflow.timeouts.middlewares"])
response = await self.session.request(
self.method,
self.url,
timeout=timeout,
data=form_data,
**request_body,
)
except Exception as e:
self.log.exception(f"Audio to text conversion error: {e}")
return

variables = {}

if self.cookies:
for cookie in self.cookies:
variables[cookie] = response.cookies.output(cookie)
try:
response_data = await response.json()

except ContentTypeError:
response_data = {}

if not response_data:
return response.status, None

if isinstance(response_data, dict):
# Tulir and its magic since time immemorial
serialized_data = RecursiveDict(CommentedMap(**response_data))
if self.middleware_variables:
for variable in self.middleware_variables:
try:
variables[variable] = self.render_data(
serialized_data[self.middleware_variables[variable]]
)
except KeyError:
pass

elif isinstance(response_data, str):
if self.middleware_variables:
for variable in self.middleware_variables:
try:
variables[variable] = self.render_data(response_data)
except KeyError:
pass

break

if variables:
await self.room.set_variables(variables=variables)

return response.status, await response.text()
13 changes: 10 additions & 3 deletions menuflow/nodes/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .switch import Switch

if TYPE_CHECKING:
from ..middlewares import IRMMiddleware, LLMMiddleware
from ..middlewares import ASRMiddleware, IRMMiddleware, LLMMiddleware


class Input(Switch, Message):
Expand Down Expand Up @@ -110,7 +110,6 @@ async def run(self, evt: Optional[MessageEvent]):
if not evt:
self.log.warning("A problem occurred getting message event.")
return

if self.input_type == MessageType.TEXT:
if self.middleware:
await self.middleware.run(text=evt.content.body)
Expand All @@ -127,8 +126,16 @@ async def run(self, evt: Optional[MessageEvent]):
o_connection = await Switch.run(self, generate_event=False)
else:
o_connection = await self.input_media(content=evt.content)
elif self.input_type == MessageType.AUDIO:
if self.middleware and evt.content.msgtype == MessageType.AUDIO:
audio_name = evt.content.file or "audio.ogg"
await self.middleware.run(
self, audio_url=evt.content.url, audio_name=audio_name
)
o_connection = await Switch.run(self=self, generate_event=False)
else:
o_connection = await self.input_media(content=evt.content)
elif self.input_type in [
MessageType.AUDIO,
MessageType.FILE,
MessageType.VIDEO,
]:
Expand Down
2 changes: 1 addition & 1 deletion menuflow/repository/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .flow import Flow
from .flow_utils import FlowUtils
from .middlewares import HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .nodes import (
Case,
CheckTime,
Expand Down
3 changes: 1 addition & 2 deletions menuflow/repository/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mautrix.util.logging import TraceLogger

from ..utils import Util
from .middlewares import HTTPMiddleware
from .middlewares import ASRMiddleware, HTTPMiddleware
from .nodes import CheckTime, HTTPRequest, Input, Message, Switch

log: TraceLogger = logging.getLogger("menuflow.repository.flow")
Expand All @@ -19,7 +19,6 @@
@dataclass
class Flow(SerializableAttrs):
nodes: List[Message, Input, HTTPRequest, Switch, CheckTime] = ib(factory=list)
middlewares: List[HTTPMiddleware] = ib(default=[])
flow_variables: Dict[str, Any] = ib(default={})

@classmethod
Expand Down
8 changes: 5 additions & 3 deletions menuflow/repository/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from mautrix.util.logging import TraceLogger

from ..utils import Middlewares
from .middlewares import EmailServer, HTTPMiddleware, IRMMiddleware, LLMMiddleware
from .middlewares import ASRMiddleware, EmailServer, HTTPMiddleware, IRMMiddleware, LLMMiddleware

log: TraceLogger = logging.getLogger("menuflow.repository.flow_utils")


@dataclass
class FlowUtils(SerializableAttrs):
middlewares: List[HTTPMiddleware, IRMMiddleware, LLMMiddleware] = ib(default=[])
middlewares: List[HTTPMiddleware, IRMMiddleware, LLMMiddleware, ASRMiddleware] = ib(default=[])
email_servers: List[EmailServer] = ib(default=[])

@classmethod
Expand Down Expand Up @@ -45,7 +45,7 @@ def from_dict(cls, data: dict) -> "FlowUtils":
@classmethod
def initialize_middleware_dataclass(
cls, middleware: Dict
) -> HTTPMiddleware | IRMMiddleware | LLMMiddleware | None:
) -> HTTPMiddleware | IRMMiddleware | LLMMiddleware | ASRMiddleware | None:
try:
middleware_type = Middlewares(middleware.get("type"))
except ValueError:
Expand All @@ -58,6 +58,8 @@ def initialize_middleware_dataclass(
return IRMMiddleware.from_dict(middleware)
elif middleware_type == Middlewares.LLM:
return LLMMiddleware.from_dict(middleware)
elif middleware_type == Middlewares.ASR:
return ASRMiddleware.from_dict(middleware)

@classmethod
def initialize_email_server_dataclass(cls, email_server: Dict) -> EmailServer:
Expand Down
1 change: 1 addition & 0 deletions menuflow/repository/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .asr import ASRMiddleware
from .email import EmailServer
from .http import HTTPMiddleware
from .irm import IRMMiddleware
Expand Down
52 changes: 52 additions & 0 deletions menuflow/repository/middlewares/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from typing import Any, Dict

from attr import dataclass, ib

from ..flow_object import FlowObject


@dataclass
class ASRMiddleware(FlowObject):
"""ASRMiddleware

Middleware node recognize the text from a sound file.

content:

```
- id: m1
type: asr
method: GET
url: "http://localhost:5000/asr"
provider: "azure"
cookies:
cookie1: "value1"
header:
Client-token: "client-token"
variables:
variable1: "value1"
"""

id: str = ib(default=None)
type: str = ib(default=None)
method: str = ib(default=None)
url: str = ib(default=None)
provider: str = ib(default=None)
cookies: Dict[str, Any] = ib(factory=dict)
headers: Dict[str, Any] = ib(default=None)
variables: Dict[str, Any] = ib(factory=dict)

@classmethod
def from_dict(cls, data: Dict) -> ASRMiddleware:
return cls(
id=data.get("id"),
type=data.get("type"),
method=data.get("method"),
url=data.get("url"),
provider=data.get("provider"),
variables=data.get("variables"),
cookies=data.get("cookies"),
headers=data.get("headers"),
)
2 changes: 1 addition & 1 deletion menuflow/repository/nodes/input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List
from typing import Any, Dict, List

from attr import dataclass, ib
from mautrix.types import SerializableAttrs
Expand Down
1 change: 1 addition & 0 deletions menuflow/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ class Middlewares(SerializableEnum):
BASE = "base"
IRM = "irm"
LLM = "llm"
ASR = "asr"
Loading