In [1]:
import os
import argparse
import codecs
import json
import random as rnd
import numpy as np
from collections import Counter, defaultdict
from itertools import chain, count
from six import string_types

In [2]:
OLD_WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
NEW_WHERE_OPS = ('=','>','<','>=','<=','!=','like','not in','in','between','is')
NEW_WHERE_DICT = {
    '=': 0,
    '>': 1,
    '<': 2,
    '>=': 3,
    '<=': 4,
    '!=': 5,
    'like': 6,
    'not in': 7,
    'in': 8,
    'between': 9,
    'is':10
}
# SQL_OPS = ('none','intersect', 'union', 'except')
SQL_OPS = {
    'none': 0,
    'intersect': 1,
    'union': 2,
    'except': 3
}
KW_DICT = {
    'where': 0,
    'groupBy': 1,
    'orderBy': 2
}
ORDER_OPS = {
    'desc': 0,
    'asc': 1}
AGG_OPS = ('none','max', 'min', 'count', 'sum', 'avg')

COND_OPS = {
    'and':0,
    'or':1
}

ORDER_DICT = {
    0:"asc_limit",
    1:"asc",
    2:"desc_limit",
    3: "desc"
}

TRAIN_COMPONENTS = ('multi_sql','keyword','col','op','agg','root_tem','des_asc','having','andor','value')
COMPONENTS_DICT = {
    'multi_sql':0,
    'keyword':1,
    'col':2,
    'op':3,
    'agg':4,
    'root_tem':5,
    'des_asc':6,
    'having':7,
    'andor':8,
    'value':9
}

In [3]:
def convert_to_op_index(is_not,op):
    op = OLD_WHERE_OPS[op]
    if is_not and op == "in":
        return 7
    try:
        return NEW_WHERE_DICT[op]
    except:
        print("Unsupport op: {}".format(op)) # TODO: check ! =
        return -1

def index_to_column_name(index, table):
    column_name = table["column_names"][index][1]
    table_index = table["column_names"][index][0]
    table_name = table["table_names"][table_index]
    return table_name, column_name, index


def get_label_cols(with_join,fk_dict,labels):
    # list(set([l[1][i][0][2] for i in range(min(len(l[1]), 3))]))
    cols = set()
    ret = []
    for i in range(len(labels)):
        cols.add(labels[i][0][2]) # still col index
        if len(cols) > 3:
            break
    for col in cols:
        # ret.append([col])
        if with_join and len(fk_dict[col]) > 0:
            ret.append(col) #ret.append([col]+fk_dict[col]) TODO: fk_dict removed
        else:
            ret.append(col)
    return ret


