In [1]:
import os
import argparse
import codecs
import json
import random as rnd
import numpy as np
from collections import Counter, defaultdict
from itertools import chain, count
from six import string_types
import torch
import torchtext.data
import torchtext.vocab

# import table
# import table.IO
# import opts
# from tree import SCode

In [2]:
OLD_WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
NEW_WHERE_OPS = ('=','>','<','>=','<=','!=','like','not in','in','between','is')
NEW_WHERE_DICT = {
    '=': 0,
    '>': 1,
    '<': 2,
    '>=': 3,
    '<=': 4,
    '!=': 5,
    'like': 6,
    'not in': 7,
    'in': 8,
    'between': 9,
    'is':10
}
# SQL_OPS = ('none','intersect', 'union', 'except')
SQL_OPS = {
    'none': 0,
    'intersect': 1,
    'union': 2,
    'except': 3
}
KW_DICT = {
    'where': 0,
    'groupBy': 1,
    'orderBy': 2
}
ORDER_OPS = {
    'desc': 0,
    'asc': 1}
AGG_OPS = ('none','max', 'min', 'count', 'sum', 'avg')

COND_OPS = {
    'and':0,
    'or':1
}

ORDER_DICT = {
    0:"asc_limit",
    1:"asc",
    2:"desc_limit",
    3: "desc"
}

TRAIN_COMPONENTS = ('multi_sql','keyword','col','op','agg','root_tem','des_asc','having','andor','value')
COMPONENTS_DICT = {
    'multi_sql':0,
    'keyword':1,
    'col':2,
    'op':3,
    'agg':4,
    'root_tem':5,
    'des_asc':6,
    'having':7,
    'andor':8,
    'value':9
}

In [3]:
def convert_to_op_index(is_not,op):
    op = OLD_WHERE_OPS[op]
    if is_not and op == "in":
        return 7
    try:
        return NEW_WHERE_DICT[op]
    except:
        print("Unsupport op: {}".format(op)) # TODO: check ! =
        return -1

def index_to_column_name(index, table):
    column_name = table["column_names"][index][1]
    table_index = table["column_names"][index][0]
    table_name = table["table_names"][table_index]
    return table_name, column_name, index


def get_label_cols(with_join,fk_dict,labels):
    # list(set([l[1][i][0][2] for i in range(min(len(l[1]), 3))]))
    cols = set()
    ret = []
    for i in range(len(labels)):
        cols.add(labels[i][0][2]) # still col index
        if len(cols) > 3:
            break
    for col in cols:
        # ret.append([col])
        if with_join and len(fk_dict[col]) > 0:
            ret.append(col) #ret.append([col]+fk_dict[col]) TODO: fk_dict removed
        else:
            ret.append(col)
    return ret


