In [None]:
import os
import re
import json
import random
import codecs
from template_config import *
from nltk import word_tokenize
from collections import defaultdict
from transformers.tokenization_roberta import RobertaTokenizer

SEP_TOKEN = "</s>"
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
MAX_TOKEN_LEN = 189

### Read tables from text files

In [None]:
train_corpus = "data/wikitable_dup1_row1.txt"
# output_file = "data/data_comb_tables.txt"

In [None]:
def hasNumbers(inputString):
    return any(char.isdigit() for char in inputString)

def check_name(inpStr):
    return len(inpStr) > 1 and "-" not in inpStr and not hasNumbers(inpStr)

def gen_name(title, must_have=False):
    title_tokens = title.split(" ")
    qualify_words = []
    for w in title_tokens:
        if check_name(w):
            qualify_words.append(w)
    
    if random.random() < 0.4:
        name = " ".join(qualify_words[-2:])
    else:
        name = " ".join(qualify_words[-1:])
    
    if name != "":
        return name
    
    if must_have:
        return title_tokens[0]
    else:
        return name

In [None]:
def main_process(train_corpus):
    total_count = 0
    webtables = []
    with open(train_corpus, "r", encoding="utf-8") as f:
        for line in f:
            skip = False
            if total_count % 100000 == 0:
                print("processed: ", total_count)
            tokens = line.lower().replace("<special7>", "<tabn>").replace("<special8>", "<coln>").replace("<special9>", "<entry>").replace("*", "").replace("|||", "")
            table = {"columns": [], "values": [], "columns_original": [], "column_types": []}
            chunks = tokens.split(" <coln> ")
            for chunk in chunks:
                if "<tabn>" in chunk:
                    page_title = chunk.replace("<tabn>", "").strip()
                    table_name = gen_name(page_title)
                    table["name"] = table_name
                    if table_name == "" or len(table_name) < 2:
                        skip = True
                else:
                    assert "<entry>" in chunk
                    chunk_toks = chunk.split(" <entry> ")
                    if len(chunk_toks) == 2:
                        col_name, entry = chunk_toks[0].strip(), chunk_toks[1].strip()
                        if len(col_name) > 1:
                            table["columns"].append(" ".join(col_name.split(" ")[:5]))
                            table["columns_original"].append(col_name)
                            ctype = "text"
                            if entry.isdigit():
                                ctype = "real"
                            table["column_types"].append(ctype)
                            table["values"].append(" ".join(entry.split(" ")[:5]))
                            
            
            if len(table["columns"]) < 3:
                skip = True
            if not skip:
                table_name = table["name"]
                table["columns"] = [table_name + " *"] + table["columns"]
                table["columns_original"] = ["*"] + table["columns_original"]
                table["column_types"] = ["text"] + table["column_types"]
                table["values"] = ["all"] + table["values"]
                tabn_str = "_".join(table_name.split(" "))
                table["columns"] = [tabn_str +" "+ hd for hd in table["columns"]]
                if "*" not in table['columns'][0]:
                    print(table['columns'])
                webtables.append(table)
                
            total_count += 1
            
    return webtables

In [None]:
web_tables = main_process(train_corpus)

processed:  0
processed:  100000
processed:  200000


In [None]:
web_tables[1]

{'columns': ['transit transit *',
  'transit order year',
  'transit manufacturer',
  'transit model',
  'transit fleet series ( quantity )',
  'transit powertrain ( engine/transmission )',
  'transit fuel propulsion'],
 'values': ['all',
  '1111-11',
  'gillig',
  'phantom ( high floor )',
  '111-111 ( 11 )',
  'dd s10egr allison wb-100r',
  'diesel'],
 'columns_original': ['*',
  'order year',
  'manufacturer',
  'model',
  'fleet series ( quantity )',
  'powertrain ( engine/transmission )',
  'fuel propulsion'],
 'column_types': ['text', 'text', 'text', 'text', 'text', 'text', 'text'],
 'name': 'transit'}

### Read NL-SQL templates and sql component mapping file

In [None]:
MAX_COL_NUM = 25
OPS = ["=", ">", "<", ">=", "<=", "!=", "LIKE"]
nlsql_templates_file = "data/nlsql_templates_context.txt"
nlsql_templates_iso_file = "data/nlsql_templates.txt"
sql_components_file = "data/sql_components.json"

In [None]:
# read NL-SQL templates
templates = []
with open(nlsql_templates_file) as fp:
    lines = fp.readlines()
    template_one = {}
    for line in lines:
        if "\n" == line:
            templates.append(template_one) 
        elif "SQL Pattern:" in line:
            template_one = {}
            sps = line.strip().replace("SQL Pattern: ", "").split("|||")
            template_one["questions"] = []
            if len(sps) == 1:
                template_one["SQL pattern"] = sps[0]
                template_one["SQL constraints"] = []
            elif len(sps) == 2:
                template_one["SQL pattern"] = sps[0]
                template_one["SQL constraints"] = [x.strip() for x in sps[1].split("|") if x != " "]
            else:
                print("\n======Error warning!!!!")
        elif "count: " in line:
            sql_count = int(line.strip().replace("count: ", ""))
            template_one["count"] = sql_count
        elif "question:  " in line:
            sps = line.strip().replace("question:  ", "").split("|||")
            question = sps[0]
            if len(sps) == 2:
                q_constraints = [x.strip() for x in sps[1].split("|") if x != " "]
            else:
                q_constraints = []
            template_one["questions"].append((question, q_constraints))

In [None]:
templates_one_table = []
for template in templates:
    sql_constraints = template['SQL constraints']
    sql_pattern = template["SQL pattern"]
    questions = template["questions"]
    skip = False
    for constraint in sql_constraints:
        if "id" in constraint or "T1" in constraint:
            skip = True
    questions_after = []     
    if not skip:
        for q, qc in questions:
            if "TABLE1" not in q:
                questions_after.append((q, qc))
        if len(questions_after) > 0:
            template_one = {}
            template_one['SQL constraints'] = sql_constraints
            template_one['SQL pattern'] = sql_pattern
            template_one["questions"] = questions_after
            templates_one_table.append(template_one)

In [None]:
all_constraints = []
for tmp in templates_one_table:
    all_constraints.extend(tmp['SQL constraints'])
    for q in tmp['questions']:
        all_constraints.extend(q[1])

print(list(set(all_constraints)))

['P0==']


In [None]:
# read SQL component file
with open(sql_components_file) as json_file:
    sql_components = json.load(json_file)

In [None]:
# get labels for question sep tokens

templates_all = []
with open(nlsql_templates_iso_file) as fp:
    lines = fp.readlines()
    template_one = {}
    for line in lines:
        if "\n" == line:
            templates_all.append(template_one['SQL pattern']) 
        elif "SQL Pattern:" in line:
            template_one = {}
            sps = line.strip().replace("SQL Pattern: ", "").split("|||")
            template_one["questions"] = []
            if len(sps) == 1:
                template_one["SQL pattern"] = sps[0]
                template_one["SQL constraints"] = []
            elif len(sps) == 2:
                template_one["SQL pattern"] = sps[0]
                template_one["SQL constraints"] = [x.strip() for x in sps[1].split("|") if x != " "]
            else:
                print("\n======Error warning!!!!")
        elif "count: " in line:
            sql_count = int(line.strip().replace("count: ", ""))
            template_one["count"] = sql_count
        elif "question:  " in line:
            sps = line.strip().replace("question:  ", "").split("|||")
            question = sps[0]
            if len(sps) == 2:
                q_constraints = [x.strip() for x in sps[1].split("|") if x != " "]
            else:
                q_constraints = []
            template_one["questions"].append((question, q_constraints))

context_templates_file = "data/context_templates.json"
with open(context_templates_file) as json_file:
    context_templates = json.load(json_file)

for ct in context_templates:
    templates_all.insert(0, ct["label"])

In [None]:
qsep_label_map = {}
for i, ex in enumerate(templates_all):
    qsep_label_map[ex] = i + 1

In [None]:
# qsep_label_map

In [None]:
with open("data/qsep_label_map.json", "w") as f:
    json.dump(qsep_label_map, f, indent=2)

### Unify and combine tables as databases

In [None]:
def create_dbs(tables):
#     random.shuffle(tables)  
    dbs = []
    cur_cols = []
    db_one = []
    ahd_cols = []
    for i, tab in enumerate(tables):
        if i % 100000 == 0:
            print("processed: ", i)
        if len(db_one) <= random.choice([0, 1]) and len(ahd_cols) < MAX_COL_NUM:
            db_one.append(tab)
            cur_cols.extend([col+"."+tab["name"] for col in tab["columns"]])
            if i+1 < len(tables):
                ahd_cols = cur_cols + [col+"."+tables[i+1]["name"] for col in tables[i+1]["columns"]]
            else:
                 break
        else:
            if len(cur_cols) == len(list(set(cur_cols))):
                if len(db_one) > 1:
                    db_one_new = []
                    for tab in db_one:
                        if tab["columns"][0] == "id":
                            tab["columns"] = tab["columns"][1:]
                            tab["column_types"] = tab["column_types"][1:]
                            tab["columns_original"] = tab["columns_original"][1:]
                            tab["values"] = tab["values"][1:]
                            
                        if random.random() < 0.7:
                            index_col = "id"
                            if random.random() < 0.3:
                                index_col = "name"

                            if index_col not in tab["columns"]:
                                tabn_str = "_".join(tab["name"].split(" "))
                                tab["columns"] = [tab["columns"][0]] + [tabn_str +" "+ index_col] + tab["columns"][1:]
                                val_add = 1
                                if index_col == "name":
                                    val_add = "value"
                                tab["values"] = [tab["values"][0]] + [val_add] + tab["values"][1:]
                                tab["column_types"] = [tab["column_types"][0]] + ["text"] + tab["column_types"][1:]
                                tab["columns_original"] = [tab["columns_original"][0]] + [index_col] + tab["columns_original"][1:]
                        db_one_new.append(tab)
                    dbs.append(db_one_new)
                else:
                    dbs.append(db_one)
            db_one = []
            cur_cols = []
            ahd_cols = []
            
    return dbs

In [None]:
webtable_dbs = create_dbs(web_tables)

processed:  0
processed:  100000
processed:  200000


In [None]:
len(webtable_dbs)

83973

In [None]:
webtable_dbs[2]

