In [1]:
#############################################
#
# Reference: https: //github.com/RUCKBReasoning/RESDSQL.git
# Reference script: preprocessing.py
#
#############################################

In [3]:
import os
import re
import json
import argparse

from bridge_content_encoder import get_database_matches
from sql_metadata import Parser
from tqdm import tqdm

sql_keywords = ['select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', \
    'except', 'join', 'on', 'as', 'not', 'between', 'in', 'like', 'is', 'exists', 'max', 'min', \
        'count', 'sum', 'avg', 'and', 'or', 'desc', 'asc']

In [4]:
def get_db_contents(question, table_name_original, column_names_original, db_id, db_path):
    matched_contents = []
    # extract matched contents for each column
    for column_name_original in column_names_original:
        matches = get_database_matches(
            question, 
            table_name_original, 
            column_name_original, 
            db_path + "/{}/{}.sqlite".format(db_id, db_id)
        )
        matches = sorted(matches)
        matched_contents.append(matches)
    
    return matched_contents

def get_db_schemas(all_db_infos):
    db_schemas = {}

    for db in all_db_infos:
        table_names_original = db["table_names_original"]
        table_names = db["table_names"]
        column_names_original = db["column_names_original"]
        column_names = db["column_names"]
        column_types = db["column_types"]

        db_schemas[db["db_id"]] = {}
        
        primary_keys, foreign_keys = [], []
        # record primary keys
        for pk_column_idx in db["primary_keys"]:
            pk_table_name_original = table_names_original[column_names_original[pk_column_idx][0]]
            pk_column_name_original = column_names_original[pk_column_idx][1]
            
            primary_keys.append(
                {
                    "table_name_original": pk_table_name_original.lower(), 
                    "column_name_original": pk_column_name_original.lower()
                }
            )

        db_schemas[db["db_id"]]["pk"] = primary_keys

        # record foreign keys
        for source_column_idx, target_column_idx in db["foreign_keys"]:
            fk_source_table_name_original = table_names_original[column_names_original[source_column_idx][0]]
            fk_source_column_name_original = column_names_original[source_column_idx][1]

            fk_target_table_name_original = table_names_original[column_names_original[target_column_idx][0]]
            fk_target_column_name_original = column_names_original[target_column_idx][1]
            
            foreign_keys.append(
                {
                    "source_table_name_original": fk_source_table_name_original.lower(),
                    "source_column_name_original": fk_source_column_name_original.lower(),
                    "target_table_name_original": fk_target_table_name_original.lower(),
                    "target_column_name_original": fk_target_column_name_original.lower(),
                }
            )
        db_schemas[db["db_id"]]["fk"] = foreign_keys

        db_schemas[db["db_id"]]["schema_items"] = []
        for idx, table_name_original in enumerate(table_names_original):
            column_names_original_list = []
            column_names_list = []
            column_types_list = []
            
            for column_idx, (table_idx, column_name_original) in enumerate(column_names_original):
                if idx == table_idx:
                    column_names_original_list.append(column_name_original.lower())
                    column_names_list.append(column_names[column_idx][1].lower())
                    column_types_list.append(column_types[column_idx])
            
            db_schemas[db["db_id"]]["schema_items"].append({
                "table_name_original": table_name_original.lower(),
                "table_name": table_names[idx].lower(), 
                "column_names": column_names_list, 
                "column_names_original": column_names_original_list,
                "column_types": column_types_list
            })

    return db_schemas

