Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use yandex alice integration with agents and not just chainers #537

Merged
merged 3 commits into from Oct 19, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion deeppavlov/deep.py
Expand Up @@ -23,6 +23,7 @@
from deeppavlov.core.common.log import get_logger
from deeppavlov.download import deep_download
from deeppavlov.core.common.cross_validation import calc_cv_score
from utils.alice.alice import start_alice_server
from utils.telegram_utils.telegram_ui import interact_model_by_telegram
from utils.server_utils.server import start_model_server
from utils.ms_bot_framework_utils.server import run_ms_bf_default_agent
Expand Down Expand Up @@ -112,7 +113,10 @@ def main():
https = args.https
ssl_key = args.key
ssl_cert = args.cert
start_model_server(pipeline_config_path, alice, https, ssl_key, ssl_cert)
if alice:
start_alice_server(pipeline_config_path, https, ssl_key, ssl_cert)
else:
start_model_server(pipeline_config_path, https, ssl_key, ssl_cert)
elif args.mode == 'predict':
predict_on_stream(pipeline_config_path, args.batch_size, args.file_path)
elif args.mode == 'install':
Expand Down
59 changes: 33 additions & 26 deletions deeppavlov/skills/default_skill/default_skill.py
Expand Up @@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Optional
from typing import Tuple, Optional, List

from deeppavlov.core.common.chainer import Chainer
from deeppavlov.core.skill.skill import Skill

proposals = {
'en': 'expecting_arg: {}',
'ru': 'Пожалуйста, введите параметр {}'
}


class DefaultStatelessSkill(Skill):
"""Default stateless skill class.
Expand All @@ -26,8 +31,9 @@ class DefaultStatelessSkill(Skill):
Attributes:
model: DeepPavlov model to be wrapped into default skill instance.
"""
def __init__(self, model: Chainer, *args, **kwargs) -> None:
self.model: Chainer = model
def __init__(self, model: Chainer, lang: str='en', *args, **kwargs) -> None:
self.model = model
self.proposal: str = proposals[lang]

