In [8]:
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///../data/catastici.db")

# test DB
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM catastici LIMIT 1;")

sqlite
['catastici']


"[('liberal', 'campi', 'casa e bottega da barbier', 70, 'campo vicino alla chiesa')]"

In [4]:
import pandas as pd
query_res = pd.read_csv('../data/test_data_generated.csv')

In [6]:
import re

def clean_query(sql_query):
    sql_keywords = [
        'FROM', 'WHERE', 'GROUP', 'ORDER', 'HAVING', 
        'JOIN', 'INNER', 'LEFT', 'RIGHT', 'OUTER', 'ON', 
        'INSERT', 'INTO', 'VALUES', 'UPDATE', 'SET', 'DELETE', 'CREATE', 
        'TABLE', 'ALTER', 'DROP', 'INDEX', 'TRUNCATE', 'DISTINCT', 'AS', 
        'AND', 'OR', 'NOT', 'IN', 'BETWEEN', 'LIKE', 'IS', 'NULL', 'ASC', 
        'DESC', 'LIMIT', 'OFFSET', 'UNION', 'ALL', 'ANY', 'CASE', 'WHEN', 
        'THEN', 'ELSE', 'END', 'EXISTS', 'COUNT', 'MAX', 'MIN', 'SUM', 
        'AVG'
    ]
    
    columns = ['Owner_First_Name', 'Owner_Family_Name', 'Property_Type', 'Property_Location', 'Rent_Income']
    
    # split on ;
    sql_query = sql_query.split(';')[0].split('What')[0].split('How')[0]+';'
    
    # add white space
    for keyword in sql_keywords:
        pattern = r'(?<=[a-z0-9"\'])' + re.escape(keyword)
        sql_query = re.sub(pattern, ' ' + keyword, sql_query)

    # replace '
    sql_query = re.sub(r"([a-z])'([a-z])",r"\1''\2", sql_query)
    
    # add white space
    if ('LIMIT' in sql_query) and (sql_query[sql_query.find('LIMIT') - 1] != ' '):
        sql_query = sql_query.replace('LIMIT',' LIMIT')

    # wrap the column names with " "
    for column in columns:
        if f"{column}" in sql_query and f'"{column}"' not in sql_query:
            sql_query = sql_query.replace(f"{column}", f'"{column}"')


    return sql_query

def check_sql_executability(query, db):
    try:
        return db.run(query)
    except:
        return "ERROR"

In [15]:
for idx, row in query_res.iterrows():
    gqs = row['generated_query'].split('\n')
    gqsc = []
    final_out = None
    for gq in gqs:
        gq = clean_query(gq)
        gqsc.append(gq)
        answer = check_sql_executability(gq, db)
        if answer != "ERROR":
            final_out = gq
            break
    if final_out == None:
        final_out = '\n'.join(gqsc)
        answer = "ERROR"
    query_res.loc[idx,'generated_answer'] = answer
    query_res.loc[idx,'generated_query'] = final_out

In [12]:
query_res[(query_res['generated_answer']!='ERROR') & (query_res['generated_answer']!=query_res['true_answer'])].shape

(266, 8)

In [13]:
query_res[(query_res['generated_answer']=='ERROR')].shape

(24, 8)

# Check

Wrong -> 266<br>
Error -> 24<br>
True -> **230**

### Wrong Groud Truth
5, 20, 35, 40, 90, 155, 165, 170, 185, 481, 121, 309, 312

Ambigious questions <br>
15-19, 450-454, 85-89, 225-228, 255-259, 260-264, 305-309

Super hard<br>
494, 488, 478 (CodeS True), 464, 459, 305 (CodeS True)

### Wrong generation
Split on ; - 2,                         # *solved* <br>
Wrap column names with "" - 4           # *solved* <br>
White Space problem - 6, 10, 29, 30     # *solved* <br>
Uppercases the names                    # lowercase everything <br>

### Limitations
Sometimes, puts non-existing SQL keywords, such as STDDEV, ALL, ... <br>
More than 1 arguments in Count <br>

Sometimes, it puts extra filter (on Limit) - 3, 8, 11, 12, 13, 14, 25, 35<br>
Sometimes, it confuses the feature names: i.e Rent_Income instead of Property_Type

In [None]:
for idx, row in query_res[(query_res['generated_answer']!='ERROR') & (query_res['generated_answer']!=query_res['true_answer'])].iterrows():
    print(f"{row['level']} - {row['question_id']} - {idx}")
    print(f"Question: {row['question']}")
    print(f"Answer True: {row['true_answer']}")
    print(f"Answer Generated: {row['generated_answer']}")
    print('True SQL:')
    print(row['true_query'])
    print('Generated SQL:')
    print(row['generated_query'])
    print('\n\n')    