# history added
class MultiSqlPredictor:
    def __init__(self, question, sql, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.keywords = ('intersect', 'except', 'union')

    def generate_output(self):
        for key in self.sql:
            if key in self.keywords and self.sql[key]:
                return self.history + ['root'], key, self.sql[key]
        return self.history + ['root'], 'none', self.sql


class KeyWordPredictor:
    def __init__(self, question, sql, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'limit', 'having')

    def generate_output(self):
        sql_keywords = []
        for key in self.sql:
            if key in self.keywords and self.sql[key]: # included other keywords
                sql_keywords.append(key)
        return self.history, [len(sql_keywords), sql_keywords], self.sql


# history added
class ColPredictor:
    def __init__(self, question, sql, table, history, kw=None):
        self.sql = sql
        self.question = question
        self.history = history
        self.table = table
        self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having')
        self.kw = kw

    def generate_output(self):
        ret = []
        candidate_keys = self.sql.keys()
        if self.kw:
            candidate_keys = [self.kw]
        for key in candidate_keys:
            if key in self.keywords and self.sql[key]:
                cols = []
                sqls = []
                if key == 'groupBy':
                    sql_cols = self.sql[key]
                    for col in sql_cols:
                        cols.append((index_to_column_name(col[1], self.table), col[2]))
                        sqls.append(col) # col_unit1
                elif key == 'orderBy':
                    sql_cols = self.sql[key][1]
                    for col in sql_cols: # only contain col_unit1 in val_unit: (unit_op, col_unit1, col_unit2)
                        cols.append((index_to_column_name(col[1][1], self.table), col[1][2]))
                        sqls.append(col) # val_unit1
                elif key == 'select':
                    sql_cols = self.sql[key][1]
                    for col in sql_cols:  # only contain col_unit1 in val_unit
                        cols.append((index_to_column_name(col[1][1][1], self.table), col[1][1][2]))
                        sqls.append(col) # (agg_id, val_unit)
                elif key == 'where' or key == 'having':
                    sql_cols = self.sql[key]
                    for col in sql_cols: # TODO: check this one!
                        if not isinstance(col, list):
                            continue
                        try: # col_id of col_unit of val_unit of cond_unit of condition
                            cols.append((index_to_column_name(col[2][1][1], self.table), col[2][1][2]))
                        except:
                            print("Key:{} Col:{} Question:{}".format(key, col, self.question))
                        sqls.append(col) # cond_unit
                ret.append((
                    self.history + [key], (len(cols), cols), sqls
                ))
        return ret
        # ret.append(history+[key],)


class OpPredictor:
    def __init__(self, question, sql, history):
        self.sql = sql # check sql is cond_unit
        self.question = question
        self.history = history # history not change
        # self.keywords = ('select', 'where', 'groupBy', 'orderBy', 'having')

    def generate_output(self): # sql3: val_unit, sql4: val1
        return self.history, convert_to_op_index(self.sql[0],self.sql[1]), (self.sql[3], self.sql[4])


class AggPredictor:
    def __init__(self, question, sql, history,kw=None):
        self.sql = sql
        self.question = question
        self.history = history
        self.kw = kw
    def generate_output(self):
        label = -1
        if self.kw:
            key = self.kw
        else:
            key = self.history[-2]
        if key == 'select':
            label = self.sql[0] # check sql: (agg_id, val_unit)
        elif key == 'orderBy':
            label = self.sql[1][0] # check sql: val_unit1
        elif key == 'having':
            label = self.sql[2][1][0] # check sql: cond_unit
        else: # ADDED
            print("\n Unexpected pre-agg key: ", key)
            exit()
        return self.history, label

# TODO: check why not RootTemPredictor

# class RootTemPredictor:
#     def __init__(self, question, sql):
#         self.sql = sql
#         self.question = question
#         self.keywords = ('intersect', 'except', 'union')
#
#     def generate_output(self):
#         for key in self.sql:
#             if key in self.keywords:
#                 return ['ROOT'], key, self.sql[key]
#         return ['ROOT'], 'none', self.sql


# history added orderBy only one col and agg! TODO: CHECK multiple orderBy columns
class DesAscPredictor:
    def __init__(self, question, sql, table, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.table = table

    def generate_output(self):
        for key in self.sql: # check sql: whole sql
            if key == "orderBy" and self.sql[key]:
                # self.history.append(key)
                try:
                    col = self.sql[key][1][0][1][1] # w
                except:
                    print("question:{} sql:{}".format(self.question, self.sql))
                # self.history.append(index_to_column_name(col, self.table))
                # self.history.append(self.sql[key][1][0][1][0])
                if self.sql[key][0] == "asc" and self.sql["limit"]: # TODO: get limit value and labels
                    label = 0
                elif self.sql[key][0] == "asc" and not self.sql["limit"]:
                    label = 1
                elif self.sql[key][0] == "desc" and self.sql["limit"]:
                    label = 2
                else:
                    label = 3                                           # agg_id in col_unit of val_unit in orderBy
                return self.history+[index_to_column_name(col, self.table), self.sql[key][1][0][1][0]], label


class AndOrPredictor:
    def __init__(self, question, sql, table, history):
        self.sql = sql
        self.question = question
        self.history = history
        self.table = table

    def generate_output(self):
        if 'where' in self.sql and self.sql['where'] and len(self.sql['where']) > 1:
            return self.history, COND_OPS[self.sql['where'][1]]
        return self.history,-1
    
    
def get_table_dict(table_data_path):
    data = json.load(open(table_data_path))
    table = dict()
    for item in data:
        table[item["db_id"]] = item
    return table

In [4]:
def parse_data_full_history(question_tokens, sql, table, history):
    table_schema = [
        table["table_names"],
        table["column_names"],
        table["column_types"]
    ]
    full_labels = []
    masks = [COMPONENTS_DICT['multi_sql']]
    stack = [("root", sql)]
    with_join = False
    fk_dict = defaultdict(list)
    for fk in table["foreign_keys"]:
        fk_dict[fk[0]].append(fk[1])
        fk_dict[fk[1]].append(fk[0])
    while len(stack) > 0:
        node = stack.pop()
        if node[0] == "root":
            if len(masks) > 1:
                masks.append(COMPONENTS_DICT['root_tem'])
                full_labels.append(-1)
            history, label, sql = MultiSqlPredictor(question_tokens, node[1], history).generate_output()
            full_labels.append(SQL_OPS[label])
            history.append(label)
            if label == "none":
                stack.append((label, sql))
                masks.append(COMPONENTS_DICT['keyword'])
            else:
                node[1][label] = None
                stack.append((label, node[1], sql)) # TODO: double check
                masks.append(COMPONENTS_DICT['multi_sql'])
                # if label != "none":
                # stack.append(("none",node[1]))
        elif node[0] in ('intersect', 'except', 'union'):
            stack.append(("root", node[1]))
            stack.append(("root", node[2]))
        elif node[0] == "none":
            with_join = len(node[1]["from"]["table_units"]) > 1
            history, label, sql = KeyWordPredictor(question_tokens, node[1], history).generate_output()
            # full_labels.append(label)
            # [len(sql_keywords), sql_keywords]

            label_idxs = []
            for item in label[1]:
                if item in KW_DICT:
                    label_idxs.append(KW_DICT[item])
            label_idxs.sort()
            full_labels.append(label_idxs)

            if "orderBy" in label[1]:
                stack.append(("orderBy", node[1]))
            if "groupBy" in label[1]:
                has_having = "having" in label[1]

                stack.append(("groupBy", node[1], has_having))
            if "where" in label[1]:
                stack.append(("where", node[1]))
            if "select" in label[1]:
                stack.append(("select", node[1]))

        elif node[0] in ("select", "having", "orderBy"):
            # if node[0] != "orderBy":
            history.append(node[0])
            masks.append(COMPONENTS_DICT['col'])

            col_ret = ColPredictor(question_tokens, node[1], table, history, node[0]).generate_output()
            agg_col_dict = dict()
            op_col_dict = dict()
            if len(col_ret) > 1:
                print("\nWarning: why return more than one col_ret!")
                exit()
            # history + [key], (len(cols), cols), sqls
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("\nWarning: predicted 0 columns!")
                    exit()

                full_labels.append(get_label_cols(with_join, fk_dict, l[1]))
                for col, sql_item in zip(l[1], s):
                    key = "{}{}{}".format(col[0][0], col[0][1], col[0][2]) #table_name, column_name, index
                    # sql_item: (agg_id, val_unit)/select, val_unit1/orderBy, col_unit1/groupBy, cond_unit/where/having
                    if key not in agg_col_dict:
                        agg_col_dict[key] = [(sql_item, col[0])]
                    else:
                        agg_col_dict[key].append((sql_item, col[0])) # for the same col with multiple agg
                    if key not in op_col_dict:
                        op_col_dict[key] = [(sql_item, col[0])]
                    else:
                        op_col_dict[key].append((sql_item, col[0]))
                for key in agg_col_dict:
                    stack.append(("col", node[0], agg_col_dict[key], op_col_dict[key]))
        elif node[0] == "col":
            # full_labels.append(node[2][-1])
            history.append(node[2][0][1])
            if node[1] == "where":
                # stack.append(("value", node[2], "where"))
                stack.append(("op", node[2], "where"))
                # masks.append(COMPONENTS_DICT['op'])
            elif node[1] != "groupBy":
                labels = []
                for sql_item, col in node[2]:
                    _, label = AggPredictor(question_tokens, sql_item, history, node[1]).generate_output()
                    if label - 1 >= 0:
                        labels.append(label - 1) # TODO: check why -1

                # print(node[2][0][1][2])
                masks.append(COMPONENTS_DICT['agg'])
                full_labels.append(labels[:min(len(labels), 3)])

                if node[1] == "having":
                    # stack.append(("value", node[2], "having"))
                    stack.append(("op", node[2], "having"))
                if node[1] == "orderBy":
                    stack.append(("des_asc", sql))

                if len(labels) > 0:
                    history.append(AGG_OPS[labels[0] + 1])

        elif node[0] == "des_asc":
            orderby_ret = DesAscPredictor(question_tokens, node[1], table, history).generate_output()

            if not orderby_ret:
                continue
            masks.append(COMPONENTS_DICT['des_asc'])
            # print(node[1])
            history.append(orderby_ret[1])
            full_labels.append(orderby_ret[1])
            if len(stack) > 0:
                masks.append(-1)
        elif node[0] == "value":
            masks.append([COMPONENTS_DICT['value'],COMPONENTS_DICT['root_tem']])
            val1 = node[1][3]
            val2 = node[1][4]
            if val2:
                if len(stack) > 0:
                    masks.append(-1)
                full_labels.append([1,[val1,val2]])
                history.append([val1,val2])
            else:
                if len(stack) > 0:
                    masks.append(-1)
                full_labels.append([1,[val1]])
                history.append([val1])

        elif node[0] == "op":
            labels = []

            for sql_item, col in node[1]:
                _, label, s = OpPredictor(question_tokens, sql_item, history).generate_output()
                if label != -1:
                    labels.append(label)
                    history.append(NEW_WHERE_OPS[label])

                # masks.append(COMPONENTS_DICT['root_tem'])
                if isinstance(s[0], dict):
                    stack.append(("root", s[0]))
                    masks.append(COMPONENTS_DICT['root_tem'])
                    full_labels.append([0,[]])
                else:
                    stack.append(("value",sql_item))

            if len(labels) > 2:
                print(question_tokens)

            masks.append(COMPONENTS_DICT['op'])
            full_labels.append(labels)
        elif node[0] == "where":
            history.append(node[0])
            hist, andor_label = AndOrPredictor(question_tokens, node[1], table, history).generate_output()

            col_ret = ColPredictor(question_tokens, node[1], table, history, "where").generate_output()
            masks.append(COMPONENTS_DICT['col'])
            op_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue

                label = get_label_cols(with_join, fk_dict, l[1])
                if len(label) > 1:
                    full_labels.append([label,[andor_label]])
                    masks.append([COMPONENTS_DICT['col'],COMPONENTS_DICT['andor']])
                else:
                    full_labels.append([label,[]])
                    masks.append([COMPONENTS_DICT['col'], COMPONENTS_DICT['andor']])
                # full_labels.append()
                for col, sql_item in zip(l[1], s):
                    key = "{}{}{}".format(col[0][0], col[0][1], col[0][2])
                    if key not in op_col_dict:
                        op_col_dict[key] = [(sql_item, col[0])]
                    else:
                        op_col_dict[key].append((sql_item, col[0]))
                for key in op_col_dict:
                    stack.append(("col", "where", op_col_dict[key]))
        elif node[0] == "groupBy":
            history.append(node[0])
            col_ret = ColPredictor(question_tokens, node[1], table, history, node[0]).generate_output()
            masks.append(COMPONENTS_DICT['col'])
            # agg_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue

                full_labels.append(get_label_cols(with_join, fk_dict, l[1]))
                if node[2]:
                    stack.append(("having", node[1]))
                    full_labels.append(1)
                    masks.append(COMPONENTS_DICT['having'])
                else:
                    full_labels.append(0)
                    masks.append(COMPONENTS_DICT['having'])

    return history,full_labels,masks


In [5]:
def parse_data_full_history(question_tokens, sql, table, history):
    table_schema = [
        table["table_names"],
        table["column_names"],
        table["column_types"]
    ]
    full_labels = []
    masks = [[COMPONENTS_DICT['multi_sql']]]
    stack = [("root", sql)]
    with_join = False
    fk_dict = defaultdict(list)
    for fk in table["foreign_keys"]:
        fk_dict[fk[0]].append(fk[1])
        fk_dict[fk[1]].append(fk[0])
    while len(stack) > 0:
        node = stack.pop()
        if node[0] == "root":
            if len(node) == 3 and node[2] == "multi":
                masks.append([0])
            # print(sql)
            # print("root")
            # if len(masks) > 1:
                # masks.append([COMPONENTS_DICT['root_tem']])
                # full_labels.append([-1])
            history, label, ret_sql = MultiSqlPredictor(question_tokens, node[1], history).generate_output()
            full_labels.append([SQL_OPS[label]])
            history.append(label)
            if label == "none":
                stack.append((label, ret_sql))
                masks.append([COMPONENTS_DICT['keyword']])
            else:
                node[1][label] = None
                stack.append((label, ret_sql, node[1]))
                masks.append([COMPONENTS_DICT['multi_sql']])
                # if label != "none":
                # stack.append(("none",node[1]))
        elif node[0] in ('intersect', 'except', 'union'):
            full_labels.append([])
            masks.append([-1])
            stack.append(("root", node[1],"multi"))
            stack.append(("root", node[2]))
        elif node[0] == "none":
            with_join = len(node[1]["from"]["table_units"]) > 1
            history, label, sql = KeyWordPredictor(question_tokens, node[1], history).generate_output()
            # full_labels.append(label)

            label_idxs = []
            for item in label[1]:
                if item in KW_DICT:
                    label_idxs.append(KW_DICT[item])
            label_idxs.sort()
            full_labels.append([label_idxs])

            # if "orderBy" in label[1] and "having" in label[1]:
            #     print("Aaaaa!!!!!!!!!")
            #     print(sql)
            # if "having" in label[1]:
            #     stack.append(("having", node[1]))
            if "orderBy" in label[1]:
                stack.append(("orderBy", node[1]))
            if "groupBy" in label[1]:
                has_having = "having" in label[1]
                # if "having" in label[1]:
                #     dataset['having_dataset'].append({
                #         "question_tokens": question_tokens,
                #         "ts": table_schema,
                #         "history": history[:],
                #         "gt_col": node[1]["groupBy"][0][1],
                #         "label": 1
                #     })
                # else:
                #     dataset['having_dataset'].append({
                #         "question_tokens": question_tokens,
                #         "ts": table_schema,
                #         "history": history[:],
                #         "gt_col": node[1]["groupBy"][0][1],
                #         "label": 0
                #     })
                stack.append(("groupBy", node[1],has_having))
            if "where" in label[1]:
                stack.append(("where", node[1]))
            if "select" in label[1]:
                stack.append(("select", node[1]))

        elif node[0] in ("select", "having", "orderBy"):
            # if node[0] != "orderBy":
            history.append(node[0])
            masks.append([COMPONENTS_DICT['col']])
            # if node[0] == "orderBy":
                # orderby_ret = DesAscPredictor(question_tokens, node[1], table, history).generate_output()
                # if orderby_ret:
                #     dataset['des_asc_dataset'].append({
                #         "question_tokens": question_tokens,
                #         "ts": table_schema,
                #         "history": orderby_ret[0],
                #         "gt_col": node[1]["orderBy"][1][0][1][1],
                #         "label": ORDER_OPS[orderby_ret[1]]
                #     })
                    # history.append(orderby_ret[1])
            col_ret = ColPredictor(question_tokens, node[1], table, history, node[0]).generate_output()
            agg_col_dict = dict()
            op_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue
                # dataset['col_dataset'].append({
                #     "question_tokens": question_tokens,
                #     "ts": table_schema,
                #     "history": history[:],
                #     "label": get_label_cols(with_join, fk_dict, l[1])
                # })
                full_labels.append([get_label_cols(with_join, fk_dict, l[1])])
                for col, sql_item in zip(l[1], s):
                    key = "{}{}{}".format(col[0][0], col[0][1], col[0][2])
                    if key not in agg_col_dict:
                        agg_col_dict[key] = [(sql_item, col[0])]
                    else:
                        agg_col_dict[key].append((sql_item, col[0]))
                    if key not in op_col_dict:
                        op_col_dict[key] = [(sql_item, col[0])]
                    else:
                        op_col_dict[key].append((sql_item, col[0]))
                for key in agg_col_dict:
                    stack.append(("col", node[0], agg_col_dict[key], op_col_dict[key]))
        elif node[0] == "col":
            # full_labels.append(node[2][-1])
            history.append(node[2][0][1])
            if node[1] == "where":
                # stack.append(("value", node[2], "where"))
                stack.append(("op", node[2], "where"))
                # masks.append(COMPONENTS_DICT['op'])
            elif node[1] != "groupBy":
                labels = []
                for sql_item, col in node[2]:
                    _, label = AggPredictor(question_tokens, sql_item, history, node[1]).generate_output()
                    if label - 1 >= 0:
                        labels.append(label - 1)

                # print(node[2][0][1][2])
                if node[1] == "orderBy":
                    stack.append(("des_asc", sql))
                    continue
                masks.append([COMPONENTS_DICT['agg']])
                full_labels.append([labels[:min(len(labels), 3)]])

                # history.append(labels[:min(len(labels), 3)])
                # dataset['agg_dataset'].append({
                #     "question_tokens": question_tokens,
                #     "ts": table_schema,
                #     "history": history[:],
                #     "gt_col": node[2][0][1][2],
                #     "label": labels[:min(len(labels), 3)]
                # })
                # stack.append(("value",node[2],node[1]))
                if node[1] == "having":
                    # stack.append(("value", node[2], "having"))
                    stack.append(("op", node[2], "having"))

                # if len(labels) == 0:
                #     history.append("none")
                # else:
                if len(labels) == 0 and node[1] == "having":
                    history.append("none")
                for v in labels:
                    history.append(AGG_OPS[v + 1])
                    if node[1] != "having":
                        masks.append([-1])
                        full_labels.append([])
                    # history.append([AGG_OPS[labels[i] + 1] for i in range(len(labels))])

        elif node[0] == "des_asc":
            orderby_ret = DesAscPredictor(question_tokens, node[1], table, history).generate_output()

            if not orderby_ret:
                continue
            masks.append([COMPONENTS_DICT['des_asc']])
            # print(node[1])
            history.append(ORDER_DICT[orderby_ret[1]])
            full_labels.append([orderby_ret[1]])
            masks.append([-1])
            full_labels.append([])
        elif node[0] == "value":
            history.append(node[2])
            masks.append([COMPONENTS_DICT['root_tem'],COMPONENTS_DICT['value']])
            val1 = 0 # TODO: node[1][3]
            val2 = 1 # TODO: node[1][4]
            full_labels.append([1, [val1, val2]])
            history.append("value") # TODO: ([val1, val2])
            masks.append([-1])
            full_labels.append([])
            # if val2:
            #     # if len(stack) > 0:
            #
            #         # masks.append(-1)
            #
            #     # history.append(val2)
            # else:
            #     # if len(stack) > 0:
            #     masks.append([-1])
            #     full_labels.append([1,[val1]])
            #     history.append([val1])

        elif node[0] == "op":
            # history.append(node[1][0][1])
            labels = []
            # if len(labels) > 2:
            #     print(question_tokens)
            # dataset['op_dataset'].append({
            #     "question_tokens": question_tokens,
            #     "ts": table_schema,
            #     "history": history[:],
            #     "gt_col": node[1][0][1][2],
            #     "label": labels
            # })


            for sql_item, col in node[1]:
                _, label, s = OpPredictor(question_tokens, sql_item, history).generate_output()
                if label != -1:
                    labels.append(label)

                # masks.append(COMPONENTS_DICT['root_tem'])
                if isinstance(s[0], dict):
                    stack.append(("root", s[0]))
                    history.append(NEW_WHERE_OPS[label])
                    # masks.append([COMPONENTS_DICT['op']])
                    # history.append("root")
                    # dataset['root_tem_dataset'].append({
                    #     "question_tokens": question_tokens,
                    #     "ts": table_schema,
                    #     "history": history[:],
                    #     "gt_col": node[1][0][1][2],
                    #     "label": 0
                    # })
                else:
                    stack.append(("value",sql_item,NEW_WHERE_OPS[label]))

                    # full_labels.append(0)
                    # dataset['root_tem_dataset'].append({
                    #     "question_tokens": question_tokens,
                    #     "ts": table_schema,
                    #     "history": history[:],
                    #     "gt_col": node[1][0][1][2],
                    #     "label": 1
                    # })
                    # history.append("terminal")
            if len(labels) > 2:
                print(question_tokens)
            # dataset['op_dataset'][-1]["label"] = labels
            # full_labels.append(labels)
            # if stack[-1][0] == "value":
            masks.append([COMPONENTS_DICT['op']])
            full_labels.append([labels])
            if stack[-1][0] == "root":
                full_labels.append([0])
                masks.append([COMPONENTS_DICT['root_tem']])
                masks.append([COMPONENTS_DICT['multi_sql']])
        elif node[0] == "where":
            history.append(node[0])
            hist, andor_label = AndOrPredictor(question_tokens, node[1], table, history).generate_output()
            # if andor_label != -1:
                # masks.append(COMPONENTS_DICT['andor'])
                # full_labels.append(label)
                # dataset['andor_dataset'].append({
                #     "question_tokens": question_tokens,
                #     "ts": table_schema,
                #     "history": history[:],
                #     "label": label
                # })
            col_ret = ColPredictor(question_tokens, node[1], table, history, "where").generate_output()
            # masks.append([COMPONENTS_DICT['col']])
            op_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue
                # dataset['col_dataset'].append({
                #     "question_tokens": question_tokens,
                #     "ts": table_schema,
                #     "history": history[:],
                #     "label": get_label_cols(with_join, fk_dict, l[1])
                # })
                label = get_label_cols(with_join, fk_dict, l[1])
                if len(label) > 1:
                    full_labels.append([label,[andor_label]])
                    masks.append([COMPONENTS_DICT['col'],COMPONENTS_DICT['andor']])
                else:
                    full_labels.append([label,[]])
                    masks.append([COMPONENTS_DICT['col'], COMPONENTS_DICT['andor']])
                # full_labels.append()
                for col, sql_item in zip(l[1], s):
                    key = "{}{}{}".format(col[0][0], col[0][1], col[0][2])
                    if key not in op_col_dict:
                        op_col_dict[key] = [(sql_item, col[0])]
                    else:
                        op_col_dict[key].append((sql_item, col[0]))
                for key in op_col_dict:
                    stack.append(("col", "where", op_col_dict[key]))
        elif node[0] == "groupBy":

            history.append(node[0])
            col_ret = ColPredictor(question_tokens, node[1], table, history, node[0]).generate_output()
            masks.append([COMPONENTS_DICT['col']])
            # agg_col_dict = dict()
            for h, l, s in col_ret:
                if l[0] == 0:
                    print("Warning: predicted 0 columns!")
                    continue
                # dataset['col_dataset'].append({
                #     "question_tokens": question_tokens,
                #     "ts": table_schema,
                #     "history": history[:],
                #     "label": get_label_cols(with_join, fk_dict, l[1])
                # })
                history.append(l[1][0][0])
                full_labels.append([get_label_cols(with_join, fk_dict, l[1])])
                if node[2]:
                    stack.append(("having", node[1]))
                    full_labels.append([1])
                    masks.append([COMPONENTS_DICT['having']])
                else:
                    full_labels.append([0])
                    masks.append([COMPONENTS_DICT['having']])

                # for col, sql_item in zip(l[1], s):
                #     key = "{}{}{}".format(col[0][0], col[0][1], col[0][2])
                #     if key not in agg_col_dict:
                #         agg_col_dict[key] = [(sql_item, col[0])]
                #     else:
                #         agg_col_dict[key].append((sql_item, col[0]))
                # for key in agg_col_dict:
                #     stack.append(("col", node[0], agg_col_dict[key]))
    return history,full_labels,masks

In [6]:
# masks   :  [[0], [1], [2], [4], [2, 8], [3], [5], [0], [1], [2], [4], [-1]]
# label  :  [[0], [[0]], [[[14, 12], 4]], [[]], [[21], []], [[2]], [0], [0], [[]], [[21]], [[4]], []]

def unlist_mask_label(mask_list, label_list):
    label_nolist = []
    label_nums = []
    mask_nolist = []
    mask_inds = []
    
    for ll, ml in zip(label_list, mask_list):
        if len(ml) == 1 and ml[0] == 2 and any(isinstance(i, list) for el in ll for i in el): # for column label is [[[14, 12], 4]]
            mask_nolist.append(ml[0])
            col_num = sum([len(i) if isinstance(i, list) else 1 for el in ll for i in el])
            mask_inds.append(col_num)
        elif len(ml) == 1:
            mask_nolist.append(ml[0])
            mask_inds.append(1)
        elif len(ml) == 2:
            mask_nolist.extend(ml)
            mask_inds.extend([0, 0])
        else:
            print("\nWarning: NOT expected length of the mask list greater than 2!")
            exit()
        
        # [0], [[0]], [[[14, 12], 4]], [[]], [], [[21], []]
        if len(ll) == 0: # []
            #label_nolist.extend(ll)
            label_nums.append(0)
        elif not isinstance(ll[0], list) and len(ll) == 1: # [0]
            label_nolist.extend(ll)
            assert len(ll)  == 1
            label_nums.append(1)
        elif not isinstance(ll[0], list) and isinstance(ll[1], list): # [1, ['VALUE_0', None]]
            assert len(ll) == 2
            label_nolist.append(ll[0])
            label_nums.append(1)
            label_nolist.extend(ll[1]) # TODO: change [1, ['VALUE_0', None]]
            label_nums.append(len(ll[1]))
        elif ml[0] == 2 and any(isinstance(i, list) for el in ll for i in el): #for [[[14, 12], 4]]
            col_label = []
            for el in ll:
                for ii in el:
                    if isinstance(ii, list):
                        col_label.extend(ii)
                    else:
                        col_label.append(ii)
            label_nolist.extend(col_label)
            col_nums = [len(i) if isinstance(i, list) else 1 for el in ll for i in el]
            label_nums.extend(col_nums)
        elif len(ll) == 1: # [[0]] or [[1, 3]]
            label_nolist.extend(ll[0])
            label_nums.append(len(ll[0]))
        else:
            assert len(ll) == 2
            label_nolist.extend(ll[0])
            label_nums.append(len(ll[0]))
            label_nolist.extend(ll[1])
            label_nums.append(len(ll[1]))
            
    assert sum(label_nums) == len(label_nolist)
    assert len(mask_nolist) == len(mask_inds)
    
    return mask_nolist, mask_inds, label_nolist, label_nums

In [22]:
def reconstruct_mask_label(mask_one, mask_inds_one, label_list_one, label_nums_one):

    # mask_one/mi/ln: (max_len) mask_list, mask_inds, label_nums
    module_mask_one = []
    module_label_one = []
    module_label_one_temp = []
    two_mods = []
    label_index = 0
    for j, mlj, mij in zip(count(), mask_one, mask_inds_one):
        if mij != 0:
            if len(two_mods) > 0:
                module_mask_one.append(two_mods)
                two_mods = []   
            module_mask_one.append([mlj])
        elif mij == 0:
            two_mods.append(mlj)
        
    ln = label_nums_one.copy()
    for lnj in ln:
        label_add = []
        for _ in range(lnj):
            label_add.append(label_list_one[label_index])
            label_index += 1
        module_label_one_temp.append(label_add)
        
    label_num_tmp = []
    for j, mij in zip(count(), mask_inds_one):
        if mij < 2:
            label_num_tmp.append(ln.pop(0))
        else:
            lbs = []
            num_count = 0
            while num_count < mij:
                lnp = ln.pop(0)
                lbs.append(lnp)
                num_count += lnp
            assert num_count == mij
            label_num_tmp.append(lbs)
    
    two_labels = []
    assert len(mask_inds_one) == len(label_num_tmp)
    for k, mlk, mik, lnk in zip(count(), mask_one, mask_inds_one, label_num_tmp):
        print("----------lnk: ", lnk)
        if mik == 1:
            if len(two_labels) > 0:
                module_label_one.append(two_labels)
                two_labels = []
            if mlk in [0, 7, -1, 5, 6]:
                module_label_one.append(module_label_one_temp.pop(0))
            else:
                module_label_one.append([module_label_one_temp.pop(0)])
        elif mik == 0:
            label = module_label_one_temp.pop(0)
            if mlk == 5: # for case [1, ['VALUE_0', None]]
                two_labels.append(label[0])
            else:
                two_labels.append(label)
        else:
            col_labels = []
            print("lnk: ", lnk)
            for lnki in lnk:
                col_label = module_label_one_temp.pop(0)
                assert len(col_label) == lnki
                if len(col_label) == 1:
                    col_labels.append(col_label[0])
                else:
                    col_labels.append(col_label)
            module_label_one.append([col_labels])
            
    return module_mask_one, module_label_one

In [23]:
mask_list = [[0], [1], [2], [4], [2, 8], [3], [5], [0], [1], [2], [4], [-1]]
label_list = [[0], [[0]], [[[14, 12], 4]], [[]], [[21], []], [[2]], [0], [0], [[]], [[21]], [[4]], []]

mask_nolist, mask_inds, label_nolist, label_nums = unlist_mask_label(mask_list, label_list)
print("mask_list: ", mask_list)
print("label_list: ", label_list)
print("mask_nolist : ", mask_nolist)
print("mask_inds   : ", mask_inds)
print("label_nums  : ", label_nums)
print("label_nolist: ", label_nolist)

mask_list:  [[0], [1], [2], [4], [2, 8], [3], [5], [0], [1], [2], [4], [-1]]
label_list:  [[0], [[0]], [[[14, 12], 4]], [[]], [[21], []], [[2]], [0], [0], [[]], [[21]], [[4]], []]
mask_nolist :  [0, 1, 2, 4, 2, 8, 3, 5, 0, 1, 2, 4, -1]
mask_inds   :  [1, 1, 3, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1]
label_nums  :  [1, 1, 2, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0]
label_nolist:  [0, 0, 14, 12, 4, 21, 2, 0, 0, 21, 4]


In [24]:
def replace_value(conditions,nl,mp):
    for cond in conditions:
        for i,value in enumerate(cond):
            if i < 3:
                continue
            if not value or isinstance(value,dict):
                continue
            old_value = value
            if isinstance(value,str):
                if value[0] in ('\'','\"'):
                    value = value[1:-1]
                value = value.split()
            else:
                value = [value]
            try:
                new_val = 'VALUE_{}'.format(len(mp))
                cond[i] = new_val
                mp.append(old_value)
                if isinstance(value[0],str):
                    idx = nl.index(value[0])
                else:
                    idx = -1
                    for i in range(len(nl)):
                        if nl[i].isdigit() and (float(nl[i]) == value[0]):
                            idx = i
                            break
                    if idx == -1:
                        # print(old_value)
                        # print(nl)
                        continue
                nl = nl[:idx] + [new_val] + nl[idx+len(value):]
            except Exception:
                # print(old_value)
                # print(nl)
                continue
    return conditions,nl

def replace_nl(sql,nl):
    mp = []
    sql["where"],nl = replace_value(sql["where"],nl,mp)
    sql["having"],nl = replace_value(sql["having"],nl,mp)
    d = {}
    for i,val in enumerate(mp):
        d["VALUE_{}".format(i)] = val
    return d,sql,nl

def get_table_schema(table):
    col_names = table["column_names"]
    tab_names = table["table_names"]
    col_types = table["column_types"]
    col_name_comb = []
    for coln, colt in zip(col_names, col_types):
        tab_id, coln_str = coln
        if tab_id == -1:
            col_name_comb.append(["all", coln_str])
        else:
            col_name_comb.append(tab_names[tab_id].split(" ") + coln_str.split(" ") + [colt])
    
    return col_name_comb

def get_col_in_history(history):
    col_inds = []
    col_mask = []
    history_col_replaced = []
    #["singer", "age", 13]
    for hs in history:
        if isinstance(hs, tuple):
            col_inds.append(hs[2])
            col_mask.append(1)
            history_col_replaced.append("column")
        else:
            col_inds.append(-1)
            col_mask.append(0)
            history_col_replaced.append(hs)
    
    return col_inds, col_mask, history_col_replaced
        
            
def parse_data_new_format(data, table_dict):
    dataset = []
    for item in data:
        table_one = table_dict[item["db_id"]]
        table_schema = get_table_schema(table_one)
        mp,sql,nl = replace_nl(item["sql"],item["question_toks"])
        history, labels, masks = parse_data_full_history(nl, sql, table_one, [])
        mask_nolist, mask_inds, label_nolist, label_nums = unlist_mask_label(masks, labels)
        mask_list_recont, label_list_recont = reconstruct_mask_label(mask_nolist, mask_inds, label_nolist, label_nums)
        
        col_inds, col_mask, history_col_replaced = get_col_in_history(history)
#         if not len(history) == len(labels) == len(masks):
        if not (mask_list_recont == masks and label_list_recont == labels):
            print('\n-------------------------------')
            print('len of hisotry: ', len(history), 'len of labels: ', len(labels), "len of masks: ", len(masks))
            print("query  : ", item['query'])
            print("history: ", history)
            print("masks    : ", masks)
#             print("mask_rec : ", mask_list_recont)
            print("label    : ", labels)
#             print("label_rec: ", label_list_recont)
        else:
            dataset.append({
                "sql_history":history_col_replaced,
                "module_label":labels,
                "module_mask":masks,
                "map":mp,
                "src":nl,
                "table_schema": table_schema,
                "col_inds": col_inds,
                "col_mask": col_mask
            })
            
#     with open('dev_init.json', 'w') as outfile:
#         json.dump(dataset, outfile)
    

In [25]:
COMPONENTS_DICT = {
    'multi_sql':0,
    'keyword':1,
    'col':2,
    'op':3,
    'agg':4,
    'root_tem':5,
    'des_asc':6,
    'having':7,
    'andor':8,
    'value':9
}

In [26]:
# stepten no order by
# SELECT T1.Area FROM APPELLATIONS AS T1 JOIN WINE AS T2 ON T1.Appelation  =  T2.Appelation GROUP BY T2.Appelation HAVING T2.year  <  2010 ORDER BY count(*) DESC LIMIT 1

In [27]:
data_path = '../../SyntaxSQL/data/dev.json'
table_data_path = '../../SyntaxSQL/data/tables.json'
data = json.load(open(data_path))
table_dict = get_table_dict(table_data_path)
parse_data_new_format(data, table_dict)

In [28]:
train_data = torch.load('../data_model/syntaxSQL/train.pt')

In [69]:
for example in train_data.examples:
    print("\n-----------------------")
    print("src: ", example.src)
#     print("tbl: ", example.tbl)
    print("col_inds: ", example.col_inds)
    print("col_mask: ", example.col_mask)
    print("sql_history: ", example.sql_history)
    print("mask_list: ", example.mask_list)
    print("mask_inds: ", example.mask_inds)
    print("label_list: ", example.label_list)
    print("label_nums: ", example.label_nums)


-----------------------
src:  ['How', 'many', 'singers', 'do', 'we', 'have', '?']
col_inds:  [-1, -1, -1, 0, -1]
col_mask:  [0, 0, 0, 1, 0]
sql_history:  ['root', 'none', 'select', 'column', 'count']
mask_list:  [0, 1, 2, 4, -1]
mask_inds:  [1, 1, 1, 1, 1]
label_list:  [0, 0, 2]
label_nums:  [1, 0, 1, 1, 0]

-----------------------
src:  ['What', 'is', 'the', 'total', 'number', 'of', 'singers', '?']
col_inds:  [-1, -1, -1, 0, -1]
col_mask:  [0, 0, 0, 1, 0]
sql_history:  ['root', 'none', 'select', 'column', 'count']
mask_list:  [0, 1, 2, 4, -1]
mask_inds:  [1, 1, 1, 1, 1]
label_list:  [0, 0, 2]
label_nums:  [1, 0, 1, 1, 0]

-----------------------
src:  ['Show', 'name', ',', 'country', ',', 'age', 'for', 'all', 'singers', 'ordered', 'by', 'age', 'from', 'the', 'oldest', 'to', 'the', 'youngest', '.']
col_inds:  [-1, -1, -1, 13, 10, 9, -1, 13, -1]
col_mask:  [0, 0, 0, 1, 1, 1, 0, 1, 0]
sql_history:  ['root', 'none', 'select', 'column', 'column', 'column', 'orderBy', 'column', 'desc']
mas

label_list:  [0, 1, 21, 23, 4, 23, 0]
label_nums:  [1, 1, 2, 0, 1, 0, 1, 1]

-----------------------
src:  ['What', 'is', 'the', 'average', 'weight', 'and', 'year', 'for', 'each', 'year', '?']
col_inds:  [-1, -1, -1, 23, 21, -1, -1, 23]
col_mask:  [0, 0, 0, 1, 1, 0, 0, 1]
sql_history:  ['root', 'none', 'select', 'column', 'column', 'avg', 'groupBy', 'column']
mask_list:  [0, 1, 2, 4, 4, -1, 2, 7]
mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1]
label_list:  [0, 1, 21, 23, 4, 23, 0]
label_nums:  [1, 1, 2, 0, 1, 0, 1, 1]

-----------------------
src:  ['Which', 'countries', 'in', 'VALUE_0', 'have', 'at', 'least', 'VALUE_1', 'car', 'manufacturers', '?']
col_inds:  [-1, -1, -1, 4, -1, 2, -1, -1, -1, 4, -1, 0, -1, -1, -1]
col_mask:  [0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0]
sql_history:  ['root', 'none', 'select', 'column', 'where', 'column', '=', 'value', 'groupBy', 'column', 'having', 'column', 'count', '>=', 'value']
mask_list:  [0, 1, 2, 4, 2, 8, 3, 5, 9, -1, 2, 7, 2, 4, 3, 5, 9, -1]
mask_inds:

mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1]
label_list:  [0, 1, 5, 6, 27, 1, 0, 2, 1, 1, 0, 1]
label_nums:  [1, 1, 2, 0, 0, 1, 1, 1, 1, 1, 1, 2, 0]