def __call__(self, utterances_batch: list, history_batch: list,
states_batch: Optional[list]=None) -> Tuple[list, list, list]:
Expand All @@ -53,34 +59,35 @@ def __call__(self, utterances_batch: list, history_batch: list,
batch_len = len(utterances_batch)
confidence_batch = [1.0] * batch_len

if len(self.model.in_x) > 1:
response_batch = [None] * batch_len
infer_indexes = []
response_batch: List[Optional[str]] = [None] * batch_len
infer_indexes = []

if not states_batch:
states_batch = [None] * batch_len
if not states_batch:
states_batch: List[Optional[dict]] = [None] * batch_len

for utt_i, utterance in enumerate(utterances_batch):
if not states_batch[utt_i]:
states_batch[utt_i] = {'expected_args': list(self.model.in_x), 'received_values': []}
for utt_i, utterance in enumerate(utterances_batch):
if not states_batch[utt_i]:
states_batch[utt_i] = {'expected_args': list(self.model.in_x), 'received_values': []}

if utterance:
states_batch[utt_i]['expected_args'].pop(0)
states_batch[utt_i]['received_values'].append(utterance)

if states_batch[utt_i]['expected_args']:
response = 'expecting_arg:{}'.format(states_batch[utt_i]['expected_args'][0])
response_batch[utt_i] = response
else:
infer_indexes.append(utt_i)

if infer_indexes:
infer_utterances = zip(*[tuple(states_batch[i]['received_values']) for i in infer_indexes])
infer_results = self.model(*infer_utterances)

for infer_i, infer_result in zip(infer_indexes, infer_results):
response_batch[infer_i] = infer_result
states_batch[infer_i] = None
else:
response_batch = self.model(utterances_batch)
if states_batch[utt_i]['expected_args']:
response = self.proposal.format(states_batch[utt_i]['expected_args'][0])
response_batch[utt_i] = response
else:
infer_indexes.append(utt_i)

if infer_indexes:
infer_utterances = zip(*[tuple(states_batch[i]['received_values']) for i in infer_indexes])
infer_results = self.model(*infer_utterances)

if len(self.model.out_params) > 1:
infer_results = infer_results[0]

for infer_i, infer_result in zip(infer_indexes, infer_results):
response_batch[infer_i] = infer_result
states_batch[infer_i] = None

return response_batch, confidence_batch, states_batch
Empty file added utils/alice/__init__.py
Empty file.
162 changes: 162 additions & 0 deletions utils/alice/alice.py
@@ -0,0 +1,162 @@
# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ssl
from collections import namedtuple
from pathlib import Path
from typing import Union, Optional

from flasgger import Swagger, swag_from
from flask import Flask, request, jsonify, redirect
from flask_cors import CORS

from deeppavlov.agents.default_agent.default_agent import DefaultAgent
from deeppavlov.agents.processors.default_rich_content_processor import DefaultRichContentWrapper
from deeppavlov.core.agent import Agent
from deeppavlov.core.agent.rich_content import RichMessage
from deeppavlov.core.common.log import get_logger
from deeppavlov.skills.default_skill.default_skill import DefaultStatelessSkill
from utils.server_utils.server import get_server_params, init_model

SERVER_CONFIG_FILENAME = 'server_config.json'

log = get_logger(__name__)

app = Flask(__name__)
Swagger(app)
CORS(app)


DialogID = namedtuple('DialogID', ['user_id', 'session_id'])


def interact_alice(agent: Agent):
"""
Exchange messages between basic pipelines and the Yandex.Dialogs service.
If the pipeline returns multiple values, only the first one is forwarded to Yandex.
"""
data = request.get_json()
text = data['request'].get('command', '').strip()
payload = data['request'].get('payload')

session_id = data['session']['session_id']
user_id = data['session']['user_id']
message_id = data['session']['message_id']

dialog_id = DialogID(user_id, session_id)

response = {
'response': {
'end_session': True,
'text': ''
},
"session": {
'session_id': session_id,
'message_id': message_id,
'user_id': user_id
},
'version': '1.0'
}

agent_response: Union[str, RichMessage] = agent([payload or text], [dialog_id])[0]
if isinstance(agent_response, RichMessage):
response['response']['text'] = '\n'.join([j['content']
for j in agent_response.json()
if j['type'] == 'plain_text'])
else:
response['response']['text'] = str(agent_response)

return jsonify(response), 200


def start_alice_server(model_config_path, https=False, ssl_key=None, ssl_cert=None):
if not https:
ssl_key = ssl_cert = None

server_config_dir = Path(__file__).parent
litinsky marked this conversation as resolved.
Show resolved Hide resolved
server_config_path = server_config_dir.parent / SERVER_CONFIG_FILENAME

server_params = get_server_params(server_config_path, model_config_path)
host = server_params['host']
port = server_params['port']
model_endpoint = server_params['model_endpoint']

model = init_model(model_config_path)
skill = DefaultStatelessSkill(model, lang='ru')
agent = DefaultAgent([skill], skills_processor=DefaultRichContentWrapper())

start_agent_server(agent, host, port, model_endpoint, ssl_key, ssl_cert)


def start_agent_server(agent: Agent, host: str, port: int, endpoint: str,
ssl_key: Optional[Union[str, Path]]=None,
ssl_cert: Optional[Union[str, Path]]=None):
if ssl_key and ssl_cert:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ssh_key_path = Path(ssl_key).resolve()
ssh_cert_path = Path(ssl_cert).resolve()
ssl_context.load_cert_chain(ssh_cert_path, ssh_key_path)
else:
ssl_context = None

@app.route('/')
def index():
return redirect('/apidocs/')

endpoint_description = {
'description': 'A model endpoint',
'parameters': [
{
'name': 'data',
'in': 'body',
'required': 'true',
'example': {
'meta': {
'locale': 'ru-RU',
'timezone': 'Europe/Moscow',
"client_id": 'ru.yandex.searchplugin/5.80 (Samsung Galaxy; Android 4.4)'
},
'request': {
'command': 'где ближайшее отделение',
'original_utterance': 'Алиса спроси у Сбербанка где ближайшее отделение',
'type': 'SimpleUtterance',
'markup': {
'dangerous_context': True
},
'payload': {}
},
'session': {
'new': True,
'message_id': 4,
'session_id': '2eac4854-fce721f3-b845abba-20d60',
'skill_id': '3ad36498-f5rd-4079-a14b-788652932056',
'user_id': 'AC9WC3DF6FCE052E45A4566A48E6B7193774B84814CE49A922E163B8B29881DC'
},
'version': '1.0'
}
}
],
'responses': {
"200": {
"description": "A model response"
}
}
}

@app.route(endpoint, methods=['POST'])
@swag_from(endpoint_description)
def answer():
return interact_alice(agent)

app.run(host=host, port=port, threaded=False, ssl_context=ssl_context)
79 changes: 2 additions & 77 deletions utils/server_utils/server.py
Expand Up @@ -26,7 +26,6 @@
from deeppavlov.core.common.file import read_json
from deeppavlov.core.common.log import get_logger
from deeppavlov.core.data.utils import check_nested_dict_keys, jsonify_data
from deeppavlov.core.models.component import Component

SERVER_CONFIG_FILENAME = 'server_config.json'

Expand Down Expand Up @@ -66,54 +65,6 @@ def get_server_params(server_config_path, model_config_path):
return server_params


memory = {}


def interact_alice(model: Component, params_names: list):
"""
Exchange messages between basic pipelines and the Yandex.Dialogs service.
If the pipeline returns multiple values, only the first one is forwarded to Yandex.
"""
data = request.get_json()
text = data['request']['command'].strip()

session_id = data['session']['session_id']
message_id = data['session']['message_id']
user_id = data['session']['user_id']

response = {
'response': {
'end_session': True
},
"session": {
'session_id': session_id,
'message_id': message_id,
'user_id': user_id
},
'version': '1.0'
}

params = memory.pop(session_id, [])
if text:
params.append([text])

if len(params) < len(params_names):
memory[session_id] = params
response['response']['text'] = 'Пожалуйста, введите параметр ' + params_names[len(params)]
response['response']['end_session'] = False
return jsonify(response), 200

response_text = model(*params)[0]
if not isinstance(response_text, str) and isinstance(response_text, (list, tuple)):
try:
response_text = response_text[0]
except Exception as e:
log.warning(f'Could not get the first element of `{repr(response_text)}` because of `{e}`')

response['response']['text'] = str(response_text)
return jsonify(response), 200


def interact(model: Chainer, params_names: List[str]) -> Tuple[Response, int]:
if not request.is_json:
log.error("request Content-Type header is not application/json")
Expand Down Expand Up @@ -155,7 +106,7 @@ def interact(model: Chainer, params_names: List[str]) -> Tuple[Response, int]:
return jsonify(result), 200


def start_model_server(model_config_path, alice=False, https=False, ssl_key=None, ssl_cert=None):
def start_model_server(model_config_path, https=False, ssl_key=None, ssl_cert=None):
server_config_dir = Path(__file__).parent
server_config_path = server_config_dir.parent / SERVER_CONFIG_FILENAME

Expand Down Expand Up @@ -196,35 +147,9 @@ def index():
}
}

if alice:
endpoint_description['parameters'][0]['example'] = {
'meta': {
'locale': 'ru-RU',
'timezone': 'Europe/Moscow',
"client_id": 'ru.yandex.searchplugin/5.80 (Samsung Galaxy; Android 4.4)'
},
'request': {
'command': 'где ближайшее отделение',
'original_utterance': 'Алиса спроси у Сбербанка где ближайшее отделение',
'type': 'SimpleUtterance',
'markup': {
'dangerous_context': True
},
'payload': {}
},
'session': {
'new': True,
'message_id': 4,
'session_id': '2eac4854-fce721f3-b845abba-20d60',
'skill_id': '3ad36498-f5rd-4079-a14b-788652932056',
'user_id': 'AC9WC3DF6FCE052E45A4566A48E6B7193774B84814CE49A922E163B8B29881DC'
},
'version': '1.0'
}

@app.route(model_endpoint, methods=['POST'])
@swag_from(endpoint_description)
def answer():
return interact_alice(model, model_args_names) if alice else interact(model, model_args_names)
return interact(model, model_args_names)

app.run(host=host, port=port, threaded=False, ssl_context=ssl_context)