In [5]:
def normalization(sql):
    def white_space_fix(s):
        parsed_s = Parser(s)
        s = " ".join([token.value for token in parsed_s.tokens])

        return s

    # convert everything except text between single quotation marks to lower case
    def lower(s):
        in_quotation = False
        out_s = ""
        for char in s:
            if in_quotation:
                out_s += char
            else:
                out_s += char.lower()
            
            if char == "'":
                if in_quotation:
                    in_quotation = False
                else:
                    in_quotation = True
        
        return out_s
    
    # remove ";"
    def remove_semicolon(s):
        if s.endswith(";"):
            s = s[:-1]
        return s

    # double quotation -> single quotation 
    def double2single(s):
        return s.replace("\"", "'") 
    
    def add_asc(s):
        pattern = re.compile(r'order by (?:\w+ \( \S+ \)|\w+\.\w+|\w+)(?: (?:\+|\-|\<|\<\=|\>|\>\=) (?:\w+ \( \S+ \)|\w+\.\w+|\w+))*')
        if "order by" in s and "asc" not in s and "desc" not in s:
            for p_str in pattern.findall(s):
                s = s.replace(p_str, p_str + " asc")

        return s

    def remove_table_alias(s):
        tables_aliases = Parser(s).tables_aliases
        new_tables_aliases = {}
        for i in range(1,11):
            if "t{}".format(i) in tables_aliases.keys():
                new_tables_aliases["t{}".format(i)] = tables_aliases["t{}".format(i)]
        
        tables_aliases = new_tables_aliases
        for k, v in tables_aliases.items():
            s = s.replace("as " + k + " ", "")
            s = s.replace(k, v)
        
        return s
    
    processing_func = lambda x : remove_table_alias(add_asc(lower(white_space_fix(double2single(remove_semicolon(x))))))
    
    return processing_func(sql)


In [6]:
# extract the skeleton of sql and natsql
def extract_skeleton(sql, db_schema):
    table_names_original, table_dot_column_names_original, column_names_original = [], [], []
    for table in db_schema["schema_items"]:
        table_name_original = table["table_name_original"]
        table_names_original.append(table_name_original)

        for column_name_original in ["*"]+table["column_names_original"]:
            table_dot_column_names_original.append(table_name_original+"."+column_name_original)
            column_names_original.append(column_name_original)
    
    parsed_sql = Parser(sql)
    new_sql_tokens = []
    for token in parsed_sql.tokens:
        # mask table names
        if token.value in table_names_original:
            new_sql_tokens.append("_")
        # mask column names
        elif token.value in column_names_original \
            or token.value in table_dot_column_names_original:
            new_sql_tokens.append("_")
        # mask string values
        elif token.value.startswith("'") and token.value.endswith("'"):
            new_sql_tokens.append("_")
        # mask positive int number
        elif token.value.isdigit():
            new_sql_tokens.append("_")
        # mask negative int number
        elif isNegativeInt(token.value):
            new_sql_tokens.append("_")
        # mask float number
        elif isFloat(token.value):
            new_sql_tokens.append("_")
        else:
            new_sql_tokens.append(token.value.strip())

    sql_skeleton = " ".join(new_sql_tokens)
    
    # remove JOIN ON keywords
    sql_skeleton = sql_skeleton.replace("on _ = _ and _ = _", "on _ = _")
    sql_skeleton = sql_skeleton.replace("on _ = _ or _ = _", "on _ = _")
    sql_skeleton = sql_skeleton.replace(" on _ = _", "")
    pattern3 = re.compile("_ (?:join _ ?)+")
    sql_skeleton = re.sub(pattern3, "_ ", sql_skeleton)

    # "_ , _ , ..., _" -> "_"
    while("_ , _" in sql_skeleton):
        sql_skeleton = sql_skeleton.replace("_ , _", "_")
    
    # remove clauses in WHERE keywords
    ops = ["=", "!=", ">", ">=", "<", "<="]
    for op in ops:
        if "_ {} _".format(op) in sql_skeleton:
            sql_skeleton = sql_skeleton.replace("_ {} _".format(op), "_")
    while("where _ and _" in sql_skeleton or "where _ or _" in sql_skeleton):
        if "where _ and _"in sql_skeleton:
            sql_skeleton = sql_skeleton.replace("where _ and _", "where _")
        if "where _ or _" in sql_skeleton:
            sql_skeleton = sql_skeleton.replace("where _ or _", "where _")

    # remove additional spaces in the skeleton
    while "  " in sql_skeleton:
        sql_skeleton = sql_skeleton.replace("  ", " ")

    return sql_skeleton

