In [1]:
import json
import shutil

import sqlparse
import spacy
from sqlparse import sql
from tqdm import tqdm
from pathlib import Path

from process_sql import create_sql, SQLParseException

In [3]:
# Loading whole english spider dataset for further tests
spider_en = []
with open('../spider-en/train_spider.json') as json_data:
    spider_en.extend(json.load(json_data))
with open('../spider-en/train_others.json') as json_data:
    spider_en.extend(json.load(json_data))
with open('../spider-en/dev.json') as json_data:
    spider_en.extend(json.load(json_data))

## Query toks no value

In [3]:
def tokenize_query_no_value(query):
    statement = sqlparse.parse(query)[0]
    
    value_tokens = [token for token in statement.flatten() if str(token.ttype).startswith('Token.Literal')]
    for token in value_tokens:
        token.value = 'value'
        
    coarse_tokens =  [str(token).lower() for token in statement.flatten() if str(token).strip() != '']
    
    fine_tokens = []
    for token in coarse_tokens:
        if len(token.split(' ')) > 1:
            fine_tokens.extend(token.split(' '))
        elif token == '!=':
            fine_tokens.extend(['!', '='])
        elif token == '>=':
            fine_tokens.extend(['>', '='])
        elif token == '<=':
            fine_tokens.extend(['<', '='])
        elif token == ';':
            continue
        else:
            fine_tokens.append(token)
    return fine_tokens

### Test

In [44]:
# Compare created tokenization function with tokens from oryginal spider
discrepancies = 0

for sample in spider_en:
    my_tokens = tokenize_query_no_value(sample['query'])
    oryginal_tokens = sample['query_toks_no_value']
    if my_tokens != oryginal_tokens:
        discrepancies += 1
        print(sample['query'])
        print(my_tokens)
        print(oryginal_tokens)
        print()
        
print(discrepancies, 'discrepancies')
# found 18 discrepancies

SELECT count(*) FROM follows GROUP BY f1
['select', 'count', '(', '*', ')', 'from', 'follows', 'group', 'by', 'f1']
['select', 'count', '(', '*', ')', 'from', 'follows', 'group', 'by', 'value']

SELECT Roles.role_description , count(Employees.employee_id) FROM ROLES JOIN Employees ON Employees.role_code = Roles.role_code GROUP BY Employees.role_code HAVING count(Employees.employee_id)  >  1;
['select', 'roles', '.', 'role_description', ',', 'count', '(', 'employees', '.', 'employee_id', ')', 'from', 'roles', 'join', 'employees', 'on', 'employees', '.', 'role_code', '=', 'roles', '.', 'role_code', 'group', 'by', 'employees', '.', 'role_code', 'having', 'count', '(', 'employees', '.', 'employee_id', ')', '>', 'value']
['select', 'roles.role_description', ',', 'count', '(', 'employees.employee_id', ')', 'from', 'roles', 'join', 'employees', 'on', 'employees.role_code', '=', 'roles.role_code', 'group', 'by', 'employees.role_code', 'having', 'count', '(', 'employees.employee_id', ')', '>', 

## Query toks

In [4]:
nlp_pl = spacy.load("pl_core_news_md")

def tokenize_polish(text):
    return [str(token) for token in nlp_pl(text)]

In [5]:
def tokenize_statement(root, tokens=None):
    if tokens is None:
        tokens = []
    
    # Create tokens from elements which we can't divide
    if not hasattr(root, 'tokens'):
        token = str(root).strip()
        if token == '' or token == ';':
            pass
        elif len(token.split(' ')) > 1:
            tokens.extend(token.split(' '))
        elif token == '!=':
            tokens.extend(['!', '='])
        elif token == '>=':
            tokens.extend(['>', '='])
        elif token == '<=':
            tokens.extend(['<', '='])
        else:
            tokens.append(str(root))
    
    # Not split identifiers like "T1.name" into separate tokens
    elif isinstance(root, sql.Identifier) and '.' in str(root):
        tokens.append(str(root))
        
    # Tokenize strings using polish tokenizer
    elif isinstance(root, sql.Identifier) and len(root.tokens) == 1 and str(root.tokens[0].ttype) == 'Token.Literal.String.Symbol':
        tokens.extend(tokenize_polish(str(root)))
        
    # Tokenize other compound elements recursively
    else:
        for token in root.tokens:
            tokenize_statement(token, tokens)
            
    return tokens

In [6]:
def tokenize_query(query):
    statement = sqlparse.parse(query)[0]
    tokens = tokenize_statement(statement)
    return tokens

### Test

In [None]:
# Compare created tokenization function with tokens from oryginal spider
discrepancies = 0

for sample in spider_en:
    my_tokens = tokenize_query(sample['query'])
    oryginal_tokens = sample['query_toks']
    if my_tokens != oryginal_tokens:
        discrepancies += 1
        print(sample['query'])
        print(my_tokens)
        print(oryginal_tokens)
        print()
        
print(discrepancies, 'discrepancies')
# found 4754 discrepancies

## Question tokenization

In [7]:
def tokenize_question(question):
    return tokenize_polish(question)

## Completing dataset

In [8]:
def add_calculated_attributes_to_samples(src_path, tgt_path, tables_path):
    with open(src_path) as f:
        samples = json.load(f)
        
    new_samples = []
        
    for sample in tqdm(samples):
        try:
            sql = create_sql(sample['db_id'], sample['query_pl'], tables_path)
        except SQLParseException:
            # skip samples with invalid sqls
            print('Skipping sample')
        
        new_sample = {
            'db_id': sample['db_id'],
            'question': sample['question_pl'],
            'question_toks': tokenize_question(sample['question_pl']),
            'query': sample['query_pl'],
            'query_toks': tokenize_query(sample['query_pl']),
            'query_toks_no_value': tokenize_query_no_value(sample['query_pl']),
            'sql': sql
        }
        
        new_samples.append(new_sample)
        
    with open(tgt_path, 'w') as f:
        json.dump(new_samples, f, indent=4, ensure_ascii=False)

In [9]:
def create_gold_sql(src_files, tgt_file):
    with open(tgt_file, 'w') as f:
        for src_file in src_files:
            with open(src_file) as g:
                for sample in json.load(g):
                    f.write(f"{sample['query_pl']}\t{sample['db_id']}\n")

In [18]:
def create_complete(src_name, dst_name):
    src_path = Path('../auxiliary') / src_name
    dst_path = Path('../complete') / dst_name
    
    dst_path.mkdir(exist_ok=True, parents=True)
    
    for file in ['dev.json', 'train_others.json', 'train_spider.json']:
        add_calculated_attributes_to_samples(
            src_path / file,
            dst_path / file,
            src_path / 'tables.json'
        )

    create_gold_sql(
        [
            src_path / 'dev.json'
        ],
        dst_path / 'dev_gold.sql'
    )

    create_gold_sql(
        [
            src_path / 'train_spider.json',
            src_path / 'train_others.json'
        ],
        dst_path / 'train_gold.sql'
    )
    
    shutil.copyfile(src_path / 'tables.json', dst_path / 'tables.json')

In [19]:
create_complete('machine_translated', 'machine_translated')