[{'columns': ['year_award year award *',
   'year_award season',
   'year_award player',
   'year_award position',
   'year_award nationality',
   'year_award team',
   'year_award draft pick #',
   'year_award draft class',
   'year_award college'],
  'values': ['all',
   '1111',
   'steve ralston category : articles',
   'midfielder',
   'united states',
   'tampa bay mutiny',
   '11',
   '1111 mls college draft',
   'florida international'],
  'columns_original': ['*',
   'season',
   'player',
   'position',
   'nationality',
   'team',
   'draft pick #',
   'draft class',
   'college'],
  'column_types': ['text',
   'real',
   'text',
   'text',
   'text',
   'text',
   'real',
   'text',
   'text'],
  'name': 'year award'},
 {'columns': ['directions directions *',
   'directions name',
   'directions name',
   'directions direction',
   'directions mantra',
   'directions weapon',
   'directions consort',
   'directions graha ( planet )',
   'directions guardian mātṛkā'],
  'valu

In [1]:
for db in webtable_dbs[:1000]:
    tab_names = []
    col_count = 0
    for tab in db:
        tab_names.append(tab["name"])
        col_count += len(tab["columns"])
    print("----------")
    print("table names: ", tab_names)
    print("column num: ", col_count)
    print("table num: ", len(tab_names))

### Start generate NL-SQL examples based on new databases and CFG grammars

##### detect question and SQL slots and process constraints

In [None]:
def get_sql_slots(sql_pattern):
    sql_tokens = sql_pattern.split(" ")
    columns = {}
    ops = {}
    values = {}
    aggs = {}
    dasc = False
    slots = []
    val_pros = []
    for i, tok in enumerate(sql_tokens):
        if "{" in tok and "}" in tok and "FROM" not in tok:
            if tok not in slots:
                slots.append(tok)
                
        if "AGG" in tok:
            if i + 2 < len(sql_tokens) and "(" == sql_tokens[i+1]:
                if "COLUMN" in sql_tokens[i+2]:
                    if sql_tokens[i+2] not in columns.keys():
                        columns[sql_tokens[i+2]] = ["number"]
                    else:
                        columns[sql_tokens[i+2]].append("number")
                    aggs[tok] = sql_tokens[i+2]
                else:
                    print("\nTemplate Error: AGG format is wrong!!!")
                    print(sql_pattern)
        elif "COLUMN" in tok:
            if tok not in columns.keys():
                columns[tok] = []
        elif "OP" in tok:
            if i - 1 >= 0 and "COLUMN" in sql_tokens[i-1]:
                ops[tok] = [sql_tokens[i-1]]
                if i + 1 < len(sql_tokens) and "VALUE" in sql_tokens[i+1]:
                    ops[tok].append(sql_tokens[i+1])
                    val_pros.append(sql_tokens[i+1])
            elif i - 2 >= 0 and ")" == sql_tokens[i-1] and ("COLUMN" in sql_tokens[i-2] or "*" == sql_tokens[i-2]):
                ops[tok] = [sql_tokens[i-2]]
                if i + 1 < len(sql_tokens) and "VALUE" in sql_tokens[i+1]:
                    ops[tok].append(sql_tokens[i+1])
                    val_pros.append(sql_tokens[i+1])
            else:
                print("\nTemplate Error: OP format is wrong!!!")
                print(sql_pattern)
        elif "VALUE" in tok and tok not in val_pros:
            """
            OP} {VALUE0}
            LIMIT {VALUE0}
            {COLUMN1} BETWEEN {VALUE0} AND {VALUE1}
            HAVING COUNT ( * ) {OP1} {VALUE1}
            = {VALUE1}
            """
            if i - 2 >= 0 and ("BETWEEN" == sql_tokens[i-1] or "AND" == sql_tokens[i-1]):
                values[tok] = "number"
                if "BETWEEN" == sql_tokens[i-1]:
                    columns[sql_tokens[i-2]].append("number")
            elif i - 1 >= 0 and "LIMIT" == sql_tokens[i-1]:
                values[tok] = "integer"
            elif i - 1 >= 0 and "=" == sql_tokens[i-1]:
                assert "COLUMN" in sql_tokens[i-2]
                columns[sql_tokens[i-2]].append(tok)
            else:
                print("\nTemplate Error: VALUE format is wrong!!!")
                print(sql_pattern)
        elif "DASC" in tok:
            dasc = True
    
    return (list(set(slots)), columns, ops, values, aggs, dasc)


def get_q_slots(question):
    q_toks = [x.replace("?", "").replace("!", "").replace(".", "") for x in question.strip().split(" ")]
    q_slots = list(set([tok for tok in q_toks if "TABLE" in tok or "SC" in tok or ("{" in tok and "}" in tok)]))
    
    return q_slots
    

def process_constraints(constraints, columns, slots):
    slot_values = {}
    skip_db_with_one_table = False
    for constraint in constraints:
        if "P0==" == constraint:
            assert "{OP0}" in slots
            slot_values["{OP0}"] = "="
        elif "P1==" == constraint:
            assert "{OP1}" in slots
            slot_values["{OP1}"] = "="
        elif "P0=P1==" == constraint:
            assert "{OP0}" in slots and "{OP1}" in slots
            slot_values["{OP0}"] = "="
            slot_values["{OP1}"] = "="
        elif "P0=P1=P2==" == constraint:
            assert "{OP0}" in slots and "{OP1}" in slots and "{OP2}" in slots
            slot_values["{OP0}"] = "="
            slot_values["{OP1}"] = "="
            slot_values["{OP2}"] = "="
        elif "P0=>" == constraint:
            assert "{OP0}" in slots
            slot_values["{OP0}"] = ">"
        elif "P0=<" == constraint:
            assert "{OP0}" in slots
            slot_values["{OP0}"] = "<"
        elif "{AGG0}=MIN" == constraint:
            assert "{AGG0}" in slots
            slot_values["{AGG0}"] = "MIN"
        elif "{AGG0}=MAX" == constraint:
            assert "{AGG0}" in slots
            slot_values["{AGG0}"] = "MAX"
        elif "C0-id" == constraint:
            skip_db_with_one_table = True
            assert "{COLUMN0}" in slots and "{COLUMN0}" in columns.keys()
            columns["{COLUMN0}"].append("id")
        elif "C1-id" == constraint:
            skip_db_with_one_table = True
            assert "{COLUMN1}" in slots and "{COLUMN1}" in columns.keys()
            columns["{COLUMN1}"].append("id")
        elif "C2-id" == constraint:
            skip_db_with_one_table = True
            assert "{COLUMN2}" in slots and "{COLUMN2}" in columns.keys()
            columns["{COLUMN2}"].append("id")
        elif "C3-T1" == constraint:
            skip_db_with_one_table = True
            assert "{COLUMN3}" in slots and "{COLUMN3}" in columns.keys()
            columns["{COLUMN3}"].append("T1")
        elif "T0-T1-JOIN" == constraint or 'T0-T1-NO-JOIN' == constraint:
            skip_db_with_one_table = True
            columns["{COLUMN0}"].append("T0")
            if "{COLUMN1}" in columns.keys():
                columns["{COLUMN1}"].append("T1")
    
    return (slot_values, columns, skip_db_with_one_table)


# helper function
def gen_col_info(col_str, columns, columns_inf):
    col_conds = columns[col_str]
    value_slot = [cc for cc in col_conds if "VALUE" in cc]
    col = ""
    value_val = None
    if "id" in col_conds:
        has_id = False
        for c, t, v in columns_inf:
            if "id" in col or "name" in col:
                has_id = True
                col, ctype, values = c, t, v
                break
        if not has_id:
            col, ctype, value = columns_inf[0]
    elif "number" in col_conds:
        for colinfo in columns_inf[1:]:
            if colinfo[1] == "real":
                col, ctype, value = colinfo
    if col == "":
        col, ctype, value = random.choice(columns_inf[1:])

    if len(value_slot) > 0:
        assert len(value_slot) < 3
        if len(value_slot) == 1:
            value_val = [(value_slot[0], value)]
        else:
            value_val = [(value_slot[0], value), (value_slot[1], value)]
    
    return (col, value_val)


def replace_dict(inp, dicts):
    for rep_in, rep_out in dicts.items():
        inp = inp.replace(rep_in, str(rep_out))
    
    return inp


##### Get classification label for each column based on SQL templates

In [None]:
STRUCT_KEYWORDS = ["WHERE", "GROUP_BY", "HAVING", "ORDER_BY", "SELECT"]
EXTRA_OPS = ["NOT_IN", "IN", "BETWEEN", "="]
COUNT = "COUNT"
OTHER_KEYWORDS = ["LIMIT"] #AGG, OP, DASC, OR, =
NEST_KEYWORDS = ["EXCEPT", "UNION", "INTERSECT"]

def get_labels(sql_pattern):
    sql_tokens = sql_pattern.replace("GROUP BY", "GROUP_BY").replace("ORDER BY", "ORDER_BY").replace("NOT IN", "NOT_IN").split(" ")
    columns = {}
    cur_nest = ""
    cur_struct = ""
    cur_len = len(sql_tokens)
    select_count = 0
    skip = False
    for i, tok in enumerate(sql_tokens):
        if tok in NEST_KEYWORDS:
            if cur_nest == "" or cur_nest == "OP_SEL":
                cur_nest = tok
            else:
                cur_nest = cur_nest + " " + tok
        elif tok in STRUCT_KEYWORDS:
            cur_struct = tok
            if tok == "SELECT":
                select_count += 1
                if select_count > 1 and cur_nest == "":
                    cur_nest = "OP_SEL"
        elif "COLUMN" in tok or "*" == tok:
            if tok not in columns.keys():
                columns[tok] = []
            # SELECT {COLUMN0}
            # SELECT {COLUMN0} , {COLUMN1}
            # SELECT {AGG0} ( {COLUMN0} )
            # SELECT {COLUMN0} {FROM} WHERE {COLUMN1} {OP} ( SELECT {AGG0} ( {COLUMN1} ) {FROM} ) AND {COLUMN2} {OP0} {VALUE0}
            if cur_struct == "SELECT":
                if "," == sql_tokens[i-1] or "SELECT" == sql_tokens[i-1]:
                    columns[tok].append(cur_nest + " " + cur_struct)
                elif "(" == sql_tokens[i-1]:
                    columns[tok].append(cur_nest + " " + cur_struct + " " + sql_tokens[i-2])
                else:
                    print("\nWarning: unexcepted SELECT format")
                    skip = True
                    print(sql_pattern)
            # WHERE {COLUMN} {OP}
            # WHERE {COLUMN2} {OP0}
            # WHERE OR {COLUMN2} {OP0}
            # WHERE {COLUMN2} BETWEEN
            elif cur_struct == "WHERE":
                assert "OP" in sql_tokens[i+1] or sql_tokens[i+1] in EXTRA_OPS
                last_tok = sql_tokens[i-1]
                if "OR" == last_tok or (i+3 < cur_len and "OR" == sql_tokens[i+3]):
                    columns[tok].append(cur_nest + " " + cur_struct + " OR " + sql_tokens[i+1])
                elif "WHERE" == last_tok or "AND" == last_tok:
                    columns[tok].append(cur_nest + " " + cur_struct + " " + sql_tokens[i+1])
                else:
                    print("\nWarning: unexcepted WHERE format")
                    skip = True
            # GROUP BY {COLUMN0} , {COLUMN0}
            elif cur_struct == "GROUP_BY":
                columns[tok].append(cur_nest + " " + cur_struct)
            # HAVING COUNT ( * ) {OP0}
            # HAVING {AGG0} ( {COLUMN2} ) {OP0}
            elif cur_struct == "HAVING":
                last_tok = sql_tokens[i-1]
                if last_tok != "(" and not ("AGG" in sql_tokens[i-2] or COUNT == sql_tokens[i-2]):
                    print("\nWarning: unexcepted HAVING format")
                    skip = True
                columns[tok].append(cur_nest + " " + cur_struct + " " + sql_tokens[i-2] + " " + sql_tokens[i+2])
            # ORDER BY COUNT ( * ) {DASC} LIMIT
            # ORDER BY COUNT ( * ) {DASC}
            # ORDER BY {COLUMN1} {DASC} LIMIT
            # ORDER BY {COLUMN1} LIMIT
            # ORDER BY {COLUMN1} , {COLUMN1} {DASC} LIMIT
            # ORDER BY {COLUMN1} {DASC} if no DASC then is ASC
            elif cur_struct == "ORDER_BY":
                last_tok = sql_tokens[i-1]
                if last_tok == "(":
                    dasc_tok = "{DASC}"
                    limit_tok = ""
                    if sql_tokens[i+2] != "{DASC}":
                        dasc_tok = "ASC"
                        if sql_tokens[i+2] == "LIMIT":
                            limit_tok = "LIMIT"
                    elif i+3 < cur_len and sql_tokens[i+3] == "LIMIT":
                        limit_tok = "LIMIT"
                        
                    columns[tok].append(cur_nest + " " + cur_struct + " " + sql_tokens[i-2] + " " + dasc_tok + " " + limit_tok)
                elif last_tok == "ORDER_BY" or last_tok == ",":
                    dasc_tok = "ASC"
                    limit_tok = ""
                    # small dirty pass
                    if i+1 < cur_len and sql_tokens[i+1] == "{DASC}":
                        dasc_tok = "{DASC}"
                        if i+2 < cur_len and sql_tokens[i+2] == "LIMIT":
                            limit_tok = "LIMIT"
                    elif i+1 < cur_len and sql_tokens[i+1] == "LIMIT":
                        limit_tok = "LIMIT"
                    
                    columns[tok].append(cur_nest + " " + cur_struct + " " + dasc_tok + " " + limit_tok)
        
            else:
                print("\n------------Warning: unexcepted COLUMN label format")
                skip = True
    
    column_labels = {}
    for col, labels in columns.items():
        label_str = " ".join([l.strip() for l in labels])
        column_labels[col] = label_str
        
    return column_labels, skip


##### Populate one example for a given database based on a given nl-SQL template and sql component mapping

In [None]:
def populate_one(db, templates, templates_one, sql_components):
    """
    'P0=P1==', 'P0=P1=P2==', 'P0==', 'P1==', 'P0=>', 'P0=<', '{AGG0}=MAX', '{AGG0}=MIN'
    'T0-T1-JOIN', 'T0-T1-NO-JOIN', 
    'C0-id',, 'C2-id', , 'C1-id',  'C3-T1'
    """
    if len(db) > 1:
        template = random.choice(templates)
    else:
        template = random.choice(templates_one)
        
    sql_constraints = template['SQL constraints']
    sql_pattern = template["SQL pattern"]
    question, q_constraints = random.choice(template["questions"])
    constraints = list(set(sql_constraints + q_constraints))

    slots, columns, ops, vals, aggs, dasc = get_sql_slots(sql_pattern)
    slot_values, columns, skip_db_with_one_table = process_constraints(constraints, columns, slots)

    q_slots = get_q_slots(question)
    q_slot_values = {}

    # 1 process ops - update columns and values constraints
    for op, colv in ops.items():
        if colv[0] == "*":
            if op not in slot_values.keys():
                op_val = random.choice([">", "<", ">=", "<=", "="])
                slot_values[op] = op_val
                if len(colv) == 2:
                    slot_values[colv[1]] = random.randint(1, 10)
        else:
            if colv[0] not in columns.keys():
                print("\n-----colv[0] not in columns.keys(): ")
                print(columns.keys())
                print(ops)
            assert colv[0] in columns.keys()
            if op not in slot_values.keys():
                if random.random() < 0.4:
                    op_val = "="
                else:
                    op_val = random.choice(OPS)
                slot_values[op] = op_val
                if op_val in [">", "<", ">=", "<="]:
                    columns[colv[0]].append("number")
            if len(colv) == 2:
                columns[colv[0]].append(colv[1])
    
    # 2 process columns
    random.shuffle(db)
    table_0, table_1 = None, None
    table_label_0 = ""
    table_label_1 = ""
    use_table_1 = False
    
    if "{COLUMN0}" in columns.keys() or "{TABLE0}" in q_slots:
        table_label_0 = "SELECT"
        
    if len(db) >= 2:
        table_0, table_1 = db[:2]
        if "{TABLE1}" in q_slots:
            table_label_1 = "SELECT"
            if "{TABLE0}" in q_slots:
                # p<0.5 from T0, T1 AND to SELECT T1 *
                # otherwise all from T0 AND to SELECT T1 *
                if random.random() < 0.5:
                    use_table_1 = True                 
            else:
                # p<0.4 all from T0 
                # AND to SELECT T1 *
                if random.random() < 0.6:
                    use_table_1 = True
                    if "{COLUMN1}" in columns.keys():
                        table_label_1 = "SELECT"
        else:
            # p<0.5 from T0, T1 AND to SELECT T1 *
            # otherwise all from T0, NOT to SELECT T1 *
            if random.random() < 0.5:
                use_table_1 = True
                if "{COLUMN1}" in columns.keys():
                    table_label_1 = "SELECT"
    else:
        table_0, table_1 = db[0], db[0]
    
    T0 = table_0["name"]
    T1 = table_1["name"]
    columns_inf_0 = list(zip(table_0["columns"], table_0["column_types"], table_0["values"]))[1:]
    if use_table_1:
        columns_inf_1 = list(zip(table_1["columns"], table_1["column_types"], table_1["values"]))[1:]
    
    if "{COLUMN0}" in columns.keys():
        col_0, value_0 = gen_col_info("{COLUMN0}", columns, columns_inf_0)
        slot_values["{COLUMN0}"] = col_0
        if value_0 is not None:
            for k, v in value_0:
                slot_values[k] = v
        if len(columns_inf_0) > 2:
            columns_inf_0 = [(col, ctype, val) for col, ctype, val in columns_inf_0 if col != col_0]
    
    if use_table_1:
        columns_input = columns_inf_1
        columns_all = columns_inf_0 + columns_inf_1
    else:
        columns_input = columns_inf_0
        columns_all = columns_inf_0
                
    if "{COLUMN1}" in columns.keys():
        col_1, value_1 = gen_col_info("{COLUMN1}", columns, columns_input)
        slot_values["{COLUMN1}"] = col_1
        if value_1 is not None:
            for k, v in value_1:
                slot_values[k] = v
        columns_input_org = columns_input
        if len(columns_input) > 3:
            columns_input = [(col, ctype, val) for col, ctype, val in columns_input if col != col_1]
        if len(columns_input) < 2:
            columns_input = columns_input_org
        columns_all = [(col, ctype, val) for col, ctype, val in columns_all if col != col_1]
        
    if "{COLUMN2}" in columns.keys():
        col_2, value_2 = gen_col_info("{COLUMN2}", columns, columns_input)
        slot_values["{COLUMN2}"] = col_2
        if value_2 is not None:
            for k, v in value_2:
                slot_values[k] = v
        columns_input_org = columns_input
        if len(columns_input) > 2:
            columns_input = [(col, ctype, val) for col, ctype, val in columns_input if col != col_2]
        if len(columns_input) < 2:
            columns_input = columns_input_org
        columns_all = [(col, ctype, val) for col, ctype, val in columns_all if col != col_2]
                
    if "{COLUMN3}" in columns.keys():
        col_3, value_3 = gen_col_info("{COLUMN3}", columns, columns_input)
        slot_values["{COLUMN3}"] = col_3
        if value_3 is not None:
            for k, v in value_3:
                slot_values[k] = v
        columns_all = [(col, ctype, val) for col, ctype, val in columns_all if col != col_3]
                
        
    # 3 aggs
    for agg in aggs.keys():
        if agg not in slot_values.keys():
            slot_values[agg] = random.choice(["MAX", "MIN", "SUM", "AVG"])
    # 4 values
    NUM = 1
    for val, cond in vals.items():
        assert val not in slot_values.keys()
        if cond == "integer":
            if random.random() < 0.5:
                slot_values[val] = 1
            else:
                NUM = random.randint(2, 10)
                slot_values[val] = NUM
        else:
            slot_values[val] = random.randint(0, 100)
                    
    # 5 dasc - true
    if dasc == True:
        slot_values["{DASC}"] = random.choice(["ASC", "DESC"])
    
    # 6 check if all sql slot values are done
    if len(slots) != len(slot_values):
        print("\nlen(slots) != len(slot_values)")
        print("sql_pattern: ", sql_pattern)
        print("slots: ", slots)
        print("slot_values: ", slot_values.keys())
    assert len(slots) == len(slot_values)
    
    # 7 for the questions slots:
    for qs in q_slots:
        if qs == "{TABLE0}":
            q_slot_values["{TABLE0}"] = T0
        elif qs == "{TABLE1}":
            q_slot_values["{TABLE1}"] = T1
        elif "SC" in qs:
            sc = slot_values["{DASC}"]
            if "SC" == qs:
                q_slot_values[qs] = random.choice(sql_components["SC"][sc])
            elif "SC_COL_LIMIT" == qs:
                if NUM > 1:
                    sc =  sc + "_NUM"
                    q_slot_values[qs] = random.choice(sql_components["SC_COL_LIMIT"][sc]).replace("[NUM]", str(NUM))
                else:
                    q_slot_values[qs] = random.choice(sql_components["SC_COL_LIMIT"][sc])
            elif "SC_COL_COUNT_LIMIT" in qs:
                sc_type = qs.replace("SC_COL_COUNT_LIMIT", "")
                if NUM > 1:
                    sc =  sc + "_NUM" + sc_type
                    q_slot_values[qs] = random.choice(sql_components["SC_COL_COUNT_LIMIT"][sc]).replace("[NUM]", str(NUM))
                else:
                    sc =  sc + sc_type
                    q_slot_values[qs] = random.choice(sql_components["SC_COL_COUNT_LIMIT"][sc])
            else:
                if "-" not in qs:
                    print("qs wrong", qs)
                assert "-" in qs
                if "C1" in qs:
                    sc_col = slot_values["{COLUMN1}"]
                elif "C2" in qs:
                    sc_col = slot_values["{COLUMN2}"]
                q_slot_values[qs] = random.choice(sql_components["SC_COL"][sc]).replace("[COL]", sc_col)
        else:
            if qs not in slot_values.keys():
                print("qs not in sv: ", qs)
                print("sql_pattern: ", sql_pattern)
                print("slot_values: ", slot_values)
            assert qs in slot_values.keys()
            if "OP" in qs:
                q_slot_values[qs] = random.choice(sql_components["OP"][slot_values[qs]])
            elif "AGG" in qs:
                q_slot_values[qs] = random.choice(sql_components["AGG"][slot_values[qs]])
            elif "COLUMN" in qs:
                q_slot_values[qs] = " ".join(slot_values[qs].split(" ")[:6])
            elif "VALUE" in qs:
                q_slot_values[qs] = " ".join(str(slot_values[qs]).split(" ")[:5])
            else:
                print("\nWarning: some q slot type not considered!")
                print(qs)
    
    # 8 check if all question slots are processed
    assert len(q_slots) == len(q_slot_values)
    
    # 9 generate final SQL-question pair
    question_gen = replace_dict(question, q_slot_values)
    
    
    # 10 generate column labels
    slot_values_new = {}
    for sl, vl in slot_values.items():
        if "COLUMN" in sl:
            slot_values_new[sl] = "_=_".join(vl.split(" "))
        else:
            slot_values_new[sl] = vl
            
    column_labels, skip = get_labels(sql_pattern)
    column_lables_real = {}
    for col, label in column_labels.items():
        if col != "*":
            col = slot_values[col]
        for slot, value in slot_values.items():
            label = label.replace(slot, str(value))
        column_lables_real[col] = label
    
    # also add labels for table column * 
    if table_label_0 != "":
        column_lables_real[table_0["columns"][0]] = table_label_0
    if table_label_1 != "":
        column_lables_real[table_1["columns"][0]] = table_label_1
    
    sql_gen = replace_dict(sql_pattern.replace(" {FROM}", ""), slot_values_new)
    
    return (sql_gen, question_gen, column_lables_real, q_slot_values, slot_values, template["SQL pattern"], columns_all)

##### generatee examples for all databases

In [None]:
# let's start data augmentation!
def augment_db(db, templates, templates_one_table, sql_components, aug_limit):
    count = 0
    augment_pairs = []
    while count < aug_limit:
        sql_gen, question_gen, column_lables, q_slot_values, slot_values, template, columns_all = populate_one(db, templates, templates_one_table, sql_components)
        qsep_label = qsep_label_map[template]
        augment_pairs.append((question_gen, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all, [qsep_label]))
        count += 1
    
    return augment_pairs
    

def augment_all_dbs(dbs, templates, templates_one_table, sql_components, aug_limit):
    augment_data = {}
    schema_dbs = {}
    for idx, db in enumerate(dbs):
        if idx % 10000 == 0:
            print("processed: ", idx)
        db_cols = ["*"]
        db_values = [""]
        for tab in db:
            db_cols.extend(tab["columns"])
            db_values.extend(tab["values"])
        schema_str = " </s> ".join(db_cols)
        values_str = " </s> ".join([str(k) for k in db_values])
        schema_str = schema_str + " |-| " + values_str
        augment_pairs = augment_db(db, templates, templates_one_table, sql_components, aug_limit)
        augment_data[schema_str] = augment_pairs
        schema_dbs[schema_str] = db
    
    return augment_data, schema_dbs

In [None]:
augment_first_webtable, schema_dbs_webtable = augment_all_dbs(webtable_dbs, templates, templates_one_table, sql_components, 2)

processed:  0
processed:  10000
processed:  20000
processed:  30000
processed:  40000
processed:  50000
processed:  60000
processed:  70000
processed:  80000


In [None]:
# read context template file
context_templates_file = "data/context_templates.json"
with open(context_templates_file) as json_file:
    context_templates = json.load(json_file)


# context_label_maps = {}
# for i, ct in enumerate(context_templates):
#     context_label_maps[ct["label"]] = i+1

In [None]:
SQL_OPS = ('INTERSECT', 'UNION', 'EXCEPT')
AGG_OPS = ["MAX", "MIN", "SUM", "AVG"]
OPS = [">", "<", ">=", "<=", "=", "!="]
SQLPARSE_MAP = {"\n      ": " ", "\n     ": " ", "\n    ": " ", "\n   ": " ", "\n  ": " ", "\n ": " ", "\nhaving": " having", "\nlimit": " limit"}
import sqlparse

In [None]:
prev_token = " <unk> "

In [None]:
def col_select(col_conds, columns_inf):
    value_slot = [cc for cc in col_conds if "VALUE" in cc]
    col = ""
    value_val = None
    if "id" in col_conds:
        has_id = False
        for c, t, v in columns_inf:
            if "id" in col or "name" in col:
                has_id = True
                col, ctype, values = c, t, v
                break
        if not has_id:
            col, ctype, value = columns_inf[0]
    elif "number" in col_conds:
        for colinfo in columns_inf:
            if colinfo[1] == "real":
                col, ctype, value = colinfo
    if len(columns_inf) == 0:
        print("\n---------------------------------------- columns_inf: ", columns_inf)
    if col == "":
        col, ctype, value = random.choice(columns_inf)

    if len(value_slot) > 0:
        assert len(value_slot) < 3
        if len(value_slot) == 1:
            value_val = [(value_slot[0], value)]
        else:
            value_val = [(value_slot[0], value), (value_slot[1], value)]
    
    return (col, value_val)


def replace_words(s, words):
    for k, v in words.items():
        s = s.replace(k, v)
    return s


def edit_sql(sql_pattern, context_label, slot_values_prev, columns_all_prev, context_template):
    sql = sql_pattern.lower().replace("{from}", "", 1).strip()
    sql_clauses = {"select": "", "where": "", "group_by": "", "order_by": ""}
    sql = sqlparse.format(sql, reindent=True)
    parsed = [x for x in replace_words(sql, SQLPARSE_MAP).split("\n")]
    slot_values = slot_values_prev.copy()
    for p in parsed:
        p_toks = p.split(" ")
        if p_toks[0] == "select":
            sql_clauses["select"] = p
        elif p_toks[0] == "where":
            sql_clauses["where"] = p
        elif p_toks[0] == "group":
            sql_clauses["group_by"] = p
        elif p_toks[0] == "order":
            sql_clauses["order_by"] = p
        else:
            raise Exception("unexcepted sql clause: ", p)
            
    context_question, context_constraints = random.choice(context_template["questions"])
    context_q_slots = get_q_slots(context_question)
    context_q_slots = [x.replace("1", "10").replace("2", "20").replace("3", "30") for x in context_q_slots]
            
    q_slot_values = {}
    sql_pattern_new = ""
    context_q = ""
    satisfy = True
    
    if context_label == "select replace column":
        if sql_clauses["group_by"] != "":
            satisfy = False
        else:
            sql_clauses["select"] = "select " + " , ".join(context_q_slots)
            col_num = len(context_q_slots)
            for i, qs in enumerate(context_q_slots):
                col, _ = col_select([], columns_all_prev)
    #             if col_num - i <= len(columns_all_prev) and len(columns_all_prev) > 1:
    #                 columns_all_prev = [x for x in columns_all_prev if x[0] != col]
                q_slot_values[qs] = " ".join(col.split(" ")[:5])
                slot_values[qs] = col
    elif context_label == "select insert column":
        sql_clauses["select"] = sql_clauses["select"] + " , " + " , ".join(context_q_slots)
        col_num = len(context_q_slots)
        for i, qs in enumerate(context_q_slots):
            col, _ = col_select([], columns_all_prev)
#             if col_num - i <= len(columns_all_prev) and len(columns_all_prev) > 1:
#                 columns_all_prev = [x for x in columns_all_prev if x[0] != col]
            q_slot_values[qs] = " ".join(col.split(" ")[:5])
            slot_values[qs] = col
    elif context_label == "select replace agg":
        if 'agg' not in sql_clauses["select"] or sql_clauses["select"].count("agg") > 1:
            satisfy = False
        else:
            assert sql_clauses["select"].count("agg") == 1
            for s, v in slot_values.items():
                if "AGG" in s:
                    agg_kw_prev, agg_prev = s, v
                    break
            agg_cur_list = [x for x in AGG_OPS if x != agg_prev]
            agg_kw_cur = context_q_slots[0]
            sql_clauses["select"] = sql_clauses["select"].replace(agg_kw_prev.lower(), agg_kw_cur)
            agg_cur = random.choice(agg_cur_list)
            q_slot_values[agg_kw_cur] = random.choice(sql_components["AGG"][agg_cur])
            slot_values[agg_kw_cur] = agg_cur
            if "COLUMN0" in context_question:
                q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
    elif context_label == "select delete column":
        if "agg" in sql_clauses["select"] or sql_clauses["select"].count("column") <= 1 or (sql_clauses["select"].count("column") <= 2 and len(context_q_slots) == 2):
            satisfy = False
        else:
            assert sql_clauses["select"].count("column") > 1
            sql_clauses["select"] = "select " + " , ".join(context_q_slots)
            for qs in context_q_slots:
                if "{COLUMN10}" == qs:
                    if "{COLUMN12}" in slot_values.keys():
                        slot_values["{COLUMN10}"] = slot_values["{COLUMN12}"]
                    elif "{COLUMN11}" in slot_values.keys():
                        slot_values["{COLUMN10}"] = slot_values["{COLUMN11}"]
                    elif "{COLUMN1}" in slot_values.keys():
                        slot_values["{COLUMN10}"] = slot_values["{COLUMN1}"]
                    q_slot_values[qs] = slot_values["{COLUMN10}"]
                else:
                    q_slot_values[qs] = slot_values[qs]
    elif context_label == "select delete agg": # need to check more examples
        if "count" not in sql_clauses["select"] or ("*" in sql_clauses["select"] and "column" in sql_clauses["select"]) or (sql_clauses["select"].count("column") == 0 and len(context_q_slots) == 0) or (sql_clauses["select"].count("column0") == 0 and "COLUMN0" in context_question):
            satisfy = False
        else:
            if len(context_q_slots) == 0:
                sql_clauses["select"] = "select {COLUMN0}"
            elif "{COLUMN0}" in context_q_slots:
                q_slot_values[context_q_slots[0]] = slot_values["{COLUMN0}"]
            else:
                col, _ = col_select([], columns_all_prev)
                q_slot_values[context_q_slots[0]] = col
                slot_values[context_q_slots[0]] = col
                sql_clauses["select"] = "select " + context_q_slots[0]
    elif context_label == "select insert agg":
        if "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or sql_clauses["select"].count("column") > 1:
            satisfy = False
        else:
            
            if "{AGG0}" in context_q_slots:
                slot_values["{AGG0}"] = random.choices(["MAX", "MIN", "SUM", "AVG", "COUNT"], weights=(1, 1, 1, 1, 3), k=1)[0]
                q_slot_values["{AGG0}"] = random.choice(sql_components["AGG"][slot_values["{AGG0}"]])
            else:
                slot_values["{AGG0}"] = "COUNT"
                
            if sql_clauses["select"].count("column") == 0:
                sql_clauses["select"] = "select {AGG0} (*)"
                q_slot_values["{COLUMN0}"] = ""
            elif sql_clauses["select"].count("column") == 1:
                sql_clauses["select"] = "select {AGG0} ({COLUMN0})"
                q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            else:
                raise Exception("unexcepted select clause: ", sql_clauses["select"])
    elif context_label == "select insert agg and column":
        if "agg" not in sql_clauses["select"] and sql_clauses["group_by"] == "":
            satisfy = False
        else:
            sql_clauses["select"] = sql_clauses["select"] + " , " + "{agg10} ({column10})"
            for i, qs in enumerate(context_q_slots):
                if "AGG" in qs:
                    slot_values[qs] = random.choice(AGG_OPS)
                    q_slot_values[qs] = random.choice(sql_components["AGG"][slot_values[qs]])
                elif "COLUMN" in qs:
                    col, _ = col_select(["number"], columns_all_prev)
                    q_slot_values[qs] = " ".join(col.split(" ")[:5])
                    slot_values[qs] = col

            if "agg" in sql_clauses["select"] and "COLUMN" not in context_question:
                slot_values["{COLUMN10}"] = slot_values["{COLUMN0}"]
            
            if "agg" not in sql_clauses["select"] and "COLUMN" not in context_question:
                satisfy = False
    elif context_label == "select replace agg and column":
        if sql_clauses["group_by"] == "" and sql_clauses["where"] == "":
            satisfy = False
        else:
            sql_clauses["select"] = "select {agg10} ({column10})"
            for i, qs in enumerate(context_q_slots):
                if "AGG" in qs:
                    slot_values[qs] = random.choice(AGG_OPS)
                    q_slot_values[qs] = random.choice(sql_components["AGG"][slot_values[qs]])
                elif "COLUMN" in qs:
                    col, _ = col_select(["number"], columns_all_prev)
                    q_slot_values[qs] = " ".join(col.split(" ")[:5])
                    slot_values[qs] = col
                    
            if len(context_q_slots) == 0:
                slot_values["{AGG10}"] = "COUNT"
                slot_values["{COLUMN10}"] = "*"
            elif len(context_q_slots) == 4:
                sql_clauses["select"] = "select {agg10} ({column10}) , {agg20} ({column20})"
            
            if sql_clauses["group_by"] != "":
                gb_col = sql_clauses["group_by"].split(" ")[2].upper()
                assert "COLUMN" in gb_col
                sql_clauses["select"] = sql_clauses["select"] + " , " + gb_col
    elif context_label == "where insert":
        if "agg" in sql_clauses["select"] or "count" in sql_clauses["select"]:
            satisfy = False
        else:
            if sql_clauses["where"] != "":
                sql_clauses["where"] = sql_clauses["where"] + " AND " + "{COLUMN10} {OP10} {VALUE10}"
            else:
                sql_clauses["where"] = "WHERE {COLUMN10} {OP10} {VALUE10}"
            
            if "{OP" in context_question:
                op_val = random.choice([">", "<", ">=", "<=", "="])
            else:
                op_val = "="
                
            slot_values["{OP10}"] = op_val
            q_slot_values["{OP10}"] = random.choice(sql_components["OP"][op_val])
                
            if op_val != "=":
                col, value = col_select(["number", "VALUE10"], columns_all_prev)
            else:
                col, value = col_select(["VALUE10"], columns_all_prev)
            
            q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
            slot_values["{COLUMN10}"] = col
            q_slot_values["{VALUE10}"] = " ".join(str(value[0][1]).split(" ")[:5])
            slot_values["{VALUE10}"] = value[0][1]
            if "COLUMN0" in context_question:
                q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
    elif context_label == "where replace":
        if sql_clauses["where"].count("column") != 1:
            satisfy = False
        else:
            sql_clauses["where"] = "WHERE {COLUMN10} {OP10} {VALUE10}"
            
            if "{OP" in context_question:
                op_val = random.choice([">", "<", ">=", "<=", "="])
            else:
                op_val = "="
                
            slot_values["{OP10}"] = op_val
            q_slot_values["{OP10}"] = random.choice(sql_components["OP"][op_val])
                
            if op_val != "=":
                col, value = col_select(["number", "VALUE10"], columns_all_prev)
            else:
                col, value = col_select(["VALUE10"], columns_all_prev)
            
            q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
            slot_values["{COLUMN10}"] = col
            q_slot_values["{VALUE10}"] = " ".join(str(value[0][1]).split(" ")[:5])
            slot_values["{VALUE10}"] = value[0][1]
            if "COLUMN0" in context_question:
                q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
    elif context_label == "where replace value":
        if sql_clauses["where"].count("column") != 1 or sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "":
            satisfy = False
        else:
            wh_toks = sql_clauses["where"].split(" ")
            for tok in wh_toks:
                if "column" in tok:
                    wh_col = tok
                elif "value" in tok:
                    wh_val = tok
            sql_clauses["where"] = sql_clauses["where"].replace(wh_val, "{VALUE10}")
            q_slot_values["{COLUMN0}"] = slot_values[wh_col.upper()]
            q_slot_values["{VALUE10}"] = " ".join(str(slot_values[wh_val.upper()]).split(" ")[:2]) + " " + q_slot_values["{COLUMN0}"].split(" ")[0] #just to add noisy to fake value
            slot_values["{VALUE10}"] = q_slot_values["{VALUE10}"]
    elif context_label == "where replace operation":
        if sql_clauses["where"].count("column") != 1 or sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "":
            satisfy = False
        else:
            wh_toks = sql_clauses["where"].split(" ")
            for tok in wh_toks:
                if "column" in tok:
                    wh_col = tok
                elif "value" in tok:
                    wh_val = tok
                elif "op" in tok:
                    wh_op = tok
            sql_clauses["where"] = sql_clauses["where"].replace(wh_op, "{OP10}")
            q_slot_values["{VALUE0}"] = " ".join(str(slot_values[wh_val.upper()]).split(" ")[:4])
            op_prev = slot_values[wh_op.upper()]
            op_cur_list = [x for x in OPS if x != op_prev]
            op_val = random.choice(op_cur_list)
            slot_values["{OP10}"] = op_val
            q_slot_values["{OP10}"] = random.choice(sql_components["OP"][op_val])
    
    elif context_label == "order_by insert":
        if "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "":
            satisfy = False
        else:
            sql_clauses["order_by"] = "order by {column10} {dasc}"
            sc = random.choice(["ASC", "DESC"])
            slot_values["{DASC}"] = sc
            col, _ = col_select([], columns_all_prev)
            slot_values["{COLUMN10}"] = col
            q_slot_values["SC_COL"] = random.choice(sql_components["SC_COL"][sc]).replace("[COL]", " ".join(col.split(" ")[:5]))
    elif context_label == "order_by insert limit":
        if "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or sql_clauses["order_by"] == "" or "limit" in sql_clauses["order_by"]:
            satisfy = False
        else:
            if "{dasc}" in sql_clauses["order_by"]:
                sc = slot_values["{DASC}"]
            else:
                sc = "ASC"
            sql_clauses["order_by"] += " limit {value10}"
            limit_val = random.choice([1,1,1,2,3,5])
            slot_values["{VALUE10}"] = limit_val
            if limit_val == 1:
                q_slot_values["SC_COL_LIMIT"] = random.choice(sql_components["SC_COL_LIMIT"][sc])
            else:
                q_slot_values["SC_COL_LIMIT"] = random.choice(sql_components["SC_COL_LIMIT"][sc+"_NUM"]).replace("[NUM]", str(limit_val))
    elif context_label == "order_by insert limit | select delete agg and column":
        if "count" not in sql_clauses["select"] or sql_clauses["group_by"] == "" or sql_clauses["order_by"] != "":
            satisfy = False
        else:
            sql_clauses["order_by"] = "order by count (*) {dasc} limit 1"
            sel_cols = [x for x in sql_clauses["select"].split(" ") if "column" in x]
            sql_clauses["select"] = "select " + " , ".join(sel_cols)
            sc = random.choice(["ASC", "DESC"])
            slot_values["{DASC}"] = sc
            q_slot_values["SC_COL_LIMIT"] = random.choice(sql_components["SC_COL_LIMIT"][sc])
            if "{COLUMN0}" in slot_values.keys():
                q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
    elif context_label == "order_by insert limit | select replace column":
        if sql_clauses["order_by"] != "" or sql_clauses["where"] != "":
            satisfy = False
        else:
            col, _ = col_select([], columns_all_prev)
            q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
            slot_values["{COLUMN10}"] = col
            sql_clauses["select"] = "select {column10}"
            
            sc = random.choice(["ASC", "DESC"])
            slot_values["{DASC}"] = sc
            q_slot_values["SC_COL_LIMIT"] = random.choice(sql_components["SC_COL_LIMIT"][sc])
            
            if "{COLUMN2}" in context_question:
                col2, _ = col_select(["number"], columns_all_prev)
                q_slot_values["{COLUMN20}"] = " ".join(col2.split(" ")[:5])
                slot_values["{COLUMN20}"] = col2
                sql_clauses["order_by"] = "order by {column20} {dasc} limit 1" 
            else:
                sql_clauses["order_by"] = "order by count (*) {dasc} limit 1"           
    elif context_label == "order_by replace sc":
        if "agg" in sql_clauses["select"] or "limit" not in sql_clauses["order_by"]:
            satisfy = False
        else:
            if "{dasc}" in sql_clauses["order_by"]:
                sc_prev = slot_values["{DASC}"]
            else:
                sc_prev = "ASC"
            sc = "DESC" if sc_prev == "ASC" else "ASC"
            slot_values["{DASC}"] = sc
            q_slot_values["SC_COL_LIMIT"] = random.choice(sql_components["SC_COL_LIMIT"][sc])
            if "{COLUMN0}" in slot_values.keys():
                q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
    elif context_label == "no change":
        if "having" not in sql_clauses["group_by"] and "value" not in sql_clauses["order_by"]:
            satisfy = False
        else:
            q_slot_values["{OP0}"] = "top"
            q_slot_values["[NUM]"] = random.choice([3,5,10])
            if "having" in sql_clauses["group_by"]:
                gb_toks = sql_clauses["group_by"].split(" ")
                for tok in gb_toks:
                    if "op" in tok:
                        gb_op = tok.upper()
                        q_slot_values["{OP0}"] = random.choice(sql_components["OP"][slot_values[gb_op]])
    elif context_label == "group_by insert | select insert agg and column":
        if sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "" or sql_clauses["where"].count("column") > 1 or "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or "column0" not in sql_clauses["select"]:
            satisfy = False
        else:
            q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            sql_clauses["group_by"] = "group by {column0}"
            if "AGG" not in context_question:
                sql_clauses["select"] = sql_clauses["select"] + " , count (*)"
            else:
                sql_clauses["select"] = sql_clauses["select"] + " , {agg10} ({column10})"
                agg_cur = random.choice(AGG_OPS)
                slot_values["{AGG10}"] = agg_cur
                q_slot_values["{AGG10}"] = random.choice(sql_components["AGG"][agg_cur])
                col, _ = col_select(["number"], columns_all_prev)
                q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
                slot_values["{COLUMN10}"] = col
    elif context_label == "group_by insert | select replace agg and column":
        if sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "" or sql_clauses["where"] != "" or "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or "column0" not in sql_clauses["select"]:
            satisfy = False
        else:
            q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            sql_clauses["group_by"] = "group by {column0}"
            if "AGG" not in context_question:
                sql_clauses["select"] = "select {column0} , count (*)"
            else:
                if "COLUMN1" in context_question:
                    col, _ = col_select([], columns_all_prev)
                    q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
                    slot_values["{COLUMN10}"] = col
                    sql_clauses["select"] = "select {column10} , {agg10} ({column20})"
                    sql_clauses["group_by"] = "group by {column10}"
                else:
                    sql_clauses["select"] = "select {column0} , {agg10} ({column20})"
                agg_cur = random.choice(AGG_OPS)
                slot_values["{AGG10}"] = agg_cur
                q_slot_values["{AGG10}"] = random.choice(sql_components["AGG"][agg_cur])
                col, _ = col_select(["number"], columns_all_prev)
                q_slot_values["{COLUMN20}"] = " ".join(col.split(" ")[:5])
                slot_values["{COLUMN20}"] = col
    elif context_label == "group_by insert | order_by insert":
        if sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "" or sql_clauses["where"] != "" or "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or "column0" not in sql_clauses["select"] or sql_clauses["select"].count("column") > 2:
            satisfy = False
        else:
            q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            sql_clauses["group_by"] = "group by {column0}"
            sql_clauses["order_by"] = "order by count (*) {dasc}"

            if "COLUMN1" in context_question:
                col, _ = col_select([], columns_all_prev)
                q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
                slot_values["{COLUMN10}"] = col
                sql_clauses["group_by"] = "group by {column10}"
            
            sc = random.choice(["ASC", "DESC"])
            slot_values["{DASC}"] = sc
            q_slot_values["SC"] = random.choice(sql_components["SC"][sc])
    elif context_label == "group_by insert having":
        if "having" in sql_clauses["group_by"] or sql_clauses["order_by"] != "" or sql_clauses["where"] != "" or "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or "column0" not in sql_clauses["select"] or sql_clauses["select"].count("column") > 2:
            satisfy = False
        else:
            q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            if sql_clauses["group_by"] == "":
                sql_clauses["group_by"] = "group by {column0} having count (*) {op0} {value0}"
            else:
                sql_clauses["group_by"] += " having count (*) {op0} {value0}"
            
            op_val = random.choice(OPS)
            slot_values["{OP0}"] = op_val
            q_slot_values["{OP0}"] = random.choice(sql_components["OP"][op_val])
            
            value = random.choice([1, 3, 5, 10])
            slot_values["{VALUE0}"] = value
            q_slot_values["{VALUE0}"] = str(value)
    elif context_label == "group_by insert having | select delete agg and column":
        if "having" in sql_clauses["group_by"] or sql_clauses["group_by"] == "" or sql_clauses["order_by"] != "" or sql_clauses["where"] != "" or ("agg" not in sql_clauses["select"] and "count" not in sql_clauses["select"]) or "column0" not in sql_clauses["select"]:
            satisfy = False
        else:
            q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            sel_toks = sql_clauses["select"].split(" ")
            agg_col_toks = []
            for tok in sel_toks:
                if "select" not in tok and "," not in tok:
                    if "({column" in tok or "column" not in tok:
                        agg_col_toks.append(tok)
                if ")" in tok:
                    break
            agg_col = " ".join(agg_col_toks)
            sql_clauses["select"] = " ".join([x for x in sel_toks if x not in agg_col_toks])
            
            sql_clauses["group_by"] += " having " + agg_col + " {op0} {value0}"
            
            op_val = random.choice(OPS)
            slot_values["{OP0}"] = op_val
            q_slot_values["{OP0}"] = random.choice(sql_components["OP"][op_val])
            
            value = random.choice([1, 3, 5, 10])
            slot_values["{VALUE0}"] = value
            q_slot_values["{VALUE0}"] = str(value)
    elif context_label == "group_by replace | select replace column":
        if sql_clauses["group_by"] == "" or sql_clauses["where"] != "" or "column0" not in sql_clauses["select"] or sql_clauses["select"].count("column") > 2:
            satisfy = False
        else:
            col, _ = col_select([], columns_all_prev)
            q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
            slot_values["{COLUMN10}"] = col
            if "{column0}" in sql_clauses["group_by"]:
                sql_clauses["select"] = sql_clauses["select"].replace("{column0}", "{column10}")
                sql_clauses["group_by"] = sql_clauses["group_by"].replace("{column0}", "{column10}")
            elif "{column1}" in sql_clauses["group_by"]:
                sql_clauses["select"] = sql_clauses["select"].replace("{column1}", "{column10}")
                sql_clauses["group_by"] = sql_clauses["group_by"].replace("{column1}", "{column10}")
            else:
                satisfy = False
    elif context_label == "insert SQL":
        if sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "" or sql_clauses["where"].count("value") != 1 or "select" in sql_clauses["where"] or "agg" in sql_clauses["select"] or "count" in sql_clauses["select"] or "column0" not in sql_clauses["select"] or sql_clauses["select"].count("column") > 2:
            satisfy = False
        else:
            wh_toks = sql_clauses["where"].split(" ")
            for tok in wh_toks:
                if "column" in tok:
                    wh_col = tok
                elif "value" in tok:
                    wh_val = tok
                elif "op" in tok:
                    wh_op = tok
            
            q_slot_values["{COLUMN0}"] = slot_values[wh_col.upper()]
            q_slot_values["{VALUE0}"] = slot_values[wh_val.upper()]
            value = random.choice([1, 3, 5, 10])
            slot_values["{VALUE10}"] = value
            q_slot_values["{VALUE10}"] = str(value)
            op_val = random.choice(OPS)
            slot_values["{OP10}"] = op_val
            q_slot_values["{OP10}"] = random.choice(sql_components["OP"][op_val])
            if "{OP1" in context_question:
                sql_where = " where {column0} {op10} {value10}"
            elif "{VALUE1" in context_question:
                sql_where = " where {column0} {op0} {value10}"
            else:
                sql_where = " where {column0} {op0} {value0}"
            if "COLUMN1" in context_question:
                col, _ = col_select([], columns_all_prev)
                q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
                slot_values["{COLUMN10}"] = col
                sql_clauses["select"] = "select {column10}"
            if context_constraints[0] == "intersect":
                sql_clauses["group_by"] = "intersect " + sql_clauses["select"] + sql_where
            else:
                sql_clauses["group_by"] = "except " + sql_clauses["select"] + sql_where
                sql_clauses["where"] = ""
    elif context_label == "where insert SQL":
        if sql_clauses["group_by"] != "" or sql_clauses["order_by"] != "" or sql_clauses["where"].count("column") > 1 or "select" in sql_clauses["where"] or "count" in sql_clauses["select"] or "column0" not in sql_clauses["select"] or sql_clauses["select"].count("column") > 1:
            satisfy = False
        elif "agg" in sql_clauses["select"] and context_constraints[0] == "op":
            col, _ = col_select([], columns_all_prev)
            q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
            slot_values["{COLUMN10}"] = col
            sql_clauses["select"] = "select {column10}"
            q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            op_val = random.choice(OPS)
            slot_values["{OP10}"] = op_val
            q_slot_values["{OP10}"] = random.choice(sql_components["OP"][op_val])
            sql_clauses["where"] = "where {column0} {op10} (" + sql_pattern.lower() + ")"
            sql_clauses["where"] = sql_clauses["where"].replace("( ", "(").replace(" )", ")")
        elif "agg" not in sql_clauses["select"] and context_constraints[0] != "op":
            q_slot_values["{COLUMN0}"] = slot_values["{COLUMN0}"]
            if "{TABLE0}" in slot_values.keys():
                q_slot_values["{TABLE0}"] = slot_values["{TABLE0}"]
            else:
                q_slot_values["{TABLE0}"] = ""
                
            if "COLUMN1" in context_question:
                col, _ = col_select([], columns_all_prev)
                q_slot_values["{COLUMN10}"] = " ".join(col.split(" ")[:5])
                slot_values["{COLUMN10}"] = col
                sql_clauses["select"] = "select {column10}"
                
            if context_constraints[0] == "not in":
                sql_clauses["where"] = "where {column0} not in (" + sql_pattern.lower() + ")"
            else:
                sql_clauses["where"] = "where {column0} in (" + sql_pattern.lower() + ")"
            sql_clauses["where"] = sql_clauses["where"].replace("( ", "(").replace(" )", ")")
        else:
            satisfy = False
    else:
        print("\n--------------------Unexcepted context template: ", context_label)
        satisfy = False
    
        
    if satisfy:
#         print("parsed prev sql: ", parsed)
#         print("slot_values: ", slot_values)
#         print("q_slot_values: ", q_slot_values)
        sql_str_list = [v for k, v in sql_clauses.items() if v != ""]
        sql_str_list.insert(1, "{from}")

        sql_pattern_new = " ".join(sql_str_list).upper().replace("(", "( ").replace(")", " )")
        
        # 9 generate final SQL-question pair
        q_slot_values = {k.replace("10", "1").replace("20", "2").replace("30", "3"): v for k, v in q_slot_values.items()}
        context_q = replace_dict(context_question, q_slot_values)

    return sql_pattern_new, slot_values, context_q, satisfy


def add_augment_context(augment_data, context_templates, schema_dbs):
    #question_gen, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all

    data_new = {}
    skip_count = 0
    count = 0
    augment_iso = augment_data.copy()
    for schema_str, exs in augment_iso.items():
        count += 1
        if count % 10000 == 0:
            print("processed: ", count)
        data_new[schema_str] = []
        for ex in exs:
            sql_pattern = ex[5]
            columns_all_prev = ex[6].copy()
            question_prev = ex[0]
            sql_prev = ex[1]
            col_labels_prev = ex[2].copy()
            q_slot_values_prev = ex[3]
            slot_values_prev = ex[4].copy()
            context_label_list = ex[7].copy()
            
            if random.random() <= 0.8:
                try_num = 0
                if "INTERSECT" in sql_pattern or "UNION" in sql_pattern or "EXCEPT" in sql_pattern or len(columns_all_prev) < 1:
                    continue

                while try_num < 3:
                    context_template = random.choice(context_templates)
                    context_label = context_template['label']
                    prereqs = context_template["prereqs"]
                    edited_sql_pattern, slot_values, context_q, satisfy = edit_sql(sql_pattern, context_label, slot_values_prev, columns_all_prev, context_template)

                    try_num += 1
                    if satisfy:
                        break

                if not satisfy:
                    continue

                context_q = context_q + prev_token + question_prev

    #             print("question: ", context_q)
    #             print("previous sql pattern: ", sql_pattern)
    #             print("edited_sql_pattern: ", edited_sql_pattern)

                # 10 generate column labels
                slot_values_new = {}
                for sl, vl in slot_values.items():
                    if "COLUMN" in sl:
                        slot_values_new[sl] = "_=_".join(vl.split(" "))
                    else:
                        slot_values_new[sl] = vl

                column_labels, skip = get_labels(edited_sql_pattern)
                if skip:
                    continue
                column_lables_real = {}
                for col, label in column_labels.items():
                    if col != "*":
                        col = slot_values[col]
                    for slot, value in slot_values.items():
                        label = label.replace(slot, str(value))
                    column_lables_real[col] = label

                edited_sql = replace_dict(edited_sql_pattern.replace(" {FROM}", ""), slot_values_new)

                #(question_gen, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all)

    #             print("edited_sql: ", edited_sql)
    #             print("column_lables_real: ", column_lables_real)
    #             print("")
                context_label_int = qsep_label_map[context_label]
                context_label_list.insert(0, context_label_int)
                data_new[schema_str].append((context_q, edited_sql, column_lables_real, None, slot_values, edited_sql_pattern, columns_all_prev, context_label_list))
            else:
                db = schema_dbs[schema_str]
                sql_gen, question_gen, column_lables, q_slot_values, slot_values, template, columns_all = populate_one(db, templates, templates_one_table, sql_components)
                context_q = question_gen + prev_token + question_prev
                context_label_list.insert(0, 0)
                data_new[schema_str].append((context_q, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all, context_label_list))
            
    return data_new       

In [None]:
augment_second_webtable = add_augment_context(augment_first_webtable, context_templates, schema_dbs_webtable)

processed:  10000
processed:  20000
processed:  30000
processed:  40000
processed:  50000
processed:  60000
processed:  70000
processed:  80000


In [2]:
two_count = 0
for schema, examples in augment_second_webtable.items():
    if two_count > 100:
        break
    for ex in examples:
        two_count += 1
        sql_pattern = ex[5]
        columns_all_prev = ex[6]
        question_prev = ex[0]
        sql_prev = ex[1]
        col_labels_prev = ex[2]
        q_slot_values_prev = ex[3]
        slot_values_prev = ex[4]
        context_label_list = ex[7]
        print("\nsql_pattern: ", sql_pattern)
        print("question: ", question_prev)
        print("sql: ", sql_prev)
        print("column labels: ", col_labels_prev)
        print("slot values: ", slot_values_prev)
        print("context_label_list: ", context_label_list)

In [None]:
slot_update_dict = {"10": "11", "20": "21", "30": "31"}

def add_augment_context_second(augment_second_data, context_templates, schema_dbs):
    #question_gen, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all, [context_label_int]

    data_new = {}
    skip_count = 0
    count = 0
    augment_second_iso = augment_second_data.copy()
    for schema_str, exs in augment_second_iso.items():
        count += 1
        if count % 10000 == 0:
            print("processed: ", count)
        data_new[schema_str] = []
        for ex in exs:
            sql_pattern = replace_dict(ex[5], slot_update_dict)
            columns_all_prev = ex[6].copy()
            question_prev = ex[0]
            sql_prev = ex[1]
            col_labels_prev = ex[2].copy()
            q_slot_values_prev = ex[3]
            slot_values_prev = {replace_dict(k, slot_update_dict) : v for k, v in ex[4].items()}.copy()
            context_label_list = ex[7].copy()
            
            
            if random.random() <= 0.8:
                try_num = 0
                if "INTERSECT" in sql_pattern or "UNION" in sql_pattern or "EXCEPT" in sql_pattern or len(columns_all_prev) < 1:
                    continue

                while try_num < 3:
                    context_template = random.choice(context_templates)
                    context_label = context_template['label']
                    prereqs = context_template["prereqs"]
                    edited_sql_pattern, slot_values, context_q, satisfy = edit_sql(sql_pattern, context_label, slot_values_prev, columns_all_prev, context_template)

                    try_num += 1
                    if satisfy:
                        break

                if not satisfy:
                    continue

                context_q = context_q + prev_token + question_prev

    #             print("question: ", context_q)
    #             print("previous sql pattern: ", sql_pattern)
    #             print("edited_sql_pattern: ", edited_sql_pattern)

                # 10 generate column labels
                slot_values_new = {}
                for sl, vl in slot_values.items():
                    if "COLUMN" in sl:
                        slot_values_new[sl] = "_=_".join(vl.split(" "))
                    else:
                        slot_values_new[sl] = vl

                column_labels, skip = get_labels(edited_sql_pattern)
                if skip:
                    continue
                column_lables_real = {}
                for col, label in column_labels.items():
                    if col != "*":
                        if col not in slot_values.keys():
                            print("slot_values_prev: ", slot_values_prev)
                            print("q_slot_values_prev: ", q_slot_values_prev)
                            print("sql_pattern: ", sql_pattern)
                            print("context_label: ", context_label)
                            print("edited_sql_pattern: ", edited_sql_pattern)
                            print("slot_values: ", slot_values)
                            print("column_labels: ", column_labels)
                        col = slot_values[col]
                    for slot, value in slot_values.items():
                        label = label.replace(slot, str(value))
                    column_lables_real[col] = label

                edited_sql = replace_dict(edited_sql_pattern.replace(" {FROM}", ""), slot_values_new)

                #(question_gen, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all)

    #             print("edited_sql: ", edited_sql)
    #             print("column_lables_real: ", column_lables_real)
    #             print("")
                context_label_int = qsep_label_map[context_label]
                context_label_list.insert(0, context_label_int)
                
                data_new[schema_str].append((context_q, edited_sql, column_lables_real, None, slot_values, edited_sql_pattern, columns_all_prev, context_label_list))
            else:
                db = schema_dbs[schema_str]
                sql_gen, question_gen, column_lables, q_slot_values, slot_values, template, columns_all = populate_one(db, templates, templates_one_table, sql_components)
                context_q = question_gen + prev_token + question_prev
                context_label_list.insert(0, 0)
                data_new[schema_str].append((context_q, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all, context_label_list))
            
    return data_new

In [3]:
augment_third_webtable = add_augment_context_second(augment_second_webtable, context_templates, schema_dbs_webtable)

In [None]:
two_count = 0
for schema, examples in augment_third_webtable.items():
    for ex in examples:
        two_count += 1
#         sql_pattern = ex[5]
#         columns_all_prev = ex[6]
#         question_prev = ex[0]
#         sql_prev = ex[1]
#         col_labels_prev = ex[2]
#         q_slot_values_prev = ex[3]
#         slot_values_prev = ex[4]
#         context_label_list = ex[7]
#         print("\nsql_pattern: ", sql_pattern)
#         print("question: ", question_prev)
#         print("sql: ", sql_prev)
#         print("column labels: ", col_labels_prev)
#         print("context_label_list: ", context_label_list)
#         print("slot_values: ", slot_values_prev)
print(two_count)

913


In [None]:
slot_update_dict = {"10": "12", "20": "22", "30": "32"}

def add_augment_context_third(augment_third_data, context_templates, schema_dbs):
    #question_gen, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all, [context_label_int]

    data_new = {}
    skip_count = 0
    count = 0
    augment_third_iso = augment_third_data.copy()
    for schema_str, exs in augment_third_iso.items():
        count += 1
        if count % 10000 == 0:
            print("processed: ", count)
        data_new[schema_str] = []
        for ex in exs:
            sql_pattern = replace_dict(ex[5], slot_update_dict)
            columns_all_prev = ex[6].copy()
            question_prev = ex[0]
            sql_prev = ex[1]
            col_labels_prev = ex[2].copy()
            q_slot_values_prev = ex[3]
            slot_values_prev = {replace_dict(k, slot_update_dict) : v for k, v in ex[4].items()}.copy()
            context_label_list = ex[7].copy()
            
            
            if random.random() <= 0.8:
                try_num = 0
                if "INTERSECT" in sql_pattern or "UNION" in sql_pattern or "EXCEPT" in sql_pattern or len(columns_all_prev) < 1:
                    continue

                while try_num < 3:
                    context_template = random.choice(context_templates)
                    context_label = context_template['label']
                    prereqs = context_template["prereqs"]
                    edited_sql_pattern, slot_values, context_q, satisfy = edit_sql(sql_pattern, context_label, slot_values_prev, columns_all_prev, context_template)

                    try_num += 1
                    if satisfy:
                        break

                if not satisfy:
                    continue

                context_q = context_q + prev_token + question_prev

    #             print("question: ", context_q)
    #             print("previous sql pattern: ", sql_pattern)
    #             print("edited_sql_pattern: ", edited_sql_pattern)

                # 10 generate column labels
                slot_values_new = {}
                for sl, vl in slot_values.items():
                    if "COLUMN" in sl:
                        slot_values_new[sl] = "_=_".join(vl.split(" "))
                    else:
                        slot_values_new[sl] = vl

                column_labels, skip = get_labels(edited_sql_pattern)
                if skip:
                    continue
                column_lables_real = {}
                for col, label in column_labels.items():
                    if col != "*":
                        if col not in slot_values.keys():
                            print("slot_values_prev: ", slot_values_prev)
                            print("q_slot_values_prev: ", q_slot_values_prev)
                            print("sql_pattern: ", sql_pattern)
                            print("context_label: ", context_label)
                            print("edited_sql_pattern: ", edited_sql_pattern)
                            print("slot_values: ", slot_values)
                            print("column_labels: ", column_labels)
                        col = slot_values[col]
                    for slot, value in slot_values.items():
                        label = label.replace(slot, str(value))
                    column_lables_real[col] = label

                edited_sql = replace_dict(edited_sql_pattern.replace(" {FROM}", ""), slot_values_new)

                #(question_gen, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all)

    #             print("edited_sql: ", edited_sql)
    #             print("column_lables_real: ", column_lables_real)
    #             print("")
                context_label_int = qsep_label_map[context_label]
                context_label_list.insert(0, context_label_int)
                
                data_new[schema_str].append((context_q, edited_sql, column_lables_real, None, slot_values, edited_sql_pattern, columns_all_prev, context_label_list))
            else:
                db = schema_dbs[schema_str]
                sql_gen, question_gen, column_lables, q_slot_values, slot_values, template, columns_all = populate_one(db, templates, templates_one_table, sql_components)
                context_q = question_gen + prev_token + question_prev
                context_label_list.insert(0, 0)
                data_new[schema_str].append((context_q, sql_gen, column_lables, q_slot_values, slot_values, template, columns_all, context_label_list))
            
    return data_new

In [None]:
augment_fourth_webtable = add_augment_context_third(augment_third_webtable, context_templates, schema_dbs_webtable)


SELECT {COLUMN12} {FROM} WHERE {COLUMN0} NOT IN ( SELECT {COLUMN0} { FROM} ) AND {COLUMN10} {OP10} {VALUE10}
processed:  10000

SELECT {COLUMN0} , {COLUMN12} {FROM} WHERE {COLUMN0} NOT IN ( SELECT {COLUMN0} { FROM} ) AND {COLUMN10} {OP10} {VALUE10}
processed:  20000

SELECT {COLUMN11} , {COLUMN12} , {COLUMN22} {FROM} WHERE {COLUMN0} IN ( SELECT {COLUMN0} { FROM} ) AND {COLUMN10} {OP10} {VALUE10}
processed:  30000
processed:  40000

SELECT {COLUMN0} , {COLUMN12} {FROM} WHERE {COLUMN0} NOT IN ( SELECT {COLUMN0} { FROM} ) AND {COLUMN10} {OP10} {VALUE10}
processed:  50000
processed:  60000
processed:  70000
processed:  80000


In [4]:
two_count = 0
for schema, examples in augment_fourth_webtable.items():
    if two_count > 100:
        break
    for ex in examples:
        two_count += 1
        sql_pattern = ex[5]
        columns_all_prev = ex[6]
        question_prev = ex[0]
        sql_prev = ex[1]
        col_labels_prev = ex[2]
        q_slot_values_prev = ex[3]
        slot_values_prev = ex[4]
        context_label_list = ex[7]
        print("\nsql_pattern: ", sql_pattern)
        print("question: ", question_prev)
        print("sql: ", sql_prev)
        print("column labels: ", col_labels_prev)
        print("slot values: ", slot_values_prev)
        print("context_label_list: ", context_label_list)

In [None]:
two_count = 0
for schema, examples in augment_fourth_webtable.items():
    for ex in examples:
        two_count += 1
#         sql_pattern = ex[5]
#         columns_all_prev = ex[6]
#         question_prev = ex[0]
#         sql_prev = ex[1]
#         col_labels_prev = ex[2]
#         q_slot_values_prev = ex[3]
#         slot_values_prev = ex[4]
#         context_label_list = ex[7]
#         print("\nsql_pattern: ", sql_pattern)
#         print("question: ", question_prev)
#         print("sql: ", sql_prev)
#         print("column labels: ", col_labels_prev)
#         print("context_label_list: ", context_label_list)
#         print("slot_values: ", slot_values_prev)
print(two_count)

54557


In [None]:
prev1_token = " </ "
prev2_token = " :/ "

##### Map SQL labels of all augmented examples into numeric labels

In [None]:
### process label prints for each column
def get_label_map(data):
    label_dict = defaultdict(int)
    for schema_str, example_list in data.items():
        for example in example_list:
            (question, sql, col_labels) = example
            for val in col_labels.values():
                label_dict[val] += 1
    label_list = sorted(label_dict.items(), key=lambda kv: kv[1], reverse=True)
    label_map = {}
    count = 1
    for label, _ in label_list:
        label_map[label] = count
        count += 1
    
    return label_map

def map_labels(data, label_map, is_dev=False):
    data_new = {}
    skip_count = 0
    count = 0
    augment_data = data.copy()
    for schema_str, exs in augment_data.items():
        count += 1
        if count % 100000 == 0:
            print("processed: ", count)
        data_new[schema_str] = []
        for ex in exs:
            skip = False
            label_dict = ex[2]
            label_dict_new = {}
            for col, label in label_dict.items():
                if label in label_map.keys():
                    label_dict_new[col] = label_map[label]
                else:
                    skip = True
                    skip_count += 1
                    #else just skip
#             context_q, edited_sql, column_lables_real, label_dict_int, slot_values, edited_sql_pattern, context_label_list
            if not skip:
        
                data_new[schema_str].append((ex[0], ex[1], ex[2], label_dict_new, ex[4], ex[5], ex[7]))   
    
    print("skip_count: ", skip_count)
    return data_new

In [None]:
import pickle
label_map_file = "data/labels_map.pkl"
# label_map_final = get_label_map(fine_tuning_data_augment_with_dev_wikisql)
# with open(label_map_file, 'wb') as fp:
#     pickle.dump(label_map_final, fp, protocol=pickle.HIGHEST_PROTOCOL)
with open(label_map_file, 'rb') as fp:
    label_map = pickle.load(fp)

In [None]:
augment_first_webtable = map_labels(augment_first_webtable, label_map)
augment_second_webtable = map_labels(augment_second_webtable, label_map)
augment_third_webtable = map_labels(augment_third_webtable, label_map)
augment_fourth_webtable = map_labels(augment_fourth_webtable, label_map)

skip_count:  0
skip_count:  4885
skip_count:  3807
skip_count:  2632


In [None]:
augment_context_all_webtable = defaultdict(list)
for augment_one in [augment_first_webtable, augment_second_webtable, augment_third_webtable, augment_fourth_webtable]:
    for schema, examples in augment_one.items():
        augment_context_all_webtable[schema].extend(examples)

two_count = 0
for schema, examples in augment_context_all_webtable.items():
    for ex in examples:
        two_count += 1
print(two_count)

409424


##### Write and save file

In [None]:
MAX_TOKEN_LEN = 200
def write_final_file(augment_data):
    data_json = []
    skip_count = 0
    line_count = 0
    dup_count = 0
    pro_count = 0
    for schema_str, exs in augment_data.items():
        for ex in exs:
            line_count += 1
            if line_count % 100000 == 0:
                print("processed: ", line_count)
            question, sql, label_strs, label_ints, sql_slot_values, sql_pattern, context_label_list = ex
            col_str, val_str = schema_str.split(" |-| ")
            colns = col_str.split(" </s> ")
            values = val_str.split(" </s> ")
            assert len(colns) == len(values)
            cols = []
            label_num = len(label_ints)
            label_count = 0
            for idx, coln in enumerate(colns):
                col = {}
                col["name"] = coln
                col["value"] = values[idx]
                if coln != "*":
                    col["name"] = " ".join(coln.split(" ")[1:])
                col["label_int"] = 0
                if coln in label_ints.keys():
                    col["label_int"] = label_ints[coln]
                    label_count += 1
                cols.append(col)
            
            assert label_count >= label_num
            if label_count > label_num:
                dup_count += 1
#                 print("\nWARNING: deplicated columns!")
#                 print("label_ints: ", label_ints)
#                 print("colns: ", colns)
            
            col_list = []
            label_list = []
            value_list = []
            col_count = 0
            for i, col in enumerate(cols):
                if col_count > 40 and col["label_int"] == 0:
                    continue
                col_list.append(col["name"])
                value_list.append(col["value"])
                col_count += 1
                label_list.append(int(col["label_int"]))
            assert len(col_list) == len(value_list)
            
            assert question.count(prev_token) + 1 == len(context_label_list)
            
            label_str = " ".join([str(k) for k in label_list])
            q_col_str = "<s> " + question.lower() + " </s> " + " </s> ".join(col_list).strip() + " </s> "
            example_str = q_col_str + " ||| " + label_str + " ||| " + " ".join([str(x) for x in context_label_list])
            tokens = tokenizer.tokenize(q_col_str)
            if len(tokens) > MAX_TOKEN_LEN:
                continue
                
            data_json.append({"question": question.lower(),
                              "columns": col_list,
                              "rows": [value_list],
                              "column_labels": label_list,
                              "example_str": example_str,
                              "context_labels": context_label_list
                             })
            pro_count += 1

    print("total line: ", line_count)
    print("skiped line: ", skip_count)
    print("dup line: ", dup_count)
    print("pro line: ", pro_count)
    
    return data_json

In [None]:
data_json = write_final_file(augment_context_all_webtable)

processed:  100000
processed:  200000
processed:  300000
processed:  400000
total line:  409424
skiped line:  0
dup line:  1577
pro line:  391439


In [None]:
len(data_json)

391439

In [None]:
with open('data/augment_wikitable_context.json', 'w') as outfile:
    json.dump(data_json, outfile)

In [None]:
import codecs
def write_to_file(sql_data, output_file):
    table_file = codecs.open(output_file, "w", "utf-8")
    valid_count = 0
    num_sql = len(sql_data)
    check_point = int(num_sql*0.1)
    max_col_num = 0
    unique_labels = set()
    skip_count = 0
    for tn, sql_one in enumerate(sql_data):
        if tn % check_point == 0:
            print("processed: ", str(round(tn/num_sql, 2)))
        example_str = sql_one['example_str']
        valid_count += 1
        table_file.write(example_str.strip().replace("\n", ""))
        #add column names in another new line
        table_file.write("\n")

    table_file.close()

    return valid_count

In [5]:
write_to_file(data_json, "data/augment_wikitable_context.txt")