def isNegativeInt(string):
    if string.startswith("-") and string[1:].isdigit():
        return True
    else:
        return False

def isFloat(string):
    if string.startswith("-"):
        string = string[1:]
    
    s = string.split(".")
    if len(s)>2:
        return False
    else:
        for s_i in s:
            if not s_i.isdigit():
                return False
        return True

In [7]:
import collections
from nltk import word_tokenize

max_encode_len = 0
cnt = collections.Counter()

VALUE_NUM_SYMBOL = "{value}"

def get_encode_Query(query):
    
    global max_encode_len
    global cnt
    
    tokens = strip_nl(query)
    cnt.update(tokens)
    max_encode_len = max(max_encode_len, len(tokens))
    token_sentence = " ".join(tokens)

    return token_sentence

def strip_nl(nl):
    '''
    return keywords of nl query
    '''
    nl_keywords = []
    nl = nl.strip()
    nl = nl.replace(";"," ; ").replace(",", " , ").replace("?", " ? ").replace("\t"," ")
    nl = nl.replace("(", " ( ").replace(")", " ) ")
    
    nl = nl.replace('\"', "'")
    nl = nl.replace("\'", "'")
    
#     str_1 = re.findall("\"[^\"]*\"", nl)
#     str_2 = re.findall("\'[^\']*\'", nl)
#     float_nums = re.findall("[-+]?\d*\.\d+", nl)
#     values = str_1 + str_2 + float_nums
# #     print(values)
#     for val in values:
#         nl = nl.replace(val.strip(), VALUE_NUM_SYMBOL)

    def to_lower(s):
        in_quotation = False
        out_s = ""
        for char in s:
            if in_quotation:
                out_s += char
            elif char == ".":
                if in_quotation:
                    out_s += char
                else:
                    out_s += " . "
            else:
                out_s += char.lower()

            if char == "'":
                if in_quotation:
                    in_quotation = False
                else:
                    in_quotation = True
            
        return out_s
    
    raw_keywords = nl.strip().split()
#     print(raw_keywords)
    for tok in raw_keywords:
#         print(tok)
        if "." in tok:
            to = tok.split()
#             to = tok.replace(".", " . ").split()
#             print("---", to)
#             to = [t.lower() for t in to if len(t)>0]
            nl_keywords.extend(to)
        elif "'" in tok and tok[0]!="'" and tok[-1]!="'":
            to = word_tokenize(tok)
#             to = [t.lower() for t in to if len(t)>0]
            nl_keywords.extend(to)      
#         elif len(tok) > 0:
#             nl_keywords.append(tok.lower())
        else:
            nl_keywords.append(tok)
    
    nl_keywords = to_lower(" ".join(nl_keywords))
    nl_keywords_2 = nl_keywords.split(" ")
    
    return nl_keywords_2


In [8]:
sample = "Find the last name of the latest contact individual of the enrico09@example.com organization \"Labour Party\"."
output = get_encode_Query(sample)
output

"find the last name of the latest contact individual of the enrico09@example . com organization 'Labour Party' . "

In [15]:
def main(input_dataset_path,
         output_dataset_path,
         table_path,
         db_path,
         mode):
    
    max_len_decoder_target = 0
    max_len_classifier_input = 0
    
    dataset = json.load(open(input_dataset_path))
    print(f"Total data points in {input_dataset_path.split('/')[-1]} is {len(dataset)}")
    all_db_infos = json.load(open(table_path))
    
    assert mode in ["train", "eval", "test"]
    
    db_schemas = get_db_schemas(all_db_infos)
    
    preprocessed_dataset = []

