# Templates for Generating SQL Explanations

Reproducing "Speak to your Parser: Interactive Text-to-SQL with Natural Language Feedback" 

In [1]:
import sys, os
from pathlib import  Path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(Path(SCRIPT_DIR).parent))
from utils.sqlComponents import Query     # The convinent data structure for SQL queries

import re

ImportError: attempted relative import with no known parent package

In [2]:
def get_explanation(sql_parse: Query, step: int=1, join_table: bool=False):
    explanations = []
    if not sql_parse.intersect.is_empty:
        sub_sql = sql_parse.intersect.sub_query
        sql_parse.intersect.empty()
        sent1, step = get_explanation(sql_parse, step)
        part1_step = step - 1
        sent2, step = get_explanation(sub_sql, step)
        part2_step = step - 1
        sent3 = "show the rows that are in both the results of step %d and step %d" % (part1_step, part2_step)
        explanations = sent1 + sent2 + [sent3]
        step += 1

    elif not sql_parse.union.is_empty:
        sub_sql = sql_parse.union.sub_query
        sql_parse.union.empty()
        sent1, step = get_explanation(sql_parse, step)
        part1_step = step - 1
        sent2, step = get_explanation(sub_sql, step)
        part2_step = step - 1
        sent3 = "show the rows that are in any of the results of step %d and step %d" % (part1_step, part2_step)
        explanations = sent1 + sent2 + [sent3]
        step += 1

    elif not sql_parse.exceptt.is_empty:
        sub_sql = sql_parse.exceptt.sub_query
        sql_parse.exceptt.empty()
        sent1, step = get_explanation(sql_parse, step)
        part1_step = step - 1
        sent2, step = get_explanation(sub_sql, step)
        part2_step = step - 1
        sent3 = "show the rows that are in the results of step %d but not in the results of step %d" % (part1_step, part2_step)
        explanations = sent1 + sent2 + [sent3]
        step += 1
    
    elif not sql_parse.limit.is_empty and sql_parse.limit.val > 1:
        sent2 = "only show the first %d rows of the results" % sql_parse.limit.val
        # remove from sql_parse
        sql_parse.limit.empty()
        sent1, step = get_explanation(sql_parse, step)
        explanations = sent1 + [sent2]
        step += 1

    elif not sql_parse.fromm.is_empty and isinstance(sql_parse.fromm.tables[0], Query):
        # when nested sub query in FROM clause
        assert len(sql_parse.fromm.tables) == 1, "Apart of the subquery, more things in the FROM clause"
        sub_query = sql_parse.fromm.tables[0]
        sql_parse.fromm.empty()
        sent1, step = get_explanation(sub_query, step)
        sent2, step = get_explanation(sql_parse, step, join_table=True)

        explanations = sent1 + sent2
    
    elif not sql_parse.fromm.is_empty and len(sql_parse.fromm.tables) > 1: 
        # Assume no nested SQL in FROM in this case
        assert not isinstance(sql_parse.fromm.tables[0], Query), "subquery in non-single FROM clause"
        table_name1 = str(sql_parse.fromm.tables[0])
        sent1 = "for each row in table %s" % table_name1
        is_first = True
        for other_table_unit in sql_parse.fromm.tables[1:]:
            other_table_name = str(other_table_unit)
            if is_first:
                sent1 += ", find the conrresponding rows in %s table" % other_table_name
                is_first = False
            else:
                sent1 += " and in %s table" % other_table_name
    
        # remove from sql_parse
        sql_parse.fromm.empty()
        sent2, step = get_explanation(sql_parse, step + 1, join_table=True)
        explanations = [sent1] + sent2
    
    else:
        bool_groupby = not sql_parse.group_by.is_empty
        bool_orderby = not sql_parse.order_by.is_empty
        bool_having = not sql_parse.having.is_empty
        bool_where = not sql_parse.where.is_empty
        
        if join_table:  # including join tables and nested SQL in FROM clause
            table_description = "of the results of step %d " % (step - 1)
        else:
            table_name = str(sql_parse.fromm.tables[0])
            table_description = "in %s table" % table_name

        if sql_parse.fromm.is_empty and not join_table:
            table_description = ''
        
        sel_column_names = [str(col) for col in sql_parse.select.args]
        if sql_parse.select.distinct:
            sel_column_names = list(map(lambda x: 'distinct ' + x, sel_column_names))
        
        # select ...
        if sql_parse.select.distinct:
            sel_sent = 'find without repetition %s' % ", ".join(sel_column_names)
        sel_sent = "find the %s" % ", ".join(sel_column_names)
        if sql_parse.select.is_empty:
            sel_sent = ''

        # the WHERE clause, there might be nested quiries within it, and the behavior depends on whether the GROUP BY exists
        where_explanation = []  # for the nested SQL in WHERE clause
        where_description = ''
        if bool_where:
            where_description = "whose"
            where_descriptions = []
            for cond in sql_parse.where.conds:
                if cond.is_nested:
                    assert len(cond.sub_quiries) == 1, "more than one subquires in one cond_unit"
                    assert not isinstance(cond.val0, Query) and isinstance(cond.val1, Query), "new format of cond_unit founded!"
                    where_explanation, step = get_explanation(cond.sub_quiries[0], step)
                    where_descriptions += [' '.join([str(cond.val0), str(cond.op), 'the results of step %d' % (step - 1)])]
                else:
                    where_descriptions +=[str(cond)]
            where_description += ' and '.join(where_descriptions)
        
        if bool_groupby:
            groupby_column_names = [str(col) for col in sql_parse.group_by.args]

        if bool_orderby:
            orderby_column_names = [str(val_unit) for val_unit in sql_parse.order_by.args]
            orderby_direction = {"asc": "ascending", "desc": "descending"}[sql_parse.order_by.dir]
            orderby_est = {"asc": "smallest", "desc": "largest"}[sql_parse.order_by.dir]
            limit_value = sql_parse.limit.val
            if sql_parse.limit.is_empty:
                limit_value = None
            assert limit_value is None or limit_value == 1  # the LIMIT larger than 1 is processed outside
        
        if bool_having:
            assert len(sql_parse.having.conds) == 1, "assume only one cond_unit in HAVING"
            # assert not sql_parse.having.conds[0].is_nested, "no nested SQL in the HAVING clause"
            having_column_name = str(sql_parse.having.conds[0].val0)
            having_description = sql_parse.having.conds[0].description

        # This is for the case there are no GROUPBY
        # select ... from ... (where) ... (order by)
        if sql_parse.having.is_empty and sql_parse.group_by.is_empty:
            if sql_parse.fromm.is_empty and not join_table:
                table_description = ''
            sent1 = ' '.join([sel_sent, table_description])
            if bool_where:
                sent1 += ' ' + where_description
            if bool_orderby:
                if limit_value is None:
                    sent1 += ' '.join(['ordered', orderby_direction, 'by'] + orderby_column_names)
                else:
                    sent1 += ' '.join(['with', orderby_est, 'value of'] + orderby_column_names)
            
            return where_explanation + [sent1], step + 1

        # This is for the case where there are both WHERE and GROUPBY
        if bool_groupby and bool_where:
            assert len(groupby_column_names) == 1, "only one column name in groupby"
            if join_table:
                where_explanation += ["only keep te results of step %d %s" % (step - 1, where_description)]
            else:
                where_explanation += [' '.join(["find rows", table_description, where_description])]
            table_description = "of the results of step %d" % step
            if sql_parse.fromm.is_empty and not join_table:
                table_description = ''
            step += 1

        # select ... (where ...) group by ...
        # e.g., SELECT Employee_ID , Count ( * ) FROM Employees GROUP BY Employee_ID ->
        # find each value of Employee_ID in Employees table along with the number of the corresponding rows to each value
        if bool_groupby and not (bool_orderby or bool_having):
            # assert not bool_where, "WHERE should not occur with GROUP BY"
            sent = "find each value of %s %s along with %s of the corresponding rows to each value" % (", ".join(groupby_column_names),
                                                                                                        table_description,
                                                                                                        " and ".join(list(set(sel_column_names) - set(groupby_column_names))))
            return where_explanation + [sent], step + 1
        
        # select ... (where ...) group by ... order by ...
        if bool_groupby and bool_orderby and not bool_having:
            # assert not bool_where, "WHERE should not occur with GROUP BY"
            sent1 = "find the %s of each value of %s %s" % (", ".join(orderby_column_names),
                                                        ", ".join(groupby_column_names),
                                                        table_description)
            if limit_value and orderby_direction == "descending":
                sent2 = sel_sent + "with largest value in the results of step %d" % step
            elif limit_value and orderby_direction == "ascending":
                sent2 = sel_sent + "with smallest value in the results of step %d" % step
            else:
                sent2 = sel_sent + "ordered %s by the results of step %d" % (orderby_direction, step)

            return where_explanation + [sent1, sent2], step + 2
        
        # select ... (where ...) group by ... having ...
        if bool_groupby and bool_having and not bool_orderby:
            # assert not bool_where, "WHERE should not occur with GROUP BY"
            sent1 = "find the %s of each value of %s %s" % (having_column_name, 
                                                        ", ".join(groupby_column_names),
                                                        table_description)
            sent2 = sel_sent + " whose corresponding value in step %d %s " % (step + 1, having_description)

            return where_explanation + [sent1, sent2], step + 2

        # This case is only for feedback generation
        if bool_having and not bool_groupby:
            sent1 = 'make sure %s ' % connect_words(sql_parse.having.conds)
            
            return where_explanation + [sent1], step + 1
                    
        raise Exception("New patterns discovered!")
        
    return explanations, step

