In this file I am building the sql query based on what is returned from the model

In [1]:
import json
import pandas as pd

In [2]:
path_text2sql_results = '/Users/bleopold/OneDrive/data-analysis/results_from_sqlova/results_dev.jsonl'

path_dev_tables = '/Users/bleopold/OneDrive/data-analysis/vendors/WikiSQL/data/dev.tables.jsonl'

In [133]:
condition_comparison_mapping = {
    0: '='
    , 1: '<'
    , 2: '>'
}


aggregation_mapping = {
    1: 'max'
    , 2: 'min'
    , 3: 'count'
}

# Get the data I need

In [3]:
# get names of all the tables for which I have data
with open(path_dev_tables, 'r') as json_file:
    dev_tables_raw_file = list(json_file)

tables_with_data = []

for json_str in dev_tables_raw_file:
    result = json.loads(json_str)
    table_id = result["id"]
    
    tables_with_data.append(table_id)

In [4]:
# for all these tables I have a text to sql result
with open(path_text2sql_results, 'r') as json_file:
    text2sql_results_raw_file = list(json_file)

tables_with_text2sql_result = []

for json_str in text2sql_results_raw_file:
    result = json.loads(json_str)
    table_id = result["table_id"]
    
    tables_with_text2sql_result.append(table_id)
    
tables_with_text2sql_result = list(set(tables_with_text2sql_result))

In [5]:
# for all these tables I have data and a text to sql result
"""
get the intersect of tables_with_data and tables_with_text2sql_result
"""
ids_usable_tables = [_ for _ in tables_with_text2sql_result if _ in tables_with_data]

In [6]:
mapping_table_id_table_name = pd.read_csv(
        "../data/01_raw/wiki-sql/schema_infos/mapping_table_id_table_name.csv"
        ,  sep = '\t'
    )

mapping_table_id_table_name

Unnamed: 0,table_id,table_name
0,1-10015132-9,toronto_raptors_all-time_roster
1,1-10015132-9,toronto_raptors_all-time_roster


# Define functions

In [85]:
def get_table_name_for_table_id(table_id):
    # holds the table_id in one column and the table_name in the other column
    mapping_table_id_table_name = pd.read_csv(
        "../data/01_raw/wiki-sql/schema_infos/mapping_table_id_table_name.csv"
        ,  sep = '\t'
    )
    
    
    # get the table_name using table_id
    table_name = list(mapping_table_id_table_name[
        mapping_table_id_table_name["table_id"]==table_id
    ]["table_name"])[0]
    
    
    
    return table_name

In [8]:
def replace_enclosed_where_conditions_with_quotes(sql, condition_infos):
    """
    example for condition_infos: [[3, 0, '4th, atlantic division']]
    
    place each condition in the array between dobles quotes
    """
    for condition_info in condition_infos:
        condition_value_orig = condition_info[2]
        condition_value_quoted = "'{}'".format(condition_value_orig)
        sql = sql.replace(condition_value_orig, condition_value_quoted)

    return sql

In [9]:
def insert_clean_column_names(sql, condition_infos, table_id):
    mapping_column_names = pd.read_csv(
        "../data/01_raw/wiki-sql/schema_infos/information_schema.csv"
        ,  sep = '\t'
    )
    
    for condition_info in condition_infos:
        column_order = condition_info[0]
    
    
        column_name_original  = list(
            mapping_column_names[
                (mapping_column_names["table_id"] == table_id)
                &
                (mapping_column_names["column_order"] == column_order)
            ]["column_name_original"]
        )[0]

        column_name_clean  = list(
            mapping_column_names[
                (mapping_column_names["table_id"] == table_id)
                &
                (mapping_column_names["column_order"] == column_order)
            ]["column_name_clean"]
        )[0]


        sql = sql.replace(column_name_original, column_name_clean)
        
        
    return sql

In [28]:
def get_clean_column_name_for_column_id(column_id, table_id):
    """
    column_id = Int
    table_id = String
    """
    inf_schema = pd.read_csv(
        "../data/01_raw/wiki-sql/schema_infos/information_schema.csv"
        , sep='\t'
    )
    column_name_clean = list(
        inf_schema[
            (inf_schema["table_id"] == table_id)
            &
            (inf_schema["column_order"] == column_id)
        ]["column_name_clean"]
    )[0]
    return column_name_clean

# Convert original sql to parse-able sql

In [152]:
result = {'query': {'agg': 3, 'sel': 5, 'conds': [[1, 0, '3']]},
 'table_id': '1-10015132-11',
 'nlu': 'How many schools did player number 3 play at?',
 'sql': 'SELECT count(School/Club Team) FROM 1-10015132-11 WHERE No. = 3'}

In [155]:
table_id = '1-10015132-9'
def get_complete_sql_query(result, table_id):
    select_statement = get_select_statement(result, table_id)
    from_statement = get_from_statement(table_id)
    where_statement  = get_where_statement(result)
    
    print(where_statement)
    
get_complete_sql_query(result, table_id)

 where years = '3'


In [139]:
def get_select_statement(result, table_id):
    select_column_index = result["query"].get('sel')
    select_column_name = get_clean_column_name_for_column_id(select_column_index, table_id)

    select_statement = 'select {}'.format(select_column_name)

    # check if an aggregation exist
    if result["query"].get('agg') != 0:
        aggregation_code = result["query"].get('agg')
        aggregation_value = aggregation_mapping.get(aggregation_code)
        select_statement = 'select {}({})'.format(aggregation_value, select_column_name)


    return select_statement

In [143]:
def get_from_statement(table_id):
    table_name = get_table_name_for_table_id(table_id)
    from_statement = ' from {}'.format(table_name)
    
    return from_statement

In [150]:
def get_where_statement(result):
    conditions = result['query'].get('conds')
    condition_statements = []
    condition_statements_separator = ' and '

    for condition in conditions:
        condition_statement = get_individual_conditions(condition)
        condition_statements.append(condition_statement)


    return ' where {}'.format(condition_statements_separator.join(condition_statements))

In [147]:
def get_individual_conditions(condition):
    condition_column_index = condition[0]
    condition_column_name = get_clean_column_name_for_column_id(
        condition_column_index, '2-12601141-1'
    )

    condition_comparison_code = condition[1]
    condition_comparison_value = condition_comparison_mapping.get(condition_comparison_code)

    condition_content = "'{}'".format(condition[2])
    
    
    condition_statement = '{} {} {}'.format(
        condition_column_name, condition_comparison_value, condition_content
    )
    
    return condition_statement

# Probleme
* Query selber bauen? replacing ist gefährlich
* toronto_raptors_all-time_roster nicht korrekt slugified
* selects dürfen nicht in klammer sein + müssen slugified sein --> ACHTUNG: darf kein einfaches replace sein nur auf den spaltennamen, weil dder spaltennamen in der tabelle auch stehen könnte --> replace auf inkl "select" + das aggregation (zB count)
* Probleme mit case-sensitive

# Aufgaben 
* get_complete_sql_query(result, table_id) zum laufen bringen