#     for natsql_data, data in tqdm(zip(natsql_dataset, dataset)):
    for data in tqdm(dataset):
        if data['query'] == 'SELECT T1.company_name FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id  =  T2.maintenance_contract_company_id JOIN Ref_Company_Types AS T3 ON T1.company_type_code  =  T3.company_type_code ORDER BY T2.contract_end_date DESC LIMIT 1':
            data['query'] = 'SELECT T1.company_type FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id  =  T2.maintenance_contract_company_id ORDER BY T2.contract_end_date DESC LIMIT 1'
            data['query_toks'] = ['SELECT', 'T1.company_type', 'FROM', 'Third_Party_Companies', 'AS', 'T1', 'JOIN', 'Maintenance_Contracts', 'AS', 'T2', 'ON', 'T1.company_id', '=', 'T2.maintenance_contract_company_id', 'ORDER', 'BY', 'T2.contract_end_date', 'DESC', 'LIMIT', '1']
            data['query_toks_no_value'] =  ['select', 't1', '.', 'company_type', 'from', 'third_party_companies', 'as', 't1', 'join', 'maintenance_contracts', 'as', 't2', 'on', 't1', '.', 'company_id', '=', 't2', '.', 'maintenance_contract_company_id', 'order', 'by', 't2', '.', 'contract_end_date', 'desc', 'limit', 'value']
            data['question'] = 'What is the type of the company who concluded its contracts most recently?'
            data['question_toks'] = ['What', 'is', 'the', 'type', 'of', 'the', 'company', 'who', 'concluded', 'its', 'contracts', 'most', 'recently', '?']
        if data['query'].startswith('SELECT T1.fname FROM student AS T1 JOIN lives_in AS T2 ON T1.stuid  =  T2.stuid WHERE T2.dormid IN'):
            data['query'] = data['query'].replace('IN (SELECT T2.dormid)', 'IN (SELECT T3.dormid)')
            index = data['query_toks'].index('(') + 2
            assert data['query_toks'][index] == 'T2.dormid'
            data['query_toks'][index] = 'T3.dormid'
            index = data['query_toks_no_value'].index('(') + 2
            assert data['query_toks_no_value'][index] == 't2'
            data['query_toks_no_value'][index] = 't3'

        question = data["question"].replace("\u2018", "'").replace("\u2019", "'").replace("\u201c", "'").replace("\u201d", "'").strip()
        db_id = data["db_id"]
        
        if mode == "test":
            sql, norm_sql, sql_skeleton = "", "", ""
            sql_tokens = []

        else:
            sql = data["query"].strip()
            norm_sql = normalization(sql).strip()
            sql_skeleton = extract_skeleton(norm_sql, db_schemas[db_id]).strip()
            sql_tokens = norm_sql.split()

        norm_question = get_encode_Query(question)
       
        preprocessed_data = {}
        preprocessed_data["question"] = question
        preprocessed_data["norm_question"] = norm_question
        preprocessed_data["db_id"] = db_id
        preprocessed_data["classifier_input"] = question
