diff --git a/annotators/midas_classification/Dockerfile b/annotators/midas_classification/Dockerfile index 1bbd1c5b43..0afff9a0f5 100644 --- a/annotators/midas_classification/Dockerfile +++ b/annotators/midas_classification/Dockerfile @@ -16,6 +16,7 @@ COPY . /src/ WORKDIR /src RUN python -m deeppavlov install $CONFIG +RUN python -m spacy download en_core_web_sm RUN sed -i "s|$SED_ARG|g" "$CONFIG" diff --git a/annotators/midas_classification/requirements.txt b/annotators/midas_classification/requirements.txt index 8efa0c3c7d..e110c05391 100644 --- a/annotators/midas_classification/requirements.txt +++ b/annotators/midas_classification/requirements.txt @@ -4,4 +4,4 @@ sentry-sdk==0.14.2 requests==2.23.0 gunicorn==19.9.0 numpy==1.17.2 -nltk==3.2.5 +spacy==3.0.6 \ No newline at end of file diff --git a/annotators/midas_classification/server.py b/annotators/midas_classification/server.py index d1a4423638..baef684584 100644 --- a/annotators/midas_classification/server.py +++ b/annotators/midas_classification/server.py @@ -4,9 +4,9 @@ import numpy as np import sentry_sdk +import spacy from deeppavlov import build_model from flask import Flask, request, jsonify -from nltk.tokenize import sent_tokenize sentry_sdk.init(os.getenv("SENTRY_DSN")) @@ -18,6 +18,7 @@ EMPTY_SIGN = ": EMPTY >" try: + spacy_nlp = spacy.load("en_core_web_sm") model = build_model("midas_conv_bert.json", download=True) m = model(["hi"]) except Exception as e: @@ -76,6 +77,14 @@ def recombine_responses(responses, dialog_ids, n_dialogs): return final_responses +def spacy_sent_tokenize(sentence): + doc = spacy_nlp(sentence) + segments = [] + for sent in doc.sents: + segments.append(sent.text) + return segments + + @app.route("/model", methods=["POST"]) def respond(): st_time = time.time() @@ -84,17 +93,22 @@ def respond(): inputs = [] for i, dialog in enumerate(dialogs): if len(dialog["bot_utterances"]): - prev_bot_uttr_text = dialog["bot_utterances"][-1].get("text", "").lower() - tokenized_sentences = sent_tokenize(prev_bot_uttr_text) + tokenized_sentences = ( + dialog["bot_utterances"][-1].get("annotations", {}).get("sentseg", {}).get("segments", []) + ) + if len(tokenized_sentences) == 0: + prev_bot_uttr_text = dialog["bot_utterances"][-1].get("text", "").lower() + tokenized_sentences = spacy_sent_tokenize(prev_bot_uttr_text) context = tokenized_sentences[-1].lower() if len(tokenized_sentences) > 0 else "" else: context = "" if len(dialog["human_utterances"]): - curr_human_uttr_text = dialog["human_utterances"][-1].get("text", "").lower() + sentences = dialog["human_utterances"][-1].get("annotations", {}).get("sentseg", {}).get("segments", []) + if len(sentences) == 0: + sentences = spacy_sent_tokenize(dialog["human_utterances"][-1].get("text", "").lower()) else: - curr_human_uttr_text = "" + sentences = [] - sentences = sent_tokenize(curr_human_uttr_text) for sent in sentences: input_ = f"{context} {EMPTY_SIGN} {sent}" inputs.append(input_) @@ -113,12 +127,19 @@ def batch_respond(): st_time = time.time() bot_utterances = request.json["sentences"] human_utterances = request.json["last_human_utterances"] - bot_utterances_sentences = [sent_tokenize(utterance) for utterance in bot_utterances] + bot_utterances_sentences = [ + spacy_sent_tokenize(utterance) if isinstance(utterance, str) else utterance for utterance in bot_utterances + ] dialog_ids = [] inputs = [] for i, bot_utterance_sents in enumerate(bot_utterances_sentences): if human_utterances[i]: - context = sent_tokenize(human_utterances[i])[-1].lower() + if isinstance(human_utterances[i], str): + context = spacy_sent_tokenize(human_utterances[i])[-1].lower() + elif isinstance(human_utterances[i], list): + context = human_utterances[i][-1].lower() + else: + context = "" else: context = "" for utterance in bot_utterance_sents: