-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat/knowledge grounding service (#44)
* Create requirements.txt * feat: add knowledge grounding service * knowledge grounding service: update runtests.sh * knowledge grounding service: update configs * speedup fix * update kubernetes config file * chmod +x test.sh * codestyle fixes * fix: codestyle * fix: formatting * fix: app route respond * feat: knowledge grounding service to gpu with quest gen * fix: batch processing * fix: responses' batch is a list of str * feat/knowledge_grounding_service: change nltk downloader in dockerfile for safe way * added service to proxy.yml * fix: try-except scope and docker-compose wait-hosts * fix: typo * fix: empty knowledge handling Co-authored-by: dilyararimovna <dilyara.rimovna@gmail.com> Co-authored-by: Денис Кузнецов <kuznetsov.den.p@gmail.com>
- Loading branch information
1 parent
b6a1b05
commit 54cac98
Showing
13 changed files
with
287 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime | ||
|
||
RUN apt-get update && apt-get install -y --allow-unauthenticated wget && rm -rf /var/lib/apt/lists/* | ||
|
||
WORKDIR /src | ||
|
||
#create and activate venv | ||
#ENV VIRTUAL_ENV=/parlaivenv | ||
#RUN python3 -m venv $VIRTUAL_ENV | ||
#ENV PATH="$VIRTUAL_ENV/bin:$PATH" | ||
|
||
# install parlai | ||
RUN pip install parlai | ||
|
||
#create dir for redditgk task | ||
RUN mkdir -p /opt/conda/lib/python3.7/site-packages/parlai/tasks/redditgk | ||
#create dir to use default parlai DATAPATH | ||
RUN mkdir -p /opt/conda/lib/python3.7/site-packages/data | ||
#create dir for data file for redditgk task | ||
RUN mkdir -p /opt/conda/lib/python3.7/site-packages/data/redditgk | ||
#create dir for wow model file | ||
RUN mkdir -p /opt/conda/lib/python3.7/site-packages/data/models/wizard_of_wikipedia | ||
#create dir for courier agent | ||
RUN mkdir -p /opt/conda/lib/python3.7/site-packages/parlai/agents/courier | ||
# | ||
|
||
#wget redditgk scripts from cloud | ||
#move task files to tasks/redditgk | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/redditgk/__init__.py -q -P /opt/conda/lib/python3.7/site-packages/parlai/tasks/redditgk | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/redditgk/agents.py -q -P /opt/conda/lib/python3.7/site-packages/parlai/tasks/redditgk | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/redditgk/worlds.py -q -P /opt/conda/lib/python3.7/site-packages/parlai/tasks/redditgk | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/redditgk/test.py -q -P /opt/conda/lib/python3.7/site-packages/parlai/tasks/redditgk | ||
|
||
#delete old task_list, copy new from cloud | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/task_list.py -q -O /opt/conda/lib/python3.7/site-packages/parlai/tasks/task_list.py | ||
|
||
#wget courier agent scripts from cloud | ||
#move agent files to agents/courier | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/courier/__init__.py -q -P /opt/conda/lib/python3.7/site-packages/parlai/agents/courier | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/courier/courier.py -q -P /opt/conda/lib/python3.7/site-packages/parlai/agents/courier | ||
|
||
#unzip jsons to the DATAPATH/redditgk | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/parlai_redditgk_data.tar.gz -q -P /opt/conda/lib/python3.7/site-packages/data/redditgk | ||
RUN tar -xvzf /opt/conda/lib/python3.7/site-packages/data/redditgk/parlai_redditgk_data.tar.gz -C /opt/conda/lib/python3.7/site-packages/data/redditgk | ||
#RUN rm parlai_redditgk_data.tar.gz | ||
|
||
#get wow model tar.gz | ||
RUN wget http://lnsigo.mipt.ru/export/alexaprize_data/parlai_grounding_knowledge/end2end_generator_0.tar.gz -q -P /opt/conda/lib/python3.7/site-packages/data/models/wizard_of_wikipedia | ||
|
||
WORKDIR /src | ||
|
||
COPY ./requirements.txt /src/requirements.txt | ||
RUN pip install -r /src/requirements.txt | ||
RUN python -c "import nltk; nltk.download('punkt')" | ||
|
||
COPY . /src | ||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:8083 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
flask==1.1.1 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk[flask]==0.14.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import logging | ||
import os | ||
import random | ||
import time | ||
|
||
import sentry_sdk | ||
import torch | ||
from flask import Flask, request, jsonify | ||
from parlai.core.params import ParlaiParser | ||
from parlai.core.agents import create_agent | ||
from parlai.core.worlds import create_task | ||
from parlai.core.script import ParlaiScript, register_script | ||
from parlai.agents.courier.courier import CourierAgent | ||
from sentry_sdk.integrations.flask import FlaskIntegration | ||
|
||
sentry_sdk.init(dsn=os.getenv('SENTRY_DSN'), integrations=[FlaskIntegration()]) | ||
|
||
|
||
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | ||
level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
cuda = torch.cuda.is_available() | ||
if cuda: | ||
torch.cuda.set_device(0) # singe gpu | ||
device = torch.device('cuda') | ||
else: | ||
device = torch.device('cpu') | ||
|
||
logger.info(f'knowledge grounding is set to run on {device}') | ||
|
||
|
||
logger.info('knowledge grounding script is preparing...') | ||
|
||
|
||
@register_script('get model response') | ||
class GetModelResponse(ParlaiScript): | ||
@classmethod | ||
def setup_args(cls): | ||
parser = ParlaiParser(True, True, 'Get response from model in knowledge grounded conversation') | ||
parser.add_argument( | ||
'-it', | ||
'--interactive-task', | ||
type='bool', | ||
default=True, | ||
help='Create interactive version of task', | ||
) | ||
parser.add_argument( | ||
'--user-input-topic', | ||
type=str, | ||
default='', | ||
help='User input topic', | ||
) | ||
parser.add_argument( | ||
'--user-input-knowledge', | ||
type=str, | ||
default='', | ||
help='User input knowledge', | ||
) | ||
parser.add_argument( | ||
'--user-input-text', | ||
type=str, | ||
default='', | ||
help='User input text', | ||
) | ||
parser.add_argument( | ||
'--user-input-history', | ||
type=str, | ||
default='', | ||
help='User input history', | ||
) | ||
parser.set_defaults(interactive_mode=True, task='interactive') | ||
return parser | ||
|
||
def run(self): | ||
opt = self.opt | ||
if isinstance(self.opt, ParlaiParser): | ||
logging.error('opt should be passed, not Parser') | ||
opt = self.opt.parse_args() | ||
# Create model and courier and assign them to the specified task | ||
agent = create_agent(opt, requireModelExists=True) | ||
courier_agent = CourierAgent(opt) | ||
world = create_task(opt, [courier_agent, agent]) | ||
user_input = { | ||
'topic': opt['user_input_topic'], | ||
'knowledge': opt['user_input_knowledge'], | ||
'text': opt['user_input_text'], | ||
'history': opt['user_input_history'].split('\n') if opt['user_input_history'] else [''] | ||
} | ||
response = world.parley(user_input) | ||
courier_agent.finished = True | ||
return response['text'] | ||
|
||
|
||
try: | ||
GetModelResponse.main( | ||
task='redditgk', | ||
datatype='test', | ||
user_input_topic='', | ||
user_input_knowledge='.', | ||
user_input_text='hi', | ||
user_input_history='', | ||
split_lines=False, | ||
model_file='zoo:wizard_of_wikipedia/end2end_generator/model', | ||
) | ||
except Exception as e: | ||
sentry_sdk.capture_exception(e) | ||
logger.exception(e) | ||
|
||
logger.info(f'knowledge grounding script is ready') | ||
|
||
app = Flask(__name__) | ||
|
||
|
||
@app.route("/respond", methods=['POST']) | ||
def respond(): | ||
batch = request.json['batch'] | ||
responses = [] | ||
random.seed(42) | ||
for sample in batch: | ||
response = "" | ||
st_time = time.time() | ||
if sample['knowledge']: | ||
try: | ||
response = GetModelResponse.main( | ||
task='redditgk', | ||
datatype='test', | ||
user_input_topic=sample['topic'], | ||
user_input_knowledge=sample['knowledge'], | ||
user_input_text=sample['text'], | ||
user_input_history=sample['history'], | ||
split_lines=False, | ||
model_file='zoo:wizard_of_wikipedia/end2end_generator/model', | ||
) | ||
except Exception as e: | ||
sentry_sdk.capture_exception(e) | ||
logger.exception(e) | ||
logger.info(f'Current sample response: {response}') | ||
else: | ||
logger.info(f'Sample knowledge is empty, returning empty response') | ||
total_time = time.time() - st_time | ||
logger.info(f'knowledge grounding: one sample from batch exec time: {total_time:.3f}s') | ||
responses.append(response) | ||
return jsonify(responses) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import requests | ||
|
||
|
||
def test_knowledge_grounding(): | ||
url = 'http://0.0.0.0:8083/respond' | ||
|
||
topic = "financial endowment" | ||
knowledge = "<h1> financial endowment </h1> <h2> <anchor> criticisms </anchor> </h2> <p> officials in charge of " \ | ||
"the endowments of some universities have been criticized for ' hoarding ' and reinvesting too much " \ | ||
"of the endowment's income . \ngiven a historical endowment performance of 10 – 11 % , and a payout " \ | ||
"rate of 5 % , around half of the endowment's income is reinvested . \nroughly 3 % of the " \ | ||
"reinvestment is used to keep pace with inflation , leaving an inflation-adjusted 2 % annual " \ | ||
"growth of the endowment . \nof course , many endowments fail to earn 10 – 11 % . \n</p> <p> " \ | ||
"two arguments against inflation-adjusted endowment growth are : </p> <h3> hoarding money </h3> <p> " \ | ||
"large endowments have been criticized for ' hoarding ' money . \nmost philanthropies are required " \ | ||
"by federal law to distribute 5 % of their assets per year , but university endowments are not " \ | ||
"required to spend anything . \nmany universities with very large endowments would require less " \ | ||
"than 5 % to pay full tuition for all their students . \nfor example , it has been estimated that " \ | ||
"if in 2006 all the harvard students had paid the maximum in tuition and fees , it would have " \ | ||
"amounted to less than $ 300 million . \nin 2007 , if harvard <h3> size </h3> <p> financial " \ | ||
"endowments range in size depending on the size of the institution and the level of community " \ | ||
"support . \nat the large end of the spectrum , the total endowment can be over one billion " \ | ||
"dollars at many leading private universities . \nharvard university has the largest endowment " \ | ||
"in the world with $ 37.6 billion in assets as of june 30 , 2015 . \neach university typically " \ | ||
"has numerous endowments , each of which are frequently restricted to funding very specific areas " \ | ||
"of the university . \nthe most common examples are endowed professorships , and endowed " \ | ||
"scholarships or fellowships <h3> socially and environmentally responsible investing </h3> <p> " \ | ||
"many college and university endowments have come under fire in recent years for practices such " \ | ||
"as investing in fossil fuels , ' land grabs ' in poor countries and high-risk , high-return " \ | ||
"investment practices that led to the financial crisis . </p>" | ||
text = "wow do you know about financial endowment?" | ||
history = "hello how are you\n fine just got from work \n me too what do you do for living? \n i am a financist" | ||
|
||
request_data = {'batch': [{'topic': topic, 'knowledge': knowledge, 'text': text, 'history': history}]} | ||
result = requests.post(url, json=request_data).json()[0] | ||
assert result != '', f'Got empty string as a result' | ||
print('Got\n{}\nSuccess'.format(result)) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_knowledge_grounding() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters