Skip to content

Commit

Permalink
add service factory (#124)
Browse files Browse the repository at this point in the history
* add service factory

* fix flake8

* fix mypy
  • Loading branch information
pushforce committed Mar 28, 2023
1 parent 47eef53 commit 7a39623
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 152 deletions.
10 changes: 1 addition & 9 deletions deeppavlov_agent/core/connectors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import os
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast
from typing_extensions import Protocol, runtime_checkable, Type, TypeGuard
from typing_extensions import Protocol, runtime_checkable, Type
from collections import defaultdict
from logging import getLogger
from importlib import import_module
Expand Down Expand Up @@ -214,14 +214,6 @@ def _get_gateway(
return _GATEWAY


def _is_url(value: Any) -> TypeGuard[str]:
return isinstance(value, str)


def _is_urllist(value: Any) -> TypeGuard[List[str]]:
return isinstance(value, list) and all([isinstance(val, str) for val in value])


def _make_http_connector(
config: Union[ConnectorConfig, Dict]
) -> Tuple[Connector, List[QueueListenerBatchifyer]]:
Expand Down
123 changes: 123 additions & 0 deletions deeppavlov_agent/core/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from typing import Dict, Any, Optional, Callable, Union, Type, Set, List
from importlib import import_module

from typing_extensions import Literal

from .config import ServiceConfig
from .connectors import Connector
from .state_manager import BaseStateManager
from ..state_formatters import all_formatters


class Service:
def __init__(
self,
Expand Down Expand Up @@ -64,3 +75,115 @@ def apply_response_formatter(self, payload):

def simple_workflow_formatter(workflow_record):
return workflow_record["dialog"].to_dict()


def _get_connector(
service_name: str, config: ServiceConfig, connectors: Dict[str, Connector]
) -> Connector:
connector_conf = config.get("connector", None)

connector: Optional[Connector] = None

if isinstance(connector_conf, str):
connector = connectors.get(connector_conf, None)
elif isinstance(connector_conf, dict):
connector = connectors.get(service_name, None)

if connector is None:
raise ValueError(f"connector in pipeline.{service_name} is not declared")

return connector


def _get_state_manager_method(
service_name: str, config: ServiceConfig, state_manager: BaseStateManager
) -> Optional[Callable[..., Any]]:
sm_method_name = config.get("state_manager_method", None)
sm_method: Optional[Callable[..., Any]] = None

if sm_method_name is not None:
sm_method = getattr(state_manager, sm_method_name, None)

if not sm_method:
raise ValueError(
f"state manager doesn't have a method {sm_method_name} (declared in {service_name})"
)

return sm_method


FormatterType = Union[Literal["dialog_formatter"], Literal["response_formatter"]]


def _get_formatter_class(class_name) -> Optional[Type[Any]]:
params = class_name.split(":")
formatter_class = None

if len(params) == 2:
module = import_module(params[0])
formatter_class = getattr(module, params[1], None)

return formatter_class


def _get_formatter(
formatter_type: FormatterType, service_name: str, config: ServiceConfig
) -> Optional[Callable[..., Any]]:
formatter_name = config.get(formatter_type, None)

if formatter_name is None:
return None

formatter: Optional[Callable[..., Any]]

if formatter_name in all_formatters:
formatter = all_formatters[formatter_name]
else:
formatter = _get_formatter_class(formatter_name)

if not formatter:
raise ValueError(
f"{formatter_type} {formatter_name} doesn't exist (declared in {service_name})"
)

return formatter


def _merge_service_names(
config_service_names: Union[Set, List[str]], service_names: Dict[str, Set[str]]
) -> Set[str]:
result = set()

for sn in config_service_names:
result.update(service_names.get(sn, set()))

return result


def make_service(
*,
name: str,
group: Optional[str] = None,
state_manager: BaseStateManager,
connectors: Dict[str, Connector],
service_names: Dict[str, Set[str]],
config: ServiceConfig,
) -> Service:
service_name = ".".join([i for i in [group, name] if i])
connector = _get_connector(service_name, config, connectors)

return Service(
name=service_name,
connector_func=connector.send,
state_processor_method=_get_state_manager_method(name, config, state_manager),
workflow_formatter=simple_workflow_formatter,
dialog_formatter=_get_formatter("dialog_formatter", service_name, config),
response_formatter=_get_formatter("response_formatter", service_name, config),
names_previous_services=_merge_service_names(
config.get("previous_services", set()), service_names
),
names_required_previous_services=_merge_service_names(
config.get("required_previous_services", set()), service_names
),
tags=config.get("tags", []),
)
9 changes: 7 additions & 2 deletions deeppavlov_agent/core/state_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from copy import deepcopy
from typing import Dict

from typing_extensions import Protocol
from datetime import datetime

from .state_schema import Bot, BotUtterance, Dialog, Human, HumanUtterance


class BaseStateManager(Protocol):
pass


# TODO: fix types
class StateManager:
class StateManager(BaseStateManager):
def __init__(self, db):
self._db = db

Expand Down Expand Up @@ -219,6 +224,6 @@ async def add_annotation_and_reset_human_attributes_for_first_turn(
}


class FakeStateManager:
class FakeStateManager(BaseStateManager):
async def add_annotation(self, dialog: Dialog, payload: Dict, label: str, **kwargs):
pass
160 changes: 22 additions & 138 deletions deeppavlov_agent/parse_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import logging
from collections import defaultdict
from importlib import import_module
from typing import Dict, List, Set, Optional, Any
from types import ModuleType

import aiohttp

from .core.connectors import (
PredefinedOutputConnector,
PredefinedTextConnector,
ConfidenceResponseSelectorConnector,
make_connector,
)
from .core.service import Service, simple_workflow_formatter
from .core.connectors import make_connector
from .core.service import Service, make_service
from .core.state_manager import StateManager
from .core.transport.mapping import GATEWAYS_MAP
from .core.transport.settings import TRANSPORT_SETTINGS
from .core.config import (
parse as parse_conf,
is_service,
Expand All @@ -24,16 +14,9 @@
ServiceConfig,
Node,
)
from .state_formatters import all_formatters

logger = logging.getLogger(__name__)

built_in_connectors = {
"PredefinedOutputConnector": PredefinedOutputConnector,
"PredefinedTextConnector": PredefinedTextConnector,
"ConfidenceResponseSelectorConnector": ConfidenceResponseSelectorConnector,
}


class PipelineConfigParser:
def __init__(self, state_manager: StateManager, config: Dict):
Expand All @@ -47,11 +30,6 @@ def __init__(self, state_manager: StateManager, config: Dict):
self.session = None
self.gateway = None
self.imported_modules: Dict[str, ModuleType] = {}
self.formatters_module: Optional[ModuleType] = None

formatters_module_name = config.get("formatters_module", None)
if formatters_module_name:
self.formatters_module = import_module(formatters_module_name)

self.parse(config)

Expand All @@ -71,123 +49,29 @@ def parse(self, data) -> None:
self.services_names[group].add(name)
self.services_names[name].add(name)

self.make_service(group, name, node.config)
logger.debug(f"Create service: '{name}' config={node.config}")

service = make_service(
name=name,
group=group,
state_manager=self.state_manager,
connectors=self.connectors,
service_names=self.services_names,
config=node.config,
)

if service.is_last_chance():
self.last_chance_service = service
elif service.is_timeout():
self.timeout_service = service
else:
self.services.append(service)
elif is_connector(node):
name = _get_connector_name(node)
logger.debug(f"Create connector: '{name}' config={node.config}")
connector, workers = make_connector(node.config)
self.workers.extend(workers)
self.connectors[_get_connector_name(node)] = connector

def get_session(self):
if not self.session:
self.session = aiohttp.ClientSession()
return self.session

def get_gateway(self, on_channel_callback=None, on_service_callback=None):
if not self.gateway:
transport_type = TRANSPORT_SETTINGS["transport"]["type"]
gateway_cls = GATEWAYS_MAP[transport_type]["agent"]
self.gateway = gateway_cls(
config=TRANSPORT_SETTINGS,
on_service_callback=on_service_callback,
on_channel_callback=on_channel_callback,
)
return self.gateway

def get_external_module(self, module_name: str):
if module_name not in self.imported_modules:
module = import_module(module_name)
self.imported_modules[module_name] = module
else:
module = self.imported_modules[module_name]
return module

def make_service(self, group: Optional[str], name: str, data: ServiceConfig):
logger.debug(f"Create service: '{name}' config={data}")

def check_ext_module(class_name):
params = class_name.split(":")
formatter_class = None
if len(params) == 2:
formatter_class = getattr(
self.get_external_module(params[0]), params[1], None
)
elif len(params) == 1 and self.formatters_module:
formatter_class = getattr(self.formatters_module, params[0], None)
return formatter_class

connector_data = data.get("connector", None)
service_name = ".".join([i for i in [group, name] if i])
if "workflow_formatter" in data and not data["workflow_formatter"]:
workflow_formatter = None
else:
workflow_formatter = simple_workflow_formatter
connector = None
if isinstance(connector_data, str):
connector = self.connectors.get(connector_data, None)
elif isinstance(connector_data, dict):
connector = self.connectors.get(service_name, None)
if not connector:
raise ValueError(f"connector in pipeline.{service_name} is not declared")

sm_data = data.get("state_manager_method", None)
if sm_data:
sm_method = getattr(self.state_manager, sm_data, None)
if not sm_method:
raise ValueError(
f"state manager doesn't have a method {sm_data} (declared in {service_name})"
)
else:
sm_method = None

dialog_formatter = None
response_formatter = None

dialog_formatter_name = data.get("dialog_formatter", None)
response_formatter_name = data.get("response_formatter", None)
if dialog_formatter_name:
if dialog_formatter_name in all_formatters:
dialog_formatter = all_formatters[dialog_formatter_name]
else:
dialog_formatter = check_ext_module(dialog_formatter_name)
if not dialog_formatter:
raise ValueError(
f"formatter {dialog_formatter_name} doesn't exist (declared in {service_name})"
)
if response_formatter_name:
if response_formatter_name in all_formatters:
response_formatter = all_formatters[response_formatter_name]
else:
response_formatter = check_ext_module(response_formatter_name)
if not response_formatter:
raise ValueError(
f"formatter {response_formatter_name} doesn't exist (declared in {service_name})"
)

names_previous_services = set()
for sn in data.get("previous_services", set()):
names_previous_services.update(self.services_names.get(sn, set()))
names_required_previous_services = set()
for sn in data.get("required_previous_services", set()):
names_required_previous_services.update(self.services_names.get(sn, set()))
tags = data.get("tags", [])
service = Service(
name=service_name,
connector_func=connector.send,
state_processor_method=sm_method,
tags=tags,
names_previous_services=names_previous_services,
names_required_previous_services=names_required_previous_services,
workflow_formatter=workflow_formatter,
dialog_formatter=dialog_formatter,
response_formatter=response_formatter,
label=name,
)
if service.is_last_chance():
self.last_chance_service = service
elif service.is_timeout():
self.timeout_service = service
else:
self.services.append(service)
self.connectors[name] = connector


def _get_connector_name(node: Node[ConnectorConfig]) -> str:
Expand Down
1 change: 0 additions & 1 deletion deeppavlov_agent/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"db_config": "db_conf.json",
"overwrite_last_chance": None,
"overwrite_timeout": None,
"formatters_module": None,
"response_logger": True,
"time_limit": 0,
"output_formatter": http_api_output_formatter,
Expand Down
1 change: 0 additions & 1 deletion deeppavlov_agent/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ agent:
db_config: db_conf.json
overwrite_last_chance: null
overwrite_timeout: null
formatters_module: null
enable_response_logger: true
response_time_limit: 0
output_formatter: deeppavlov_agent.state_formatters.output_formatters.http_api_output_formatter
Expand Down

0 comments on commit 7a39623

Please sign in to comment.