def add_step(explanation: list) -> list:
    if len(explanation) == 1:
        return explanation
    added = []
    for idx, e in enumerate(explanation):
        added += ['Step %d: ' % (idx + 1) + e]

    return added


In [3]:
# Imported for testing
from utils import load_json
from config import SPLASH_TRAIN_JSON
from config import SPLASH_DEV_JSON
from config import SPLASH_TEST_JSON

from tqdm import tqdm
import json
import random


In [4]:
train = load_json(SPLASH_TRAIN_JSON)
valid = load_json(SPLASH_DEV_JSON)
test = load_json(SPLASH_TEST_JSON)
data = train + valid + test

In [5]:
# Test explanation on all data and find errors
explanations = []
diff_steps = []
greater_than_5 = []
errors = []
for id, sample in tqdm(list(enumerate(data))):
    gold_sql = sample['gold_parse']
    pred_sql = sample['predicted_parse_with_values']
    db_id = sample['db_id']
    pred_query = Query(pred_sql, db_id)
    gold_query = Query(gold_sql, db_id)

    try:
        explanation = add_step(get_explanation(pred_query)[0])
        explanation_gold = add_step(get_explanation(gold_query)[0])
        steps_explanation = len(explanation)
        steps_gold = len(sample['predicted_parse_explanation'])
        example = [{'predicted_parse_with_values': pred_sql, 'gold_explanation': sample['predicted_parse_explanation'], 'pred_explanation': explanation}]

        explanations += example
        if steps_explanation != steps_gold:
            diff_steps += example
        if steps_gold > 5:
            greater_than_5 += [example]
    except:
        errors += [id]

