In [1]:
import torch
from src import args as arg
from src import utils
from src.models.model import IRNet
from src.rule import semQL
import argparse
import sem2SQL
import copy
import os
import json
import nltk
import re
import tempfile

args = argparse.Namespace(
    action_embed_size=128,
    att_vec_size=300,
    batch_size=64,
    beam_size=10,
    clip_grad=5.0,
    col_embed_size=300,
    column_att='affine',
    column_pointer=True,
    cuda=True,
    dataset='./fdata',
    decode_max_time_step=40,
    dropout=0.3,
    embed_size=300,
    epoch=50,
    glove_embed_path='./data/glove.42B.300d.txt',
    hidden_size=300,
    load_model='./saved_model/IRNet_pretrained.model',
    loss_epoch_threshold=50,
    lr=0.001,
    lr_scheduler=True,
    lr_scheduler_gammar=0.5,
    lstm='lstm',
    max_epoch=-1,
    model_name='rnn',
    no_query_vec_to_action_map=False,
    optimizer='Adam',
    query_vec_to_action_diff_map=False,
    readout='identity',
    save='save',
    save_to='model',
    seed=90,
    sentence_features=True,
    sketch_loss_coefficient=1.0,
    toy=False,
    type_embed_size=128,
    word_dropout=0.2
)

def load_data(tmpdir):
    global table_data, sql_data, schemas
    
    import copy
    cargs = copy.deepcopy(args)
    cargs.dataset = os.path.join(tmpdir, 'fdata')
    
    _, table_data, sql_data, _ = utils.load_dataset(cargs.dataset, use_small=args.toy)
    with open(os.path.join(args.dataset, 'tables.json'), 'r', encoding='utf8') as f:
        table_datas = json.load(f)
    schemas = dict()
    for i in range(len(table_datas)):
        schemas[table_datas[i]['db_id']] = table_datas[i]

        
def load_model():
    global model
    
    grammar = semQL.Grammar()
    model = IRNet(args, grammar)
    if args.cuda: model.cuda()
    print('load pretrained model from %s'% (args.load_model))
    pretrained_model = torch.load(args.load_model,
                                     map_location=lambda storage, loc: storage)
    pretrained_modeled = copy.deepcopy(pretrained_model)
    for k in pretrained_model.keys():
        if k not in model.state_dict().keys():
            del pretrained_modeled[k]
    model.load_state_dict(pretrained_modeled)
    model.word_emb = utils.load_word_emb(args.glove_embed_path)


def feed_question(qstr, db_id='concert_singer'):
    template = json.load(open("preprocess/fuck.json"))[0]
    template['db_id'] = db_id
    template['question'] = qstr
    template['question_toks'] = nltk.tokenize.TweetTokenizer().tokenize(qstr)
    
    tempdir = tempfile.mkdtemp('lalala')
    print('cp -r fdata ' + tempdir)
    os.system('cp -r fdata ' + tempdir)
    json.dump([template], open(os.path.join(tempdir, "tmp.json"), "w"))
    
    print('preprocess/run_me.sh {} tables.json {}'.format(os.path.join(tempdir, "tmp.json"), os.path.join(tempdir, "fdata", "dev.json")))
    os.system('preprocess/run_me.sh {} tables.json {}'.format(os.path.join(tempdir, "tmp.json"), os.path.join(tempdir, "fdata", "dev.json")))
    return tempdir

    
def do_prediction(db_id='concert_singer'):
    print(sql_data)
    perm = list(range(len(sql_data)))
    examples = utils.to_batch_seq(sql_data, table_data, perm, 0, 1, is_train=False)
    example=examples[0]
    print(example.src_sent)
    results = model.parse(example, args.beam_size)[0]
    sql_sentences = []
    for result in results:
        modified_example = {'model_result_replace': ' '.join([str(x) for x in results[0].actions]), **sql_data[0]}
        sql_sentence = sem2SQL.transform(modified_example, schemas[db_id])[0]
        sql_sentences.append(sql_sentence)
    print(sql_sentences)
    return sql_sentences

In [None]:
# load_model()

Use Column Pointer:  True
load pretrained model from ./saved_model/IRNet_pretrained.model
Loading word embedding from ./data/glove.42B.300d.txt


In [None]:
def e2e_demo(qstr, db_id='concert_singer'):
    tempdir = feed_question(qstr, db_id=db_id)
    load_data(tempdir)
    return do_prediction(db_id=db_id)

def e2e_execute_demo(qstr, db_id='concert_singer', last=None):
    import sqlalchemy
    import pandas as pd
    sqls = e2e_demo(qstr, db_id=db_id)
    
    for sql in sqls:
        if last is not None and last != '':
            sql = sql.replace('= 1', "= '{}'".format(last))
            print(sql)
        joins = list(re.findall('([^\s]+)\s+AS\s+([^\s]+)\s+JOIN\s+([^\s]+)\s+AS\s+([^\s]+)', sql))
        ssql = re.split('[^\s]+\s+AS\s+[^\s]+\s+JOIN\s+[^\s]+\s+AS\s+[^\s]+', sql)
        print(ssql, joins)
        assert len(ssql) == len(joins) + 1
        nsql = ''
        
        dbschema = schemas[db_id]
        column_names = dbschema['column_names_original']
        table_names = dbschema['table_names_original']
        fk = dbschema['foreign_keys']
        
        def get_common_key(t1, t2):
            i1 = table_names.index(t1)
            i2 = table_names.index(t2)
            columns_1 = [(i, y) for i, (x, y) in enumerate(column_names) if x == i1]
            columns_2 = [(i, y) for i, (x, y) in enumerate(column_names) if x == i2]
            for _, x in columns_1:
                for _, y in columns_2:
                    if x == y: return x, y
            return None, None
            
        
        for x in range(len(ssql)):
            nsql += ssql[x]
            if x != len(ssql) - 1:
                print(joins[x])
                k1, k2 = get_common_key(joins[x][0], joins[x][2])
                if k1 is not None:
                    nsql += '{0} AS {1} JOIN {2} AS {3} ON {1}.{4} = {3}.{5}'.format(*(list(joins[x]) + [k1, k2]))
                else:
                    nsql += '{0} AS {1} JOIN {2} AS {3}'.format(*joins[x])
        sql = nsql
        
        print(nsql)
        engine = sqlalchemy.engine.create_engine("sqlite:///../text2sql/database/{}/{}.sqlite".format(db_id, db_id))
        df = pd.read_sql(sql, engine)
        if df.shape[0] != 0: return sql, df
    
    return sql, df

In [None]:
model.eval()
e2e_execute_demo("Give me the author of the cheapest book", db_id="book_2")

In [None]:
e2e_execute_demo("show me the titles of the books", db_id="book_2")

In [None]:
from flask import Flask
from flask import request, jsonify
from flask_cors import CORS

app = Flask(__name__)
CORS(app)

@app.route('/')
def hello_world():
    return 'Hello, World!'

@app.route('/query')
def query():
    q = request.args.get('q')
    db = request.args.get('db')
    last = request.args.get('last', '')
    
    sql, val = e2e_execute_demo(q, db_id=db, last=last)
    try:
        return jsonify({
            "data": json.loads(val.to_json(orient='records')),
            "sql": sql
        })
    except:
        return jsonify({
            "data": [],
            "sql": ''
        })

app.run(host='0.0.0.0', port=443)