#         preprocessed_data["classifier_labels"] = []

        preprocessed_data["sql"] = sql
        preprocessed_data["norm_sql"] = norm_sql
        preprocessed_data["sql_skeleton"] = sql_skeleton
        preprocessed_data["decoder_target"] = sql_skeleton + " | " + norm_sql
        
        preprocessed_data["db_schema"] = []
        preprocessed_data["pk"] = db_schemas[db_id]["pk"]
        preprocessed_data["fk"] = db_schemas[db_id]["fk"]
        preprocessed_data["table_labels"] = []
        preprocessed_data["column_labels"] = []
        
        # add database information (including table name, column name, ..., table_labels, and column labels)
        for table in db_schemas[db_id]["schema_items"]:
            db_contents = get_db_contents(
                question, 
                table["table_name_original"], 
                table["column_names_original"], 
                db_id, 
                db_path
            )

            preprocessed_data["db_schema"].append({
                "table_name_original":table["table_name_original"],
                "table_name":table["table_name"],
                "column_names":table["column_names"],
                "column_names_original":table["column_names_original"],
                "column_types":table["column_types"],
                "db_contents": db_contents
            })

            # extract table and column classification labels
            if table["table_name_original"] in sql_tokens:  # for used tables
                preprocessed_data["table_labels"].append(1)
                column_labels = []
                for column_name_original in table["column_names_original"]:
                    if column_name_original in sql_tokens or \
                        table["table_name_original"]+"."+column_name_original in sql_tokens: # for used columns
                        column_labels.append(1)
                    else:
                        column_labels.append(0)
                preprocessed_data["column_labels"].append(column_labels)
            else:  # for unused tables and their columns
                preprocessed_data["table_labels"].append(0)
                preprocessed_data["column_labels"].append([0 for _ in range(len(table["column_names_original"]))])
                
            # create classifier input
            preprocessed_data["classifier_input"] += " | " + table["table_name"] + ": "
            
            for idx_, col_names in enumerate(table["column_names_original"]):
                preprocessed_data["classifier_input"] += col_names
                
                if idx_ < len(table["column_names_original"])-1:
                    preprocessed_data["classifier_input"] += ', '
                    
        preprocessed_data["classifier_input"] += " | " + sql_skeleton
                    
        # generate classifier labels
#         for idx_, flag in enumerate(preprocessed_data["table_labels"]):
#             preprocessed_data["classifier_labels"].append(flag)
#             preprocessed_data["classifier_labels"].append(flag)
        
#         print((preprocessed_data["classifier_input"]).split(" "))
        max_len_decoder_target = max(max_len_decoder_target,
                                     len((preprocessed_data["decoder_target"]).split(" ")))
#         max_len_classifier_input = max(max_len_classifier_input, len((preprocessed_data["classifier_input"]).split(" ")))
        
        preprocessed_dataset.append(preprocessed_data)
#         break
    
    with open(output_dataset_path, "w") as f:
        preprocessed_dataset_str = json.dumps(preprocessed_dataset, indent = 2)
        f.write(preprocessed_dataset_str)
        
#     print(f"Max encoder input length - {max_len_classifier_input}")
    print(f"Max decoder input length - {max_len_decoder_target}")
        
    print('Done')
    
    return preprocessed_dataset

In [16]:
data_path = '/Users/aishwarya/Downloads/spring23/cs685-NLP/project/spider_data'
table_path_1 = os.path.join(data_path, "tables.json")
db_path_1 = os.path.join(data_path, "database")

data_path_target = "/Users/aishwarya/Downloads/spring23/cs685-NLP/project/data/resdsql_pre"
if not os.path.isdir(data_path_target):
    os.makedirs(data_path_target)
    
mode_1 = "train"
input_dataset_path_1 = "/Users/aishwarya/Downloads/spring23/cs685-NLP/project/data/split/spider_test.json"
input_dataset_path_1 = os.path.join(data_path, "dev.json") # "train_spider.json", dev.json    
output_dataset_path_1 = os.path.join(data_path_target, "preprocessed_dataset_dev.json")

processed_data = main(input_dataset_path_1,
                         output_dataset_path_1,
                         table_path_1,
                         db_path_1,
                         mode_1)

Total data points in dev.json is 1034


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1034/1034 [00:48<00:00, 21.16it/s]


Max decoder input length - 81
Done


In [11]:
def prompt_data(input_dataset_path,
                 output_dataset_path,
                 table_path,
                 db_path,
                 mode):
    
    max_len_decoder_target = 0
    max_len_classifier_input = 0
    
    counter = 0
    
    dataset = json.load(open(input_dataset_path))
    print(f"Total data points in {input_dataset_path.split('/')[-1]} is {len(dataset)}")
    all_db_infos = json.load(open(table_path))
    
    assert mode in ["train", "eval", "test"]
    
    db_schemas = get_db_schemas(all_db_infos)
    
    preprocessed_dataset = []

