In [1]:
import json
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
torch.manual_seed(2333)
torch.cuda.manual_seed(2333)
np.random.seed(2333)
random.seed(2333)
torch.backends.cudnn.deterministic = True

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
DB_tok = pickle.load(open("save/DB_tok", "rb"))
# DB_tok = {db: DB_tok[db] for db in ["college_1", "book_2", "company_employee", "flight_1", "concert_singer"]}

In [6]:
query_size = 380
bert_size = 768

In [7]:
def align(x, size):
    if len(x) < size:
        return x + [0] * (size - len(x))
    return x[:size - 1] + [102]

In [8]:
def query2tok(q):
    q = "[CLS] " + q + " [SEP]"
    return align(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(q)), query_size)

In [9]:
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.MD = nn.ModuleDict({
            "encoder": BertModel.from_pretrained('bert-base-uncased'),
            # "query_encoder": BertModel.from_pretrained('bert-base-uncased'),
            # "db_encoder": BertModel.from_pretrained('bert-base-uncased'),
            "linear1": nn.Linear(bert_size, 768),
            "linear2": nn.Linear(768, 300),
            "linear3": nn.Linear(300, 1)
        })
        
        for submodel in [self.MD["encoder"]]:
            for param in submodel.parameters():
                param.requires_grad = False
        
    
    def forward(self, x):
        db, tok = x
        x = self.MD['encoder'](db, token_type_ids=tok)
        # Q = self.MD["query_encoder"](query)
        # D = self.MD["db_encoder"](db)
#         x = torch.sum(Q[0][:, 0, :] * D[0][:, 0, :], axis=-1)
#         print(x)
        # x = torch.cat([Q[0][:, 0, :], D[0][:, 0, :]], -1)
        x = torch.nn.functional.relu(self.MD["linear1"](x[0][:, 0, :]))
        x = torch.nn.functional.relu(self.MD["linear2"](x))
        x = self.MD["linear3"](x)
        return torch.sigmoid(x).view(-1)

In [10]:
model = net().to(device)
model.load_state_dict(torch.load("save/best_model.pt"))

<All keys matched successfully>

In [16]:
def db_infer(q):
    DB_score = {}
    q_tok = query2tok(q)
    for db_name, db_tok in DB_tok.items():
        if db_name not in ['college_1', 'book_2', 'company_employee', 'flight_1', 'concert_singer']:
            continue
        X = [(q_tok, db_tok)]
        seq = torch.tensor([x[0] + x[1] for x in X]).to(device)
        mask = torch.tensor([[0] * len(x[0]) + [1] * len(x[1]) for x in X]).to(device)
        model.eval()
        y_pred = model((seq, mask)).item()
        DB_score[db_name] = y_pred
#         print(db_name, y_pred)
    db = max(DB_score, key=lambda d: DB_score[d])
    return db, DB_score[db]

In [17]:
query = "How many singers do we have?"
db_infer(query)

('concert_singer', 0.6469578146934509)

In [None]:
from flask import Flask
from flask import request, jsonify
from flask_cors import CORS

app = Flask(__name__)
CORS(app)

app = Flask(__name__)
CORS(app)

@app.route('/')
def hello_world():
    return 'Hello, World!'

@app.route('/query')
def query():
    q = request.args.get('q')
    
    db, c  = db_infer(q)
    try:
        return jsonify({
            "data": db,
            "confidence": c
        })
    except:
        return jsonify({
            "data": "college_1",
            "confidence": 0.0
        })

app.run(host='0.0.0.0', port=443)

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://0.0.0.0:443/ (Press CTRL+C to quit)
184.105.247.196 - - [16/Jan/2020 03:23:59] code 400, message Bad request syntax ('\x16\x03\x01\x00\x9a\x01\x00\x00\x96\x03\x03(Dwn\x89ùq§³]P¸\xa0"\x96ê¢î\x9b\x12Zôú4\x1cÇ~Ó5¯ÖÜ\x00\x00\x1aÀ/À+À\x11À\x07À\x13À\tÀ\x14À')
184.105.247.196 - - [16/Jan/2020 03:23:59] "   (Dwnùq§³]P¸ "ê¢îZôú4Ç~Ó5¯ÖÜ  À/À+ÀÀÀÀ	ÀÀ" HTTPStatus.BAD_REQUEST -
203.208.61.68 - - [16/Jan/2020 03:30:00] "GET /query?q=show+me+the+list+of+authors HTTP/1.1" 200 -
203.208.61.68 - - [16/Jan/2020 03:30:20] "GET /query?q=show+me+the+author+of+the+cheapest+book HTTP/1.1" 200 -
203.208.61.68 - - [16/Jan/2020 03:31:41] "GET /query?q=show+me+the+author+of+the+cheapest+book HTTP/1.1" 200 -
203.208.61.68 - - [16/Jan/2020 03:52:42] "GET /query?q=What+is+the+title+of+the+most+expensive+book%3F HTTP/1.1" 200 -
203.208.61.68 - - [16/Jan/2020 03:52:50] "GET /query?q=How+many+books+has+each+author+written%3F HTTP/1.1" 200 -
203.208.61.68 - - [16/Jan/2020 03:53:25