Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
my-master committed Sep 12, 2019
2 parents 2bd9417 + b9e2fc7 commit 34195cd
Show file tree
Hide file tree
Showing 25 changed files with 564 additions and 145 deletions.
2 changes: 2 additions & 0 deletions agent.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
TELEGRAM_TOKEN=
TELEGRAM_PROXY=
31 changes: 19 additions & 12 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from state_formatters.dp_formatters import *
from functools import partial

TELEGRAM_TOKEN = ''
TELEGRAM_PROXY = ''
from state_formatters.dp_formatters import *

DB_NAME = 'test'
HOST = '127.0.0.1'
PORT = 27017
DB_HOST = '127.0.0.1'
DB_PORT = 27017
DB_PATH = '/data/db'

MAX_WORKERS = 4

AGENT_ENV_FILE = "agent.env"

SKILLS = [
{
"name": "odqa",
"protocol": "http",
"host": "127.0.0.1",
"port": 2080,
"endpoint": "odqa",
"endpoint": "model",
"path": "odqa/ru_odqa_infer_wiki",
"env": {
"CUDA_VISIBLE_DEVICES": ""
Expand All @@ -35,30 +37,33 @@
},
"profile_handler": True,
"dockerfile": "dockerfile_skill_cpu",
"formatter": odqa_formatter
"formatter": chitchat_formatter
}
]

ANNOTATORS = [
ANNOTATORS_1 = [
{
"name": "ner",
"protocol": "http",
"host": "127.0.0.1",
"port": 2083,
"endpoint": "ner",
"endpoint": "model",
"path": "ner/ner_rus",
"env": {
"CUDA_VISIBLE_DEVICES": ""
},
"dockerfile": "dockerfile_skill_cpu",
"formatter": ner_formatter
},
}
]

