In [3]:
import pandas as pd
from datasets import load_dataset

In [4]:
generated_queries = pd.read_csv("test_resources/generated_queries.csv")

test_dataset = load_dataset("wikisql", split='test')
test_queries = test_dataset['question']
test_target = [row['human_readable'] for row in test_dataset['sql']]

Found cached dataset wikisql (/home/nicolopizzo/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)


In [5]:
import re

agg_ops = ['MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']


# idx = 92
# header = test_dataset[idx]['table']['header']
# tid = test_dataset[idx]['table']['id']
# query = test_dataset[idx]['sql']['human_readable']

def get_agg(sel_clause: str) -> str:
    for agg in agg_ops:
        if agg in sel_clause:
            return agg
    
    return None

def extract_column(sel_clause: str) -> str:
    agg = get_agg(sel_clause)
    if agg != None:
        index = sel_clause.index(agg) + len(agg)
    else:
        index = len('SELECT')
    
    return sel_clause[index:].strip()

def extract_clauses(where_clause: str, col_mapping: dict[str, str]) -> str:
    conditions = [c.strip() for c in where_clause.split(' AND ')]
    clauses = []
    for c in conditions:
        splitted = re.split(' ([=><]|OP) ', c)
        if len(splitted) > 3:
            splitted[2] = ''.join(splitted[2:])
            splitted = splitted[:3]
        elif len(splitted) < 3:
            pass
            # print(splitted)
            # print(where_clause)
        [col, op, value] = splitted
        value = value.strip()
        col = col.strip()
        
        col = col_mapping.get(col)
        num_regex = r'^\d+[.\d+]?$'
        value = value.replace('"', '')
        if (not re.match(num_regex, value)) or (not '"' in value): 
            value = f'"{value.lower()}"'
        clauses.append(f'{col}{op}{value}')
    
    return 'WHERE ' + ' AND '.join(clauses)    
    

def fix_query(query: str, header: str, tid: str) -> str:
    # FIX SELECT [AGG] CLAUSE
    from_index = query.index("FROM")
    
    col_mapping = { col: f'col{i}' for i, col in enumerate(header)}
    
    select_clause = query[:from_index]
    col_name = extract_column(select_clause)
    col = col_mapping.get(col_name)
    agg = get_agg(select_clause)
    if agg != None:
        select = f"SELECT {agg}({col})"
    else:
        select = f"SELECT {col}"
        
    # FIX TABLE NAME
    from_clause = f'FROM table_{tid.replace("-", "_")}'
    
    # FIX WHERE CLAUSE
    where = ''
    if 'WHERE' in query:
        where_index = query.index('WHERE')
        where_clause = query[where_index + 5:]
        where = extract_clauses(where_clause, col_mapping)
    
    return f'{select} AS result {from_clause} {where}'

In [6]:
def logical_accuracy(test_queries: list[str], predicted_queries: list[str]) -> int:
  # assert len(test_queries) == len(predicted_queries)

  count = 0
  for tq, pq in zip(test_queries, predicted_queries):
    if tq.lower() == pq.lower():
      count += 1

  return count / len(test_queries)

predicted_queries = generated_queries['query'].tolist().copy()
logical_accuracy(test_target, predicted_queries)

0.2787504723516816

In [21]:
import sqlite3 as sql
from sqlite3 import Error
import copy
from torch.utils.data import Dataset

def execution_accuracy(test_queries: Dataset, predicted_queries: list[str], cursor: sql.Cursor) -> int:
  assert len(test_queries) == len(predicted_queries)

  count = 0
  error = 0
  for i, row in enumerate(test_queries):
    header = row['table']['header']
    tid = row['table']['id']
    query = row['sql']['human_readable']
    
    tq = fix_query(query, header, tid)
    cursor.execute(tq)
    test_rows = cursor.fetchall()

    try:
      pq = fix_query(predicted_queries[i], header, tid)
      cursor.execute(pq)
      pred_rows = cursor.fetchall()
    except:
      pred_rows = None
      error += 1
    
    count += (pred_rows == test_rows)

  print(f"Errors encountered: {error}")
  
  n_q = len(test_queries)
  
  print(f"Queries with valid syntax: {n_q - error}")
  print(f"Correct Queries with valid syntax: {count}")
  print(f"Incorrect Queries with valid syntax: {n_q - error - count}")
  return count / len(test_queries)

def create_connection(db_file: str) -> sql.Connection:
    conn = None
    try:
        conn = sql.connect(db_file)
    except Error as e:
        print(e)

    return conn

db_path = "test_resources/test.db"
connection = create_connection(db_path)
cursor = connection.cursor()

execution_accuracy(test_dataset, predicted_queries, cursor)

Errors encountered: 8238
Queries with valid syntax: 7640
Correct Queries with valid syntax: 4770
Incorrect Queries with valid syntax: 2870


0.3004156694797833