# history added
class MultiSqlPredictor:
    def __init__(self, question, sql, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.keywords = ('intersect', 'except', 'union')

    def generate_output(self):
        for key in self.sql:
            if key in self.keywords and self.sql[key]:
                return self.history + ['root'], key, self.sql[key]
        return self.history + ['root'], 'none', self.sql


class KeyWordPredictor:
    def __init__(self, question, sql, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'limit', 'having')

    def generate_output(self):
        sql_keywords = []
        for key in self.sql:
            if key in self.keywords and self.sql[key]: # included other keywords
                sql_keywords.append(key)
        return self.history, [len(sql_keywords), sql_keywords], self.sql


# history added
class ColPredictor:
    def __init__(self, question, sql, table, history, kw=None):
        self.sql = sql
        self.question = question
        self.history = history
        self.table = table
        self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having')
        self.kw = kw

    def generate_output(self):
        ret = []
        candidate_keys = self.sql.keys()
        if self.kw:
            candidate_keys = [self.kw]
        for key in candidate_keys:
            if key in self.keywords and self.sql[key]:
                cols = []
                sqls = []
                if key == 'groupBy':
                    sql_cols = self.sql[key]
                    for col in sql_cols:
                        cols.append((index_to_column_name(col[1], self.table), col[2]))
                        sqls.append(col) # col_unit1
                elif key == 'orderBy':
                    sql_cols = self.sql[key][1]
                    for col in sql_cols: # only contain col_unit1 in val_unit: (unit_op, col_unit1, col_unit2)
                        cols.append((index_to_column_name(col[1][1], self.table), col[1][2]))
                        sqls.append(col) # val_unit1
                elif key == 'select':
                    sql_cols = self.sql[key][1]
                    for col in sql_cols:  # only contain col_unit1 in val_unit
                        cols.append((index_to_column_name(col[1][1][1], self.table), col[1][1][2]))
                        sqls.append(col) # (agg_id, val_unit)
                elif key == 'where' or key == 'having':
                    sql_cols = self.sql[key]
                    for col in sql_cols: # TODO: check this one!
                        if not isinstance(col, list):
                            continue
                        try: # col_id of col_unit of val_unit of cond_unit of condition
                            cols.append((index_to_column_name(col[2][1][1], self.table), col[2][1][2]))
                        except:
                            print("Key:{} Col:{} Question:{}".format(key, col, self.question))
                        sqls.append(col) # cond_unit
                ret.append((
                    self.history + [key], (len(cols), cols), sqls
                ))
        return ret
        # ret.append(history+[key],)


class OpPredictor:
    def __init__(self, question, sql, history):
        self.sql = sql # check sql is cond_unit
        self.question = question
        self.history = history # history not change
        # self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having')

    def generate_output(self): # sql3: val_unit, sql4: val1
        return self.history, convert_to_op_index(self.sql[0],self.sql[1]), (self.sql[3], self.sql[4])


class AggPredictor:
    def __init__(self, question, sql, history,kw=None):
        self.sql = sql
        self.question = question
        self.history = history
        self.kw = kw
    def generate_output(self):
        label = -1
        if self.kw:
            key = self.kw
        else:
            key = self.history[-2]
        if key == 'select':
            label = self.sql[0] # check sql: (agg_id, val_unit)
        elif key == 'orderBy':
            label = self.sql[1][0] # check sql: val_unit1
        elif key == 'having':
            label = self.sql[2][1][0] # check sql: cond_unit
        else: # ADDED
            print("\n Unexpected pre-agg key: ", key)
            exit()
        return self.history, label

# TODO: check why not RootTemPredictor

# class RootTemPredictor:
#     def __init__(self, question, sql):
#         self.sql = sql
#         self.question = question
#         self.keywords = ('intersect', 'except', 'union')
#
#     def generate_output(self):
#         for key in self.sql:
#             if key in self.keywords:
#                 return ['ROOT'], key, self.sql[key]
#         return ['ROOT'], 'none', self.sql


# history added orderBy only one col and agg! TODO: CHECK multiple orderBy columns
class DesAscPredictor:
    def __init__(self, question, sql, table, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.table = table

    def generate_output(self):
        for key in self.sql: # check sql: whole sql
            if key == "orderBy" and self.sql[key]:
                # self.history.append(key)
                try:
                    col = self.sql[key][1][0][1][1] # w
                except:
                    print("question:{} sql:{}".format(self.question, self.sql))
                # self.history.append(index_to_column_name(col, self.table))
                # self.history.append(self.sql[key][1][0][1][0])
                if self.sql[key][0] == "asc" and self.sql["limit"]: # TODO: get limit value and labels
                    label = 0
                elif self.sql[key][0] == "asc" and not self.sql["limit"]:
                    label = 1
                elif self.sql[key][0] == "desc" and self.sql["limit"]:
                    label = 2
                else:
                    label = 3                                           # agg_id in col_unit of val_unit in orderBy
                return self.history+[index_to_column_name(col, self.table), self.sql[key][1][0][1][0]], label


class AndOrPredictor:
    def __init__(self, question, sql, table, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.table = table

    def generate_output(self):
        if 'where' in self.sql and self.sql['where'] and len(self.sql['where']) > 1:
            return self.history, COND_OPS[self.sql['where'][1]]
        return self.history,-1
    
    
def get_table_dict(table_data_path):
    data = json.load(open(table_data_path))
    table = dict()
    for item in data:
        table[item["db_id"]] = item
    return table

In [4]:
def parse_data_full_history(question_tokens, sql, table, history):
    table_schema = [
        table["table_names"],
        table["column_names"],
        table["column_types"]
    ]
    full_labels = []
    masks = [[COMPONENTS_DICT['multi_sql']]]
    stack = [("root", sql)]
    with_join = False
    fk_dict = defaultdict(list)
    for fk in table["foreign_keys"]:
        fk_dict[fk[0]].append(fk[1])
        fk_dict[fk[1]].append(fk[0])
    while len(stack) > 0:
        node = stack.pop()
        if node[0] == "root":
            if len(node) == 3 and node[2] == "multi":
                masks.append([0])
            history, label, ret_sql = MultiSqlPredictor(question_tokens, node[1], history).generate_output()
            full_labels.append([SQL_OPS[label]])
            history.append(label)
            if label == "none":
                stack.append((label, ret_sql))
                masks.append([COMPONENTS_DICT['keyword']])
            else:
                node[1][label] = None
                stack.append((label, ret_sql, node[1]))
                masks.append([COMPONENTS_DICT['multi_sql']])
        elif node[0] in ('intersect', 'except', 'union'):
            full_labels.append([])
            masks.append([-1])
            stack.append(("root", node[1],"multi"))
            stack.append(("root", node[2]))
        elif node[0] == "none":
            with_join = len(node[1]["from"]["table_units"]) > 1
            history, label, sql = KeyWordPredictor(question_tokens, node[1], history).generate_output()

            label_idxs = []
            for item in label[1]:
                if item in KW_DICT:
                    label_idxs.append(KW_DICT[item])
            label_idxs.sort()
            full_labels.append([label_idxs])

            if "orderBy" in label[1]:
                stack.append(("orderBy", node[1]))
            if "groupBy" in label[1]:
                has_having = "having" in label[1]
                stack.append(("groupBy", node[1],has_having))
            if "where" in label[1]:
                stack.append(("where", node[1]))
            if "select" in label[1]:
                stack.append(("select", node[1]))

        elif node[0] in ("select", "having", "orderBy"):
            history.append(node[0])
            masks.append([COMPONENTS_DICT['col']])

            col_ret = ColPredictor(question_tokens, node[1], table, history, node[0]).generate_output()
            agg_col_dict = dict()
            op_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue
                full_labels.append([get_label_cols(with_join, fk_dict, l[1])])
                for col, sql_item in zip(l[1], s):
                    key = "{}{}{}".format(col[0][0], col[0][1], col[0][2])
                    if key not in agg_col_dict:
                        agg_col_dict[key] = [(sql_item, col[0])]
                    else:
                        agg_col_dict[key].append((sql_item, col[0]))
                    if key not in op_col_dict:
                        op_col_dict[key] = [(sql_item, col[0])]
                    else:
                        op_col_dict[key].append((sql_item, col[0]))
                for key in agg_col_dict:
                    stack.append(("col", node[0], agg_col_dict[key], op_col_dict[key]))
        elif node[0] == "col":
            history.append(node[2][0][1])
            if node[1] == "where":
                stack.append(("op", node[2], "where"))
            elif node[1] != "groupBy":
                labels = []
                for sql_item, col in node[2]:
                    _, label = AggPredictor(question_tokens, sql_item, history, node[1]).generate_output()
                    if label - 1 >= 0:
                        labels.append(label - 1)

                if node[1] == "orderBy":
                    stack.append(("des_asc", sql))
                    continue
                masks.append([COMPONENTS_DICT['agg']])
                full_labels.append([labels[:min(len(labels), 3)]])

                if node[1] == "having":
                    stack.append(("op", node[2], "having"))

                if len(labels) == 0 and node[1] == "having":
                    history.append("none")
                for v in labels:
                    history.append(AGG_OPS[v + 1])
                    if node[1] != "having":
                        masks.append([-1])
                        full_labels.append([])

        elif node[0] == "des_asc":
            orderby_ret = DesAscPredictor(question_tokens, node[1], table, history).generate_output()

            if not orderby_ret:
                continue
            masks.append([COMPONENTS_DICT['des_asc']])
            # print(node[1])
            history.append(ORDER_DICT[orderby_ret[1]])
            full_labels.append([orderby_ret[1]])
            masks.append([-1])
            full_labels.append([])
        elif node[0] == "value":
            history.append(node[2])
            masks.append([COMPONENTS_DICT['root_tem'],COMPONENTS_DICT['value']])
            val1 = 0 # TODO: node[1][3]
            val2 = 1 # TODO: node[1][4]
            full_labels.append([1, [val1, val2]])
            history.append("value") # TODO: ([val1, val2])
            masks.append([-1])
            full_labels.append([])

        elif node[0] == "op":
            # history.append(node[1][0][1])
            labels = []

            for sql_item, col in node[1]:
                _, label, s = OpPredictor(question_tokens, sql_item, history).generate_output()
                if label != -1:
                    labels.append(label)

                # masks.append(COMPONENTS_DICT['root_tem'])
                if isinstance(s[0], dict):
                    stack.append(("root", s[0]))
                    history.append(NEW_WHERE_OPS[label])
                else:
                    stack.append(("value",sql_item,NEW_WHERE_OPS[label]))

            if len(labels) > 2:
                print(question_tokens)
            masks.append([COMPONENTS_DICT['op']])
            full_labels.append([labels])
            if stack[-1][0] == "root":
                full_labels.append([0])
                masks.append([COMPONENTS_DICT['root_tem']])
                masks.append([COMPONENTS_DICT['multi_sql']])
        elif node[0] == "where":
            history.append(node[0])
            hist, andor_label = AndOrPredictor(question_tokens, node[1], table, history).generate_output()
            col_ret = ColPredictor(question_tokens, node[1], table, history, "where").generate_output()
            # masks.append([COMPONENTS_DICT['col']])
            op_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue

                label = get_label_cols(with_join, fk_dict, l[1])
                if len(label) > 1:
                    full_labels.append([label,[andor_label]])
                    masks.append([COMPONENTS_DICT['col'],COMPONENTS_DICT['andor']])
                else:
                    full_labels.append([label,[]])
                    masks.append([COMPONENTS_DICT['col'], COMPONENTS_DICT['andor']])
                # full_labels.append()
                for col, sql_item in zip(l[1], s):
                    key = "{}{}{}".format(col[0][0], col[0][1], col[0][2])
                    if key not in op_col_dict:
                        op_col_dict[key] = [(sql_item, col[0])]
                    else:
                        op_col_dict[key].append((sql_item, col[0]))
                for key in op_col_dict:
                    stack.append(("col", "where", op_col_dict[key]))
        elif node[0] == "groupBy":

            history.append(node[0])
            col_ret = ColPredictor(question_tokens, node[1], table, history, node[0]).generate_output()
            masks.append([COMPONENTS_DICT['col']])
            # agg_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue

                history.append(l[1][0][0])
                full_labels.append([get_label_cols(with_join, fk_dict, l[1])])
                if node[2]:
                    stack.append(("having", node[1]))
                    full_labels.append([1])
                    masks.append([COMPONENTS_DICT['having']])
                else:
                    full_labels.append([0])
                    masks.append([COMPONENTS_DICT['having']])

    return history,full_labels,masks

In [5]:
def replace_value(conditions,nl,mp):
    for cond in conditions:
        for i,value in enumerate(cond):
            if i < 3:
                continue
            if not value or isinstance(value,dict):
                continue
            old_value = value
            if isinstance(value,str):
                if value[0] in ('\'','\"'):
                    value = value[1:-1]
                value = value.split()
            else:
                value = [value]
            try:
                new_val = 'VALUE_{}'.format(len(mp))
                cond[i] = new_val
                mp.append(old_value)
                if isinstance(value[0],str):
                    idx = nl.index(value[0])
                else:
                    idx = -1
                    for i in range(len(nl)):
                        if nl[i].isdigit() and (float(nl[i]) == value[0]):
                            idx = i
                            break
                    if idx == -1:
                        # print(old_value)
                        # print(nl)
                        continue
                nl = nl[:idx] + [new_val] + nl[idx+len(value):]
            except Exception:
                # print(old_value)
                # print(nl)
                continue
    return conditions,nl

def replace_nl(sql,nl):
    mp = []
    sql["where"],nl = replace_value(sql["where"],nl,mp)
    sql["having"],nl = replace_value(sql["having"],nl,mp)
    d = {}
    for i,val in enumerate(mp):
        d["VALUE_{}".format(i)] = val
    return d,sql,nl

def get_table_schema(table):
    col_names = table["column_names"]
    tab_names = table["table_names"]
    col_types = table["column_types"]
    col_name_comb = []
    for coln, colt in zip(col_names, col_types):
        tab_id, coln_str = coln
        if tab_id == -1:
            col_name_comb.append(["all", coln_str])
        else:
            col_name_comb.append(tab_names[tab_id].split(" ") + coln_str.split(" ") + [colt])
    
    return col_name_comb

def get_col_in_history(history):
    col_inds = []
    col_mask = []
    history_col_replaced = []
    #["singer", "age", 13]
    for hs in history:
        if isinstance(hs, tuple):
            col_inds.append(hs[2])
            col_mask.append(1)
            history_col_replaced.append("column")
        else:
            col_inds.append(-1)
            col_mask.append(0)
            history_col_replaced.append(hs)
    
    return col_inds, col_mask, history_col_replaced
        
            
def parse_data_new_format(data, table_dict):
    dataset = []
    for item in data:
        table_one = table_dict[item["db_id"]]
        table_schema = get_table_schema(table_one)
        mp,sql,nl = replace_nl(item["sql"],item["question_toks"])
        history, labels, masks = parse_data_full_history(nl, sql, table_one, [])
        
        col_inds, col_mask, history_col_replaced = get_col_in_history(history)
        if not len(history) == len(labels) == len(masks):
            print('\n-------------------------------')
            print('len of hisotry: ', len(history), 'len of labels: ', len(labels), "len of masks: ", len(masks))
            print("query  : ", item['query'])
            print("history: ", history)
            print("masks    : ", masks)
            print("label    : ", labels)

    

In [6]:
COMPONENTS_DICT = {
    'multi_sql':0,
    'keyword':1,
    'col':2,
    'op':3,
    'agg':4,
    'root_tem':5,
    'des_asc':6,
    'having':7,
    'andor':8,
    'value':9
}

In [7]:
# stepten no order by
# SELECT T1.Area FROM APPELLATIONS AS T1 JOIN WINE AS T2 ON T1.Appelation  =  T2.Appelation GROUP BY T2.Appelation HAVING T2.year  <  2010 ORDER BY count(*) DESC LIMIT 1

In [8]:
data_path = '../../SyntaxSQL/data/dev.json'
table_data_path = '../../SyntaxSQL/data/tables.json'
data = json.load(open(data_path))
table_dict = get_table_dict(table_data_path)
parse_data_new_format(data, table_dict)


-------------------------------
len of hisotry:  14 len of labels:  13 len of masks:  13
query  :  SELECT t3.individual_last_name FROM organizations AS t1 JOIN organization_contact_individuals AS t2 ON t1.organization_id  =  t2.organization_id JOIN individuals AS t3 ON t2.individual_id  =  t3.individual_id WHERE t1.uk_vat_number  =  (SELECT max(uk_vat_number) FROM organizations) ORDER BY t2.date_contact_to ASC LIMIT 1
history:  ['root', 'none', 'select', ('individuals', 'individual last name', 23), 'where', ('organizations', 'uk vat number', 27), '=', 'root', 'none', 'select', ('organizations', 'uk vat number', 27), 'max', 'orderBy', ('organization contact individuals', 'date contact to', 35)]
masks    :  [[0], [1], [2], [4], [2, 8], [3], [5], [0], [1], [2], [4], [-1], [2]]
label    :  [[0], [[0, 2]], [[23]], [[]], [[27], []], [[0]], [0], [0], [[]], [[27]], [[0]], [], [[35]]]

-------------------------------
len of hisotry:  14 len of labels:  13 len of masks:  13
query  :  SELECT t3.