ANNOTATORS_2 = [
{
"name": "sentiment",
"protocol": "http",
"host": "127.0.0.1",
"port": 2084,
"endpoint": "intents",
"endpoint": "model",
"path": "classifiers/rusentiment_cnn",
"env": {
"CUDA_VISIBLE_DEVICES": ""
Expand All @@ -68,13 +73,15 @@
}
]

ANNOTATORS_3 = []

SKILL_SELECTORS = [
{
"name": "chitchat_odqa",
"protocol": "http",
"host": "127.0.0.1",
"port": 2082,
"endpoint": "intents",
"endpoint": "model",
"path": "classifiers/rusentiment_bigru_superconv",
"env": {
"CUDA_VISIBLE_DEVICES": ""
Expand Down
11 changes: 6 additions & 5 deletions core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@


class Agent:
def __init__(self, state_manager: StateManager, preprocessor: Callable,
def __init__(self, state_manager: StateManager, preprocessors: List[Callable],
postprocessor: Callable,
skill_manager: SkillManager) -> None:
self.state_manager = state_manager
self.preprocessor = preprocessor
self.preprocessors = preprocessors
self.postprocessor = postprocessor
self.skill_manager = skill_manager

Expand Down Expand Up @@ -50,9 +50,10 @@ def __call__(self, utterances: Sequence[str], user_telegram_ids: Sequence[Hashab
return sent_responses # return text only to the users

def _update_annotations(self, me_dialogs: Sequence[Dialog]) -> None:
annotations = self.preprocessor(get_state(me_dialogs))
utterances = [dialog.utterances[-1] for dialog in me_dialogs]
self.state_manager.add_annotations(utterances, annotations)
for prep in self.preprocessors:
annotations = prep(get_state(me_dialogs))
utterances = [dialog.utterances[-1] for dialog in me_dialogs]
self.state_manager.add_annotations(utterances, annotations)

def _update_profiles(self, me_users: Sequence[Human], profiles: List[Profile]) -> None:
for me_user, profile in zip(me_users, profiles):
Expand Down
4 changes: 2 additions & 2 deletions core/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mongoengine import connect

from core.transform_config import HOST, PORT, DB_NAME
from core.transform_config import DB_HOST, DB_PORT, DB_NAME

state_storage = connect(host=HOST, port=PORT, db=DB_NAME)
state_storage = connect(host=DB_HOST, port=DB_PORT, db=DB_NAME)
5 changes: 2 additions & 3 deletions core/rest_caller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Any, Optional, Sequence, Union
from typing import Dict, List, Any, Optional, Sequence, Union, Callable

import requests

Expand Down Expand Up @@ -30,8 +30,7 @@ def __init__(self, max_workers: int = MAX_WORKERS,
def __call__(self, payload: Union[Dict, Sequence[Dict]],
names: Optional[Sequence[str]] = None,
urls: Optional[Sequence[str]] = None,
formatters = None) -> List[
Dict[str, Dict[str, Any]]]:
formatters: List[Callable] = None) -> List[Dict[str, Dict[str, Any]]]:

names = names if names is not None else self.names
urls = urls if urls is not None else self.urls
Expand Down
80 changes: 64 additions & 16 deletions core/run.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
import time
from os import getenv

from aiohttp import web
from datetime import datetime
from string import hexdigits
from threading import Thread
from multiprocessing import Process, Pipe
from multiprocessing.connection import Connection
Expand All @@ -11,8 +13,6 @@
import telebot
from telebot.types import Message, Location, User

from core.transform_config import TELEGRAM_TOKEN, TELEGRAM_PROXY

parser = argparse.ArgumentParser()
parser.add_argument("-ch", "--channel", help="run agent in telegram, cmd_client or http_client", type=str,
choices=['telegram', 'cmd_client', 'http_client'], default='cmd_client')
Expand All @@ -31,8 +31,10 @@ def _model_process(model_function: Callable, conn: Connection, batch_size: int =

while True:
batch: List[Tuple[str, Hashable]] = []
while conn.poll() and len(batch) < batch_size and time.time() - check_time <= poll_period:
while conn.poll() and len(batch) < batch_size:
batch.append(conn.recv())
if time.time() - check_time >= poll_period:
break

if not batch:
continue
Expand All @@ -46,8 +48,8 @@ def _model_process(model_function: Callable, conn: Connection, batch_size: int =

def experimental_bot(
model_function: Callable[
..., Callable[[Collection[Message], Collection[Hashable]], Collection[str]]],
token: str, proxy: Optional[str] = None, *, batch_size: int = -1, poll_period: float = 0.5):
..., Callable[[Collection[Message], Collection[Hashable]], Collection[str]]], *,
batch_size: int = -1, poll_period: float = 0.5):
"""
Args:
Expand All @@ -60,6 +62,10 @@ def experimental_bot(
Returns: None
"""

token = getenv('TELEGRAM_TOKEN')
proxy = getenv('TELEGRAM_PROXY')

if proxy is not None:
telebot.apihelper.proxy = {'https': proxy}

Expand Down Expand Up @@ -92,34 +98,48 @@ def run():
from core.rest_caller import RestCaller
from models.postprocessor import DefaultPostprocessor
from models.response_selector import ConfidenceResponseSelector
from core.transform_config import MAX_WORKERS, ANNOTATORS, SKILL_SELECTORS, SKILLS
from core.transform_config import MAX_WORKERS, ANNOTATORS, SKILL_SELECTORS, SKILLS, RESPONSE_SELECTORS

import logging

logging.getLogger('requests.packages.urllib3.connectionpool').setLevel(logging.WARNING)

state_manager = StateManager()

anno_names, anno_urls, anno_formatters = zip(
*[(a['name'], a['url'], a['formatter']) for a in ANNOTATORS])
preprocessor = RestCaller(max_workers=MAX_WORKERS, names=anno_names, urls=anno_urls,
formatters=anno_formatters)
preprocessors = []
for ants in ANNOTATORS:
if ants:
anno_names, anno_urls, anno_formatters = zip(
*[(a['name'], a['url'], a['formatter']) for a in ants])
else:
anno_names, anno_urls, anno_formatters = [], [], []
preprocessors.append(RestCaller(max_workers=MAX_WORKERS, names=anno_names, urls=anno_urls,
formatters=anno_formatters))

postprocessor = DefaultPostprocessor()
skill_caller = RestCaller(max_workers=MAX_WORKERS)
response_selector = ConfidenceResponseSelector()

if RESPONSE_SELECTORS:
rs_names, rs_urls, rs_formatters = zip(
*[(rs['name'], rs['url'], rs['formatter']) for rs in RESPONSE_SELECTORS])
response_selector = RestCaller(max_workers=MAX_WORKERS, names=rs_names, urls=rs_urls,
formatters=rs_formatters)
else:
response_selector = ConfidenceResponseSelector()

skill_selector = None
if SKILL_SELECTORS:
ss_names, ss_urls, ss_formatters = zip(
*[(selector['name'], selector['url'], selector['formatter']) for selector in
SKILL_SELECTORS])
*[(ss['name'], ss['url'], ss['formatter']) for ss in SKILL_SELECTORS])
skill_selector = RestCaller(max_workers=MAX_WORKERS, names=ss_names, urls=ss_urls,
formatters=ss_formatters)
formatters=ss_formatters)

skill_manager = SkillManager(skill_selector=skill_selector, response_selector=response_selector,
skill_caller=skill_caller,
profile_handlers=[skill['name'] for skill in SKILLS
if skill.get('profile_handler')])

agent = Agent(state_manager, preprocessor, postprocessor, skill_manager)
agent = Agent(state_manager, preprocessors, postprocessor, skill_manager)

def infer_telegram(messages: Collection[Message], dialog_ids):
utterances: List[Optional[str]] = [message.text for message in messages]
Expand Down Expand Up @@ -167,6 +187,8 @@ async def init_app():
app = web.Application()
handle_func = await api_message_processor(run())
app.router.add_post('/', handle_func)
app.router.add_get('/dialogs', users_dialogs)
app.router.add_get('/dialogs/{dialog_id}', dialog)
return app


Expand All @@ -191,9 +213,35 @@ async def api_handle(request):
return api_handle


async def users_dialogs(request):
from core.state_schema import Dialog
exist_dialogs = Dialog.objects()
result = list()
for i in exist_dialogs:
result.append(
{'id': str(i.id), 'location': i.location, 'channel_type': i.channel_type, 'user': i.user.to_dict()})
return web.json_response(result)


async def dialog(request):
from core.state_schema import Dialog
dialog_id = request.match_info['dialog_id']
if dialog_id == 'all':
dialogs = Dialog.objects()
return web.json_response([i.to_dict() for i in dialogs])
elif len(dialog_id) == 24 and all(c in hexdigits for c in dialog_id):
dialog = Dialog.objects(id__exact=dialog_id)
if not dialog:
raise web.HTTPNotFound(reason=f'dialog with id {dialog_id} is not exist')
else:
return web.json_response(dialog[0].to_dict())
else:
raise web.HTTPBadRequest(reason='dialog id should be 24-character hex string')


def main():
if CHANNEL == 'telegram':
experimental_bot(run, token=TELEGRAM_TOKEN, proxy=TELEGRAM_PROXY)
experimental_bot(run)
elif CHANNEL == 'cmd_client':
message_processor = run()
user_id = input('Provide user id: ')
Expand Down
17 changes: 11 additions & 6 deletions core/skill_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@ def __init__(self, response_selector, skill_caller, skill_selector=None, profile
self.skill_caller = skill_caller
self.skills = SKILLS
self.skill_names = [s['name'] for s in self.skills]
self.skill_responses = []

self.profile_handlers = [name for name in reversed(profile_handlers) if name in self.skill_names]
self.profile_fields = list(Human.profile.default.keys())

def __call__(self, dialogs):

user_profiles = self._get_user_profiles(self.skill_responses)
selected_skill_names, utterances, confidences = self.response_selector(self.skill_responses)
utterances = [utt if utt else NOANSWER_UTT for utt in utterances]
skill_responses = [d.utterances[-1]['selected_skills'] for d in dialogs]
user_profiles = self._get_user_profiles(skill_responses)
rs_response = self.response_selector(get_state(dialogs))
# should be a flatten list because there is always only one ResponseSelector:
selected_skill_names = list(v for d in rs_response for _, v in d.items())
utterances = []
confidences = []
for responses, selected_name in zip(skill_responses, selected_skill_names):
selected_skill = responses[selected_name]
utterances.append(selected_skill['text'] or NOANSWER_UTT)
confidences.append(selected_skill['confidence'])
return selected_skill_names, utterances, confidences, user_profiles

def _get_user_profiles(self, skill_responses) -> Optional[List[Dict]]:
Expand Down Expand Up @@ -93,5 +99,4 @@ def get_skill_responses(self, dialogs):
payloads.append(s)
skill_responses = self.skill_caller(payload=payloads, names=skill_names, urls=skill_urls,
formatters=skill_formatters)
self.skill_responses = skill_responses
return skill_responses
2 changes: 1 addition & 1 deletion core/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def add_bot_utterances(cls, dialogs: Sequence[Dialog], orig_texts: Sequence[str]
@staticmethod
def add_annotations(utterances: Sequence[Utterance], annotations: Sequence[Dict]):
for utt, ann in zip(utterances, annotations):
utt.annotations = ann
utt.annotations.update(ann)
utt.save()

@staticmethod
Expand Down
13 changes: 7 additions & 6 deletions core/transform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

from config import *

ANNOTATORS = [ANNOTATORS_1, ANNOTATORS_2, ANNOTATORS_3]

# generate component url
for service in chain(ANNOTATORS, SKILL_SELECTORS, SKILLS, RESPONSE_SELECTORS, POSTPROCESSORS):
for service in chain(*ANNOTATORS, SKILL_SELECTORS, SKILLS, RESPONSE_SELECTORS,
POSTPROCESSORS):
host = service['name'] if getenv('DPA_LAUNCHING_ENV') == 'docker' else service['host']
service['url'] = f"{service['protocol']}://{host}:{service['port']}/{service['endpoint']}"

HOST = 'mongo' if getenv('DPA_LAUNCHING_ENV') == 'docker' else HOST
TELEGRAM_TOKEN = TELEGRAM_TOKEN or getenv('TELEGRAM_TOKEN')
TELEGRAM_PROXY = TELEGRAM_PROXY or getenv('TELEGRAM_PROXY')
DB_HOST = 'mongo' if getenv('DPA_LAUNCHING_ENV') == 'docker' else DB_HOST


def _get_config_path(component_config: dict) -> dict:
Expand Down Expand Up @@ -51,8 +52,8 @@ def _get_config_path(component_config: dict) -> dict:
MAX_WORKERS = config.get('MAX_WORKERS', MAX_WORKERS)

DB_NAME = config.get('DB_NAME', DB_NAME)
HOST = config.get('HOST', HOST)
PORT = config.get('PORT', PORT)
DB_HOST = config.get('HOST', DB_HOST)
DB_PORT = config.get('PORT', DB_PORT)

for group in _component_groups:
setattr(_module, group, list(map(_get_config_path, config.get(group, []))))
Loading

0 comments on commit 34195cd

Please sign in to comment.