Skip to content

Commit

Permalink
add connectors factory (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
pushforce committed Mar 22, 2023
1 parent 658e295 commit 47eef53
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 106 deletions.
199 changes: 174 additions & 25 deletions deeppavlov_agent/core/connectors.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
import asyncio
from typing import Any, Callable, Dict, List
import os
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast
from typing_extensions import Protocol, runtime_checkable, Type, TypeGuard
from collections import defaultdict
from logging import getLogger
import os
from importlib import import_module


import sentry_sdk
import aiohttp

from .transport.settings import TRANSPORT_SETTINGS
from .transport.base import ServiceGatewayConnectorBase
from .transport.gateways.rabbitmq import RabbitMQAgentGateway
from .config import ConnectorConfig

logger = getLogger(__name__)
sentry_sdk.init(os.getenv("DP_AGENT_SENTRY_DSN")) # type: ignore


class HTTPConnector:
def __init__(self, session: aiohttp.ClientSession, url: str, timeout: float):
@runtime_checkable
class Connector(Protocol):
async def send(self, payload: Dict, callback: Callable) -> None:
...


class HTTPConnector(Connector):
def __init__(
self, session: aiohttp.ClientSession, url: str, timeout: Optional[float] = 0
):
self.session = session
self.url = url
self.timeout = aiohttp.ClientTimeout(total=timeout)

async def send(self, payload: Dict, callback: Callable):
async def send(self, payload: Dict, callback: Callable) -> None:
try:
async with self.session.post(
self.url, json=payload["payload"], timeout=self.timeout
Expand All @@ -37,11 +51,11 @@ async def send(self, payload: Dict, callback: Callable):
await callback(task_id=payload["task_id"], response=response)


class AioQueueConnector:
class AioQueueConnector(Connector):
def __init__(self, queue):
self.queue = queue

async def send(self, payload: Dict, **kwargs):
async def send(self, payload: Dict, callback: Callable) -> None:
await self.queue.put(payload)


Expand Down Expand Up @@ -82,44 +96,41 @@ def glue_tasks(self, batch):
return result


class ConfidenceResponseSelectorConnector:
async def send(self, payload: Dict, callback: Callable):
class ConfidenceResponseSelectorConnector(Connector):
async def send(self, payload: Dict, callback: Callable) -> None:
try:
response = payload["payload"]["utterances"][-1]["hypotheses"]
best_skill = max(response, key=lambda x: x["confidence"])
await callback(task_id=payload["task_id"], response=best_skill)
except Exception as e:
sentry_sdk.capture_exception(e)
logger.exception(e)

await callback(task_id=payload["task_id"], response=e)


class EventSetOutputConnector:
class EventSetOutputConnector(Connector):
def __init__(self, service_name: str):
self.service_name = service_name

async def send(self, payload, callback: Callable):
async def send(self, payload, callback: Callable) -> None:
event = payload["payload"].get("event", None)
if not event or not isinstance(event, asyncio.Event):
raise ValueError("'event' key is not presented in payload")
await callback(task_id=payload["task_id"], response=" ")
event.set()


class AgentGatewayToChannelConnector:
pass


class AgentGatewayToServiceConnector:
_to_service_callback: Callable
class AgentGatewayToServiceConnector(Connector):
_gateway: RabbitMQAgentGateway
_service_name: str

def __init__(self, to_service_callback: Callable, service_name: str):
self._to_service_callback = to_service_callback
def __init__(self, gateway: RabbitMQAgentGateway, service_name: str) -> None:
self._gateway = gateway
self._service_name = service_name

async def send(self, payload: Dict, **_kwargs):
await self._to_service_callback(
async def send(self, payload: Dict, callback: Callable) -> None:
await self._gateway.send_to_service(
payload=payload, service_name=self._service_name
)

Expand All @@ -146,21 +157,159 @@ async def send_to_service(self, payloads: List[Dict]) -> List[Any]:
return responses_batch


class PredefinedTextConnector:
class PredefinedTextConnector(Connector):
def __init__(self, response_text, annotations=None):
self.response_text = response_text
self.annotations = annotations or {}

async def send(self, payload: Dict, callback: Callable):
async def send(self, payload: Dict, callback: Callable) -> None:
await callback(
task_id=payload["task_id"],
response={"text": self.response_text, "annotations": self.annotations},
)


class PredefinedOutputConnector:
class PredefinedOutputConnector(Connector):
def __init__(self, output):
self.output = output

async def send(self, payload: Dict, callback: Callable):
async def send(self, payload: Dict, callback: Callable) -> None:
await callback(task_id=payload["task_id"], response=self.output)


_BUILT_IN_CONNECTORS = {
"PredefinedOutputConnector": PredefinedOutputConnector,
"PredefinedTextConnector": PredefinedTextConnector,
"ConfidenceResponseSelectorConnector": ConfidenceResponseSelectorConnector,
}


_SESSION: Optional[aiohttp.ClientSession] = None


def _get_session() -> aiohttp.ClientSession:
global _SESSION

if _SESSION is None:
_SESSION = aiohttp.ClientSession()

return _SESSION


_GATEWAY: Optional[RabbitMQAgentGateway] = None


def _get_gateway(
on_channel_callback=None, on_service_callback=None
) -> RabbitMQAgentGateway:
global _GATEWAY

if _GATEWAY is None:
_GATEWAY = RabbitMQAgentGateway(
config=TRANSPORT_SETTINGS,
on_service_callback=on_service_callback,
on_channel_callback=on_channel_callback,
)

return _GATEWAY


def _is_url(value: Any) -> TypeGuard[str]:
return isinstance(value, str)


def _is_urllist(value: Any) -> TypeGuard[List[str]]:
return isinstance(value, list) and all([isinstance(val, str) for val in value])


def _make_http_connector(
config: Union[ConnectorConfig, Dict]
) -> Tuple[Connector, List[QueueListenerBatchifyer]]:
url = config.get("url")
urllist = config.get("urllist")

if not isinstance(url, str) and not isinstance(urllist, list):
raise ValueError("url or urllist must be provided")

if (
"urllist" in config
or "num_workers" in config
or config.get("batch_size", 1) > 1
):
workers = []
queue: asyncio.Queue[Any] = asyncio.Queue()
batch_size = config.get("batch_size", 1)
urllist = urllist or ([url] * config.get("num_workers", 1))

for url in urllist:
workers.append(
QueueListenerBatchifyer(
_get_session(), cast(str, url), queue, batch_size
)
)

return AioQueueConnector(queue), workers

return (
HTTPConnector(
session=_get_session(), url=cast(str, url), timeout=config.get("timeout")
),
[],
)


def _make_amqp_connector(
config: Union[ConnectorConfig, Dict]
) -> Tuple[Connector, List[QueueListenerBatchifyer]]:
service_name = config.get("service_name") or config["connector_name"]
return (
AgentGatewayToServiceConnector(
gateway=_get_gateway(), service_name=service_name
),
[],
)


def _get_connector_class(class_name: str) -> Optional[Type[Any]]:
params = class_name.split(":")
connector_class: Optional[Type[Any]] = None

if len(params) == 2:
module = import_module(params[0])
connector_class = getattr(module, params[1], None)
elif len(params) == 1:
connector_class = _BUILT_IN_CONNECTORS.get(params[0])

return connector_class


def _make_connector_from_class(
config: Union[ConnectorConfig, Dict]
) -> Tuple[Connector, List[QueueListenerBatchifyer]]:
kwargs = {
key: config[key] for key in config if key not in ["protocol", "class_name"] # type: ignore
}

class_name = config["class_name"]
connector_class = _get_connector_class(class_name)

if not connector_class:
raise ValueError(f"Connector class {class_name} not found")

return connector_class(**kwargs), []


def make_connector(
config: Union[ConnectorConfig, Dict]
) -> Tuple[Connector, List[QueueListenerBatchifyer]]:
if config.get("class_name"):
return _make_connector_from_class(config)

if config.get("protocol") == "http":
return _make_http_connector(config)

# TODO: remove AMQP if it is not used
if config.get("protocol") == "AMQP":
return _make_amqp_connector(config)

raise ValueError("invalid protocol or class_name")
83 changes: 4 additions & 79 deletions deeppavlov_agent/parse_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
from collections import defaultdict
from importlib import import_module
Expand All @@ -8,13 +7,10 @@
import aiohttp

from .core.connectors import (
AgentGatewayToServiceConnector,
AioQueueConnector,
HTTPConnector,
QueueListenerBatchifyer,
PredefinedOutputConnector,
PredefinedTextConnector,
ConfidenceResponseSelectorConnector,
make_connector,
)
from .core.service import Service, simple_workflow_formatter
from .core.state_manager import StateManager
Expand Down Expand Up @@ -51,13 +47,8 @@ def __init__(self, state_manager: StateManager, config: Dict):
self.session = None
self.gateway = None
self.imported_modules: Dict[str, ModuleType] = {}
self.connectors_module: Optional[ModuleType] = None
self.formatters_module: Optional[ModuleType] = None

connectors_module_name = config.get("connectors_module", None)
if connectors_module_name:
self.connectors_module = import_module(connectors_module_name)

formatters_module_name = config.get("formatters_module", None)
if formatters_module_name:
self.formatters_module = import_module(formatters_module_name)
Expand All @@ -82,7 +73,9 @@ def parse(self, data) -> None:

self.make_service(group, name, node.config)
elif is_connector(node):
self.make_connector(_get_connector_name(node), node.config)
connector, workers = make_connector(node.config)
self.workers.extend(workers)
self.connectors[_get_connector_name(node)] = connector

def get_session(self):
if not self.session:
Expand All @@ -108,74 +101,6 @@ def get_external_module(self, module_name: str):
module = self.imported_modules[module_name]
return module

def make_connector(self, name: str, data: ConnectorConfig):
workers: List[object] = []
if data["protocol"] == "http":
connector: Any = None
workers = []
if (
"urllist" in data
or "num_workers" in data
or data.get("batch_size", 1) > 1
):
queue: asyncio.Queue[Any] = asyncio.Queue()
batch_size = data.get("batch_size", 1)
urllist = data.get(
"urllist", [data["url"]] * data.get("num_workers", 1)
)
connector = AioQueueConnector(queue)
for url in urllist:
workers.append(
QueueListenerBatchifyer(
self.get_session(), url, queue, batch_size
)
)
else:
connector = HTTPConnector(
self.get_session(), data["url"], timeout=data.get("timeout", 0)
)

elif data["protocol"] == "AMQP":
gateway = self.get_gateway()
service_name = data.get("service_name") or data["connector_name"]
connector = AgentGatewayToServiceConnector(
to_service_callback=gateway.send_to_service, service_name=service_name
)

elif data["protocol"] == "python":
params = data["class_name"].split(":")
if len(params) == 1:
if params[0] in built_in_connectors:
connector_class: Any = built_in_connectors[params[0]]
module_provided_str = "in deeppavlov_agent built in connectors"
elif self.connectors_module:
connector_class = getattr(self.connectors_module, params[0], None)
module_provided_str = (
f"in {self.connectors_module.__name__} connectors module"
)

if not connector_class:
raise ValueError(
f"Connector's python class {data['class_name']} from {name} "
f"connector was not found ({module_provided_str})"
)
elif len(params) == 2:
connector_class = getattr(
self.get_external_module(params[0]), params[1], None
)
else:
raise ValueError(
f"Expected class description in a `module.submodules:ClassName` form, "
f"but got `{data['class_name']}` (in {name} connector)"
)
others = {
k: v for k, v in data.items() if k not in {"protocol", "class_name"}
}
connector = connector_class(**others)

self.workers.extend(workers)
self.connectors[name] = connector

def make_service(self, group: Optional[str], name: str, data: ServiceConfig):
logger.debug(f"Create service: '{name}' config={data}")

Expand Down

0 comments on commit 47eef53

Please sign in to comment.