In [1]:
from spider.test_suite_eval.process_sql import get_schema_from_json, get_sql
from spider.test_suite_eval.schema import build_schema_mapping, SchemaIndex
from spider.test_suite_eval.evaluation import Evaluator
import re
import json
SCHEMA_NAMES = ['OncoMX', 'SDSS', 'StatBot', 'WorldCup']
SCHEMA_PATHS = ['oncomx', 'sdss', 'statbot', 'world_cup']

def load_schema_dict(names=SCHEMA_NAMES, paths=SCHEMA_PATHS):
    schema_dict = {}
    for name, path in zip(names, paths):
        with open(f'data/{path}/original/tables.json', 'r') as inf:
            schema = json.load(inf)[0]
        schema_dict[name] = schema
    return schema_dict


In [2]:
schema_dict = load_schema_dict()
dataset = 'SDSS'
schema = schema_dict[dataset]


In [5]:
# all functions
def get_hardness(query, evaluator, schema_dict):
    schema_mapping = build_schema_mapping(schema_dict)
    schema = SchemaIndex(
        schema_mapping, schema_dict['column_names_original'], schema_dict['table_names_original'])
    try:
        spider_sql_structure, sql_tokenizer = get_sql(schema, query)
        return evaluator.eval_hardness(spider_sql_structure)
    except Exception as e:
        return f"""Cannot parese SQL input. Please check your query again\n""" \
               f"""Input: {query}\nError: {e}"""

def normalize_spaces(match_obj):
    for i, _ in enumerate(match_obj.groups(), 1):
        if match_obj.group(i) is not None and match_obj.group(i+2) is not None:
            return match_obj.group(i) + ' ' + match_obj.group(i+1).strip() + ' ' + match_obj.group(i+2)

def normalize_as(match_obj):
    reserved_keywords = ['as', 'where', 'order', 'on'
                         'group', 'limit', 'join', 'having']
    # print(match_obj.groups())
    for i, grp in enumerate(match_obj.groups(), 1):
        if match_obj.group(i) is not None and match_obj.group(i+1) is not None and match_obj.group(i+2) is not None and match_obj.group(i+2).lower() not in reserved_keywords:
            #print(i, match_obj.group(i), match_obj.group(i+1))
            if match_obj.group(i+2):
                return match_obj.group(i) + ' ' + match_obj.group(i+1).rstrip() + ' AS ' + match_obj.group(i+2).strip()
        elif match_obj.group(i) is not None and match_obj.group(
                i+1) is not None and match_obj.group(i+2) is not None:
            return match_obj.group(i) + ' ' + match_obj.group(i+1).rstrip() + ' ' + match_obj.group(i+2).strip()

def quotate_boolean_values(match_obj):
    for i, _ in enumerate(match_obj.groups(), 1):
        if match_obj.group(i) is not None:
            # print(match_obj)
            return '\'' + match_obj.group(i).strip('\'') + '\''
    
def _add_spaces(query):
    operators_1 = ['\+', '\-', '/', '\*', '\=', '>', '<']
    operators_2 = ['>=', '<=', '!=', '<>']
    ops = operators_2 + operators_1
    re_patterns_list = [
        f"(\w+\.?\w*)(\s*{op}\s*)([\"\'\-\w]+\.?[\"\'\w]*)" for op in ops]
    regex = ('|').join(re_patterns_list)
    try:
        new_query = re.sub(regex, normalize_spaces, query)
    except Exception as e:
        print(f"_add_spaces: {query}")
        print(e)
        new_query = query
    return new_query

def _add_as(query):
    regex = re.compile(
        '(FROM)\s+([\w\_]+)\s+([\w\_]+)|(JOIN)\s+([\w\_]+)\s+([\w\_]+)', flags=re.IGNORECASE)
    try:
        new_query = re.sub(regex, normalize_as, query)
    except Exception as e:
        print(f"_add_as: {query}")
        print(e)
        new_query = query
    return new_query

def _add_quotes(query, keywords=['true', 'false']):

    _keywords = [f'\'{word}\'' for word in keywords]
    # keywords = _keywords + keywords
    if isinstance(query, str):
        for _k, k in zip(_keywords, keywords):
            query = query.replace(_k, k)
    regex = re.compile(r'\b(%s)\b' % '|'.join(keywords),
                       flags=re.IGNORECASE | re.MULTILINE)
    # print(regex)
    try:
        new_query = re.sub(regex, quotate_boolean_values, query)
    except Exception as e:
        print(f"_add_quotes: {query}")
        print(e)
        new_query = query

    return new_query

def query_cleaning(query):
    """todo:
    – SELECT * FROM A a JOIN B b ON A.a=B.b WHERE
    A.c = true
    + SELECT * FROM A AS a JOIN B AS b ON A.a = B.b
    WHERE A.c = ’true’ (use keyword "AS" explicitly, space
    before and after "=", and stringfy the boolean value
    "true"/"false")
    """
    # print(query)
    res = _add_quotes(_add_as(_add_spaces(query)))
    return res

In [6]:
evaluator = Evaluator()
train_file = f"data/{dataset.lower()}/seed.json"
dev_file = f"data/{dataset.lower()}/dev.json"

def get_dataset_hardness(datafile, schema_dict, dataset):
    with open(train_file, 'r') as data:
        train_data = json.load(data)
    res = []
    for d in train_data:
        query = d['query']
        question = d['question']
        hardness = get_hardness(query_cleaning(query), Evaluator(), schema_dict[dataset])
        temp = {
            'db_id': dataset,
            'question': question,
            'query': query,
            'hardness': hardness
        }
        res.append(temp)
    split = datafile.split('/')[-1].split('.')[0]
    hardness_output = f"data/{dataset.lower()}/{split}_hardness.json"
    with open(hardness_output, 'w') as outf:
        json.dump(res, outf, indent=4)


get_dataset_hardness(train_file, schema_dict, 'SDSS')
get_dataset_hardness(dev_file, schema_dict, 'SDSS')

In [6]:
# query = 'SELECT * FROM disease as d;'
query = '''SELECT amount, rank
FROM baby_names_favorite_firstname bnff
JOIN spatial_unit su ON bnff.spatialunit_uid=su.spatialunit_uid
WHERE year = 2014
    AND bnff.gender = 'girl'
    AND bnff.first_name = 'Lena'
    AND su.name = 'Switzerland'
    AND su.country = TRUE;
'''
query = query_cleaning(query)
evaluator = Evaluator()
res = get_hardness(query, evaluator, schema)

res

re.compile('\\b(true|false)\\b', re.IGNORECASE|re.MULTILINE)


'extra'