#     for natsql_data, data in tqdm(zip(natsql_dataset, dataset)):
    for data in tqdm(dataset):
        if data['query'] == 'SELECT T1.company_name FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id  =  T2.maintenance_contract_company_id JOIN Ref_Company_Types AS T3 ON T1.company_type_code  =  T3.company_type_code ORDER BY T2.contract_end_date DESC LIMIT 1':
            data['query'] = 'SELECT T1.company_type FROM Third_Party_Companies AS T1 JOIN Maintenance_Contracts AS T2 ON T1.company_id  =  T2.maintenance_contract_company_id ORDER BY T2.contract_end_date DESC LIMIT 1'
            data['query_toks'] = ['SELECT', 'T1.company_type', 'FROM', 'Third_Party_Companies', 'AS', 'T1', 'JOIN', 'Maintenance_Contracts', 'AS', 'T2', 'ON', 'T1.company_id', '=', 'T2.maintenance_contract_company_id', 'ORDER', 'BY', 'T2.contract_end_date', 'DESC', 'LIMIT', '1']
            data['query_toks_no_value'] =  ['select', 't1', '.', 'company_type', 'from', 'third_party_companies', 'as', 't1', 'join', 'maintenance_contracts', 'as', 't2', 'on', 't1', '.', 'company_id', '=', 't2', '.', 'maintenance_contract_company_id', 'order', 'by', 't2', '.', 'contract_end_date', 'desc', 'limit', 'value']
            data['question'] = 'What is the type of the company who concluded its contracts most recently?'
            data['question_toks'] = ['What', 'is', 'the', 'type', 'of', 'the', 'company', 'who', 'concluded', 'its', 'contracts', 'most', 'recently', '?']
        if data['query'].startswith('SELECT T1.fname FROM student AS T1 JOIN lives_in AS T2 ON T1.stuid  =  T2.stuid WHERE T2.dormid IN'):
            data['query'] = data['query'].replace('IN (SELECT T2.dormid)', 'IN (SELECT T3.dormid)')
            index = data['query_toks'].index('(') + 2
            assert data['query_toks'][index] == 'T2.dormid'
            data['query_toks'][index] = 'T3.dormid'
            index = data['query_toks_no_value'].index('(') + 2
            assert data['query_toks_no_value'][index] == 't2'
            data['query_toks_no_value'][index] = 't3'

        question = data["question"].replace("\u2018", "'").replace("\u2019", "'").replace("\u201c", "'").replace("\u201d", "'").strip()
        db_id = data["db_id"]
        
        if mode == "test":
            sql, norm_sql, sql_skeleton = "", "", ""
            sql_tokens = []

        else:
            sql = data["query"].strip()
            norm_sql = normalization(sql).strip()
            sql_skeleton = extract_skeleton(norm_sql, db_schemas[db_id]).strip()
            sql_tokens = norm_sql.split()

        norm_question = get_encode_Query(question)
       
        preprocessed_data = {}
        preprocessed_data["question"] = question
        preprocessed_data["norm_question"] = norm_question
        preprocessed_data["db_id"] = db_id
#         preprocessed_data["classifier_input"] = question
#         preprocessed_data["classifier_labels"] = []

        preprocessed_data["sql"] = sql
        preprocessed_data["norm_sql"] = norm_sql
        preprocessed_data["sql_skeleton"] = sql_skeleton
        preprocessed_data["decoder_target"] = sql_skeleton + " | " + norm_sql
        
        # add database information (including table name, column name, ..., table_labels, and column labels)
        for table in db_schemas[db_id]["schema_items"]:

            # create classifier input
            preprocessed_data["classifier_input"] += " | " + table["table_name"] + ": "
            
            for idx_, col_names in enumerate(table["column_names_original"]):
                preprocessed_data["classifier_input"] += col_names
                
                if idx_ < len(table["column_names_original"])-1:
                    preprocessed_data["classifier_input"] += ', '
                    
