-
Notifications
You must be signed in to change notification settings - Fork 190
/
api.py
32 lines (23 loc) · 815 Bytes
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from flask import Flask, request, jsonify
from flask_cors import CORS
import os
from ast import literal_eval
import pandas as pd
from cdqa.utils.filters import filter_paragraphs
from cdqa.pipeline.cdqa_sklearn import QAPipeline
app = Flask(__name__)
CORS(app)
dataset_path = os.environ['dataset_path']
reader_path = os.environ['reader_path']
df = pd.read_csv(dataset_path, converters={'paragraphs': literal_eval})
df = filter_paragraphs(df)
cdqa_pipeline = QAPipeline(reader=reader_path)
cdqa_pipeline.fit(X=df)
@app.route('/api', methods=['GET'])
def api():
query = request.args.get('query')
prediction = cdqa_pipeline.predict(X=query)
return jsonify(query=query,
answer=prediction[0],
title=prediction[1],
paragraph=prediction[2])