-----------------------
src:  ['Which', 'cities', 'have', 'served', 'as', 'host', 'cities', 'more', 'than', 'once', '?', 'Return', 'me', 'their', 'GDP', 'and', 'population', '.']
col_inds:  [-1, -1, -1, 5, 6, -1, 27, -1, 0, -1, -1, -1]
col_mask:  [0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0]
sql_history:  ['root', 'none', 'select', 'column', 'column', 'groupBy', 'column', 'having', 'column', 'count', '>', 'value']
mask_list:  [0, 1, 2, 4, 4, 2, 7, 2, 4, 3, 5, 9, -1]
mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1]
label_list:  [0, 1, 5, 6, 27, 1, 0, 2, 1, 1, 0, 1]
label_nums:  [1, 1, 2, 0, 0, 1, 1, 1, 1, 1, 1, 2, 0]

-----------------------
src:  ['List', 'every', 'individual', "'s", 'first', 'name', ',', 'middle', 'name', 'and', 'last', 'name', 'in', 'alphabetical', 'order', 'by', 'last', 'name', '.']
col_inds:  [-1, -1, -1, 23, 19, 18,

col_inds:  [-1, -1, -1, 4, -1, 3, -1, -1, -1, 4, -1, 0, -1, -1, -1]
col_mask:  [0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0]
sql_history:  ['root', 'none', 'select', 'column', 'where', 'column', '<', 'value', 'groupBy', 'column', 'having', 'column', 'count', '>', 'value']
mask_list:  [0, 1, 2, 4, 2, 8, 3, 5, 9, -1, 2, 7, 2, 4, 3, 5, 9, -1]
mask_inds:  [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1]
label_list:  [0, 0, 1, 4, 3, 2, 1, 0, 1, 4, 1, 0, 2, 1, 1, 0, 1]
label_nums:  [1, 2, 1, 0, 1, 0, 1, 1, 2, 0, 1, 1, 1, 1, 1, 1, 2, 0]

-----------------------
src:  ['Find', 'the', 'cities', 'that', 'have', 'more', 'than', 'one', 'employee', 'under', 'age', 'VALUE_0', '.']
col_inds:  [-1, -1, -1, 4, -1, 3, -1, -1, -1, 4, -1, 0, -1, -1, -1]
col_mask:  [0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0]
sql_history:  ['root', 'none', 'select', 'column', 'where', 'column', '<', 'value', 'groupBy', 'column', 'having', 'column', 'count', '>', 'value']
mask_list:  [0, 1, 2, 4, 2, 8, 3, 5, 9, -1, 2


-----------------------
src:  ['What', 'are', 'the', 'average', 'and', 'maximum', 'number', 'of', 'tickets', 'bought', 'in', 'all', 'visits', '?']
col_inds:  [-1, -1, -1, 11, -1, -1]
col_mask:  [0, 0, 0, 1, 0, 0]
sql_history:  ['root', 'none', 'select', 'column', 'avg', 'max']
mask_list:  [0, 1, 2, 4, -1, -1]
mask_inds:  [1, 1, 1, 1, 1, 1]
label_list:  [0, 11, 4, 0]
label_nums:  [1, 0, 1, 2, 0, 0]

-----------------------
src:  ['What', 'is', 'the', 'total', 'ticket', 'expense', 'of', 'the', 'visitors', 'whose', 'membership', 'level', 'is', 'VALUE_0', '?']
col_inds:  [-1, -1, -1, 12, -1, -1, 7, -1, -1]
col_mask:  [0, 0, 0, 1, 0, 0, 1, 0, 0]
sql_history:  ['root', 'none', 'select', 'column', 'sum', 'where', 'column', '=', 'value']
mask_list:  [0, 1, 2, 4, -1, 2, 8, 3, 5, 9, -1]
mask_inds:  [1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1]
label_list:  [0, 0, 12, 3, 7, 0, 1, 0, 1]
label_nums:  [1, 1, 1, 1, 0, 1, 0, 1, 1, 2, 0]

-----------------------
src:  ['What', 'is', 'the', 'name', 'of', 'the', 'v

label_list:  [0, 1, 0, 2, 2, 2, 0]
label_nums:  [1, 1, 2, 0, 1, 0, 1, 1]

-----------------------
src:  ['What', 'are', 'the', 'names', 'of', 'the', 'scientists', ',', 'and', 'how', 'many', 'projects', 'are', 'each', 'of', 'them', 'working', 'on', '?']
col_inds:  [-1, -1, -1, 2, 0, -1, -1, 2]
col_mask:  [0, 0, 0, 1, 1, 0, 0, 1]
sql_history:  ['root', 'none', 'select', 'column', 'column', 'count', 'groupBy', 'column']
mask_list:  [0, 1, 2, 4, 4, -1, 2, 7]
mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1]
label_list:  [0, 1, 0, 2, 2, 2, 0]
label_nums:  [1, 1, 2, 0, 1, 0, 1, 1]

-----------------------
src:  ['Find', 'the', 'SSN', 'and', 'name', 'of', 'scientists', 'who', 'are', 'assigned', 'to', 'the', 'project', 'with', 'the', 'longest', 'hours', '.']
col_inds:  [-1, -1, -1, 2, 1, -1, 5, -1, -1, -1, -1, 5, -1]
col_mask:  [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0]
sql_history:  ['root', 'none', 'select', 'column', 'column', 'where', 'column', '=', 'root', 'none', 'select', 'column', 'max']
mask_list:  [

sql_history:  ['root', 'intersect', 'root', 'none', 'select', 'column', 'where', 'column', '=', 'value', 'root', 'none', 'select', 'column', 'where', 'column', '=', 'value']
mask_list:  [0, 0, -1, 1, 2, 4, 2, 8, 3, 5, 9, -1, 0, 1, 2, 4, 2, 8, 3, 5, 9, -1]
mask_inds:  [1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1]
label_list:  [1, 0, 0, 34, 38, 0, 1, 0, 1, 0, 0, 34, 38, 0, 1, 0, 1]
label_nums:  [1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 2, 0, 1, 1, 1, 0, 1, 0, 1, 1, 2, 0]

-----------------------
src:  ['What', 'are', 'the', 'names', 'of', 'players', 'who', 'won', 'in', 'both', 'VALUE_0', 'and', '2016', '?']
col_inds:  [-1, -1, -1, -1, -1, 34, -1, 38, -1, -1, -1, -1, -1, 34, -1, 38, -1, -1]
col_mask:  [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0]
sql_history:  ['root', 'intersect', 'root', 'none', 'select', 'column', 'where', 'column', '=', 'value', 'root', 'none', 'select', 'column', 'where', 'column', '=', 'value']
mask_list:  [0, 0, -1, 1, 2, 4, 2, 8, 3, 5, 9, -1, 0,

label_list:  [0, 2, 3]
label_nums:  [1, 0, 2, 0, 0]

-----------------------
src:  ['What', 'is', 'the', 'first', 'and', 'second', 'line', 'for', 'all', 'addresses', '?']
col_inds:  [-1, -1, -1, 3, 2]
col_mask:  [0, 0, 0, 1, 1]
sql_history:  ['root', 'none', 'select', 'column', 'column']
mask_list:  [0, 1, 2, 4, 4]
mask_inds:  [1, 1, 1, 1, 1]
label_list:  [0, 2, 3]
label_nums:  [1, 0, 2, 0, 0]

-----------------------
src:  ['How', 'many', 'courses', 'in', 'total', 'are', 'listed', '?']
col_inds:  [-1, -1, -1, 0, -1]
col_mask:  [0, 0, 0, 1, 0]
sql_history:  ['root', 'none', 'select', 'column', 'count']
mask_list:  [0, 1, 2, 4, -1]
mask_inds:  [1, 1, 1, 1, 1]
label_list:  [0, 0, 2]
label_nums:  [1, 0, 1, 1, 0]

-----------------------
src:  ['How', 'many', 'courses', 'are', 'there', '?']
col_inds:  [-1, -1, -1, 0, -1]
col_mask:  [0, 0, 0, 1, 0]
sql_history:  ['root', 'none', 'select', 'column', 'count']
mask_list:  [0, 1, 2, 4, -1]
mask_inds:  [1, 1, 1, 1, 1]
label_list:  [0, 0, 2]
labe

src:  ['What', 'are', 'the', 'faculty', 'id', 'and', 'the', 'number', 'of', 'students', 'each', 'faculty', 'has', '?']
col_inds:  [-1, -1, -1, 0, -1, 15, -1, 15]
col_mask:  [0, 0, 0, 1, 0, 1, 0, 1]
sql_history:  ['root', 'none', 'select', 'column', 'count', 'column', 'groupBy', 'column']
mask_list:  [0, 1, 2, 4, -1, 4, 2, 7]
mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1]
label_list:  [0, 1, 0, 15, 2, 15, 0]
label_nums:  [1, 1, 2, 1, 0, 0, 1, 1]

-----------------------
src:  ['Show', 'all', 'the', 'faculty', 'ranks', 'and', 'the', 'number', 'of', 'students', 'advised', 'by', 'each', 'rank', '.']
col_inds:  [-1, -1, -1, 0, -1, 18, -1, 18]
col_mask:  [0, 0, 0, 1, 0, 1, 0, 1]
sql_history:  ['root', 'none', 'select', 'column', 'count', 'column', 'groupBy', 'column']
mask_list:  [0, 1, 2, 4, -1, 4, 2, 7]
mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1]
label_list:  [0, 1, 0, 18, 2, 18, 0]
label_nums:  [1, 1, 2, 1, 0, 0, 1, 1]

-----------------------
src:  ['How', 'many', 'students', 'are', 'advised', 'by', 'eac

col_mask:  [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0]
sql_history:  ['root', 'intersect', 'root', 'none', 'select', 'column', 'where', 'column', '=', 'value', 'column', '=', 'value', 'root', 'none', 'select', 'column', 'where', 'column', '=', 'value', 'column', '=', 'value']
mask_list:  [0, 0, -1, 1, 2, 4, 2, 8, 3, 5, 9, -1, 3, 5, 9, -1, 0, 1, 2, 4, 2, 8, 3, 5, 9, -1, 3, 5, 9, -1]
mask_inds:  [1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1]
label_list:  [1, 0, 0, 9, 24, 25, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 9, 24, 25, 0, 0, 1, 0, 1, 0, 1, 0, 1]
label_nums:  [1, 0, 1, 1, 1, 0, 2, 1, 1, 1, 2, 0, 1, 1, 2, 0, 1, 1, 1, 0, 2, 1, 1, 1, 2, 0, 1, 1, 2, 0]

-----------------------
src:  ['Give', 'the', 'names', 'of', 'countries', 'with', 'VALUE_0', 'and', 'French', 'as', 'official', 'languages', '.']
col_inds:  [-1, -1, -1, -1, -1, 9, -1, 25, -1, -1, 24, -1, -1, -1, -1, -1, 9, -1, 25, -1, -1, 24, -1, -1]
col_mask:  [0, 0, 0, 

col_inds:  [-1, -1, -1, -1, -1, 5, -1, -1, -1, 5, -1, 2, -1, -1]
col_mask:  [0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0]
sql_history:  ['root', 'except', 'root', 'none', 'select', 'column', 'root', 'none', 'select', 'column', 'where', 'column', '=', 'value']
mask_list:  [0, 0, -1, 1, 2, 4, 0, 1, 2, 4, 2, 8, 3, 5, 9, -1]
mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1]
label_list:  [3, 0, 5, 0, 0, 5, 2, 0, 1, 0, 1]
label_nums:  [1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 2, 0]

-----------------------
src:  ['Show', 'all', 'movie', 'titles', ',', 'years', ',', 'and', 'directors', ',', 'ordered', 'by', 'budget', '.']
col_inds:  [-1, -1, -1, 11, 10, 9, -1, 12, -1]
col_mask:  [0, 0, 0, 1, 1, 1, 0, 1, 0]
sql_history:  ['root', 'none', 'select', 'column', 'column', 'column', 'orderBy', 'column', 'asc']
mask_list:  [0, 1, 2, 4, 4, 4, 2, 6, -1]
mask_inds:  [1, 1, 1, 1, 1, 1, 1, 1, 1]
label_list:  [0, 2, 9, 10, 11, 12, 1]
label_nums:  [1, 1, 3, 0, 0, 0, 1, 1, 0]

----------------------