print(len(errors))
print(len(diff_steps) / len(data))
print(len(greater_than_5))

100%|██████████| 9314/9314 [00:35<00:00, 264.89it/s]

24
0.01771526733948894
4





In [6]:
# samples of generated explanations
samples = random.sample(explanations, 50)
samples_json = json.dumps(samples, indent=2)
print(samples_json)

[
  {
    "predicted_parse_with_values": "SELECT T1.name FROM ACCOUNTS AS T1 JOIN SAVINGS AS T2 ON T1.custid = T2.custid WHERE T2.balance > ( SELECT Avg ( T2.balance ) FROM SAVINGS AS T2 )",
    "gold_explanation": [
      "Step 1: find the average balance in SAVINGS table",
      "Step 2: For each row in ACCOUNTS table, find the corresponding rows in SAVINGS table",
      "Step 3: find name in the results of step 2 whose balance greater than the results of step 1"
    ],
    "pred_explanation": [
      "Step 1: for each row in table accounts, find the conrresponding rows in savings table",
      "Step 2: find the average balance in savings table",
      "Step 3: find the name of the results of step 1  whosebalance > the results of step 2"
    ]
  },
  {
    "predicted_parse_with_values": "SELECT country FROM Addresses GROUP BY country HAVING Count ( * ) > 4",
    "gold_explanation": [
      "Step 1: find the number of rows of each value of country in Addresses table",
      "Step 2: f

In [7]:
# Check one example
id = 1310
gold = data[id]['gold_parse'].strip()
pred_sql = data[id]['predicted_parse_with_values'].strip()
db_id = data[id]['db_id']
feedback = data[id]['feedback']
explanation_gold = data[id]['predicted_parse_explanation']
pred_query = Query(pred_sql, db_id)
gold_query = Query(gold, db_id)
explanation_pred, _ = get_explanation(pred_query)
explanation_gold, _ = get_explanation(gold_query)

print(pred_sql)
print(pred_query)
print(feedback)

AssertionError: assume only one cond_unit in HAVING

In [59]:
print(gold)

SELECT T2.Name ,  T1.ArtistId FROM ALBUM AS T1 JOIN ARTIST AS T2 ON T1.ArtistId  =  T2.ArtistID GROUP BY T1.ArtistId HAVING COUNT(*)  >=  3 ORDER BY T2.Name


In [41]:
print(explanation_pred)
print(explanation_gold)

['find the addresses.state_province_county in addresses table', 'find the addresses.state_province_county in addresses table', 'show the rows that are in the results of step 1 but not in the results of step 2']
['Step 1: find the state_province_county of Addresses table', 'Step 2: find the state_province_county of Addresses table', 'Step 3: show the rows that are in the results of step 1 but not in the results of step 2']