#         preprocessed_data["classifier_input"] += " | " + sql_skeleton
        
#         max_len_classifier_input = max(max_len_classifier_input, len((preprocessed_data["classifier_input"]).split(" ")))
        
        preprocessed_dataset.append(preprocessed_data)
        counter += 1
        if counter > 1: break
    
    with open(output_dataset_path, "w") as f:
        preprocessed_dataset_str = json.dumps(preprocessed_dataset, indent = 2)
        f.write(preprocessed_dataset_str)
        
    print(f"Max encoder input length - {max_len_classifier_input}")
#     print(f"Max decoder input length - {max_len_decoder_target}")
        
    print('Done')
    
    return preprocessed_dataset

In [12]:
data_path = '/Users/aishwarya/Downloads/spring23/cs685-NLP/project/spider_data'
table_path_1 = os.path.join(data_path, "tables.json")
db_path_1 = os.path.join(data_path, "database")

data_path_target = "/Users/aishwarya/Downloads/spring23/cs685-NLP/project/data/resdsql_pre"
if not os.path.isdir(data_path_target):
    os.makedirs(data_path_target)
    
mode_1 = "train"
# input_dataset_path_1 = "/Users/aishwarya/Downloads/spring23/cs685-NLP/project/data/split/spider_test.json"
input_dataset_path_1 = os.path.join(data_path, "dev.json") # "train_spider.json", dev.json    
output_dataset_path_1 = os.path.join(data_path_target, "preprocessed_dataset_dev.json")

processed_data = prompt_data(input_dataset_path_1,
                             output_dataset_path_1,
                             table_path_1,
                             db_path_1,
                             mode_1)

Total data points in dev.json is 1034


  0%|                                                                                                                                                              | 0/1034 [00:00<?, ?it/s]


KeyError: 'classifier_input'

## Check processed data

In [110]:
for block in processed_data:
    if "\"" in block["question"]:
        for k,v in block.items():
            if k in ["question", "norm_question", "sql", "norm_sql"]:
                print(f"{k} - {v}")
        print("\n")

question - Find the payment method and phone of the party with email "enrico09@example.com".
norm_question - find the payment method and phone of the party with email 'enrico09@example.com' . 
sql - SELECT payment_method_code ,  party_phone FROM parties WHERE party_email  =  "enrico09@example.com"
norm_sql - select payment_method_code , party_phone from parties where party_email = 'enrico09@example.com'


question - Find the last name of the latest contact individual of the organization "Labour Party".
norm_question - find the last name of the latest contact individual of the organization 'Labour Party' . 
sql - SELECT t3.individual_last_name FROM organizations AS t1 JOIN organization_contact_individuals AS t2 ON t1.organization_id  =  t2.organization_id JOIN individuals AS t3 ON t2.individual_id  =  t3.individual_id WHERE t1.organization_name  =  "Labour Party" ORDER BY t2.date_contact_to DESC LIMIT 1
norm_sql - select individuals.individual_last_name from organizations join organizat

## Check if classifier data loader works on above dataset

In [8]:
from load_dataset import ColumnAndTableClassifierDataset
from torch.utils.data import DataLoader

In [9]:
train_dataset = ColumnAndTableClassifierDataset(
        dir_ = "/Users/aishwarya/Downloads/spring23/cs685-NLP/project/data/resdsql_pre/preprocessed_dataset.json",
        use_contents = True,
        add_fk_info = False
    )

train_dataloder = DataLoader(
        train_dataset, 
        batch_size = 1, 
        shuffle = True,
        collate_fn = lambda x: x
    )

In [10]:
for batch in train_dataloder:
    print(batch)

[('How many singers do we have?', ['stadium', 'singer', 'concert', 'singer in concert'], [0, 1, 0, 0], [['stadium id', 'location', 'name', 'capacity', 'highest', 'lowest', 'average'], ['singer id', 'name', 'country', 'song name', 'song release year', 'age', 'is male'], ['concert id', 'concert name', 'theme', 'stadium id', 'year'], ['concert id', 'singer id']], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])]


In [11]:
for batch in train_dataloder:
    print(batch)

[('How many singers do we have?', ['stadium', 'singer', 'concert', 'singer in concert'], [0, 1, 0, 0], [['stadium id', 'location', 'name', 'capacity', 'highest', 'lowest', 'average'], ['singer id', 'name', 'country', 'song name', 'song release year', 'age', 'is male'], ['concert id', 'concert name', 'theme', 'stadium id', 'year'], ['concert id', 'singer id']], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])]


In [12]:
batch_questions = [data[0] for data in batch]
    
batch_table_names = [data[1] for data in batch]
batch_table_labels = [data[2] for data in batch]

batch_column_infos = [data[3] for data in batch]
batch_column_labels = [data[4] for data in batch]

In [13]:
batch_questions

['How many singers do we have?']

In [14]:
batch_table_names

[['stadium', 'singer', 'concert', 'singer in concert']]

In [15]:
batch_table_labels

[[0, 1, 0, 0]]

In [16]:
batch_column_infos

[[['stadium id',
   'location',
   'name',
   'capacity',
   'highest',
   'lowest',
   'average'],
  ['singer id',
   'name',
   'country',
   'song name',
   'song release year',
   'age',
   'is male'],
  ['concert id', 'concert name', 'theme', 'stadium id', 'year'],
  ['concert id', 'singer id']]]

In [17]:
batch_table_labels

[[0, 1, 0, 0]]

## Check changes

In [2]:
def strip_nl(nl):
    '''
    return keywords of nl query
    '''
    nl_keywords = []
    nl = nl.strip()
    nl = nl.replace(";"," ; ").replace(",", " , ").replace("?", " ? ").replace("\t"," ")
    nl = nl.replace("(", " ( ").replace(")", " ) ")
    
    nl = nl.replace('\"', "'")
    nl = nl.replace("\'", "'")
    
#     str_1 = re.findall("\"[^\"]*\"", nl)
#     str_2 = re.findall("\'[^\']*\'", nl)
#     float_nums = re.findall("[-+]?\d*\.\d+", nl)
#     values = str_1 + str_2 + float_nums
# #     print(values)
#     for val in values:
#         nl = nl.replace(val.strip(), VALUE_NUM_SYMBOL)

    def to_lower(s):
        in_quotation = False
        out_s = ""
        for char in s:
            if in_quotation:
                out_s += char
            elif char == ".":
                if in_quotation:
                    out_s += char
                else:
                    out_s += " . "
            else:
                out_s += char.lower()

            if char == "'":
                if in_quotation:
                    in_quotation = False
                else:
                    in_quotation = True
            
        return out_s
    
    raw_keywords = nl.strip().split()
#     print(raw_keywords)
    for tok in raw_keywords:
#         print(tok)
        if "." in tok:
            to = tok.split()
#             to = tok.replace(".", " . ").split()
#             print("---", to)
#             to = [t.lower() for t in to if len(t)>0]
            nl_keywords.extend(to)
        elif "'" in tok and tok[0]!="'" and tok[-1]!="'":
            to = word_tokenize(tok)
#             to = [t.lower() for t in to if len(t)>0]
            nl_keywords.extend(to)      
#         elif len(tok) > 0:
#             nl_keywords.append(tok.lower())
        else:
            nl_keywords.append(tok)
    
    nl_keywords = to_lower(" ".join(nl_keywords))
    nl_keywords_2 = nl_keywords.split(" ")
    
    return nl_keywords_2