In [623]:
import re
import json
import shutil

import sqlparse
from sqlparse import sql
import sqlglot
import sqlglot.expressions as exp
import spacy
from tqdm import tqdm
from pathlib import Path

from process_sql import create_sql, SQLParseException

In [3]:
# Loading whole english spider dataset for further tests
spider_en = []
with open('../../spider-en/train_spider.json') as json_data:
    spider_en.extend(json.load(json_data))
with open('../../spider-en/train_others.json') as json_data:
    spider_en.extend(json.load(json_data))
with open('../../spider-en/dev.json') as json_data:
    spider_en.extend(json.load(json_data))

## Query toks no value

In [3]:
def tokenize_query_no_value(query):
    statement = sqlparse.parse(query)[0]
    
    value_tokens = [token for token in statement.flatten() if str(token.ttype).startswith('Token.Literal')]
    for token in value_tokens:
        token.value = 'value'
        
    coarse_tokens =  [str(token).lower() for token in statement.flatten() if str(token).strip() != '']
    
    fine_tokens = []
    for token in coarse_tokens:
        if len(token.split(' ')) > 1:
            fine_tokens.extend(token.split(' '))
        elif token == '!=':
            fine_tokens.extend(['!', '='])
        elif token == '>=':
            fine_tokens.extend(['>', '='])
        elif token == '<=':
            fine_tokens.extend(['<', '='])
        elif token == ';':
            continue
        else:
            fine_tokens.append(token)
    return fine_tokens

### Test

In [None]:
# Compare created tokenization function with tokens from oryginal spider
discrepancies = 0

for sample in spider_en:
    my_tokens = tokenize_query_no_value(sample['query'])
    oryginal_tokens = sample['query_toks_no_value']
    if my_tokens != oryginal_tokens:
        discrepancies += 1
        print(sample['query'])
        print(my_tokens)
        print(oryginal_tokens)
        print()
        
print(discrepancies, 'discrepancies')
# found 18 discrepancies

## Query toks

In [4]:
nlp_pl = spacy.load("pl_core_news_md")

def tokenize_polish(text):
    return [str(token) for token in nlp_pl(text)]

In [5]:
def tokenize_statement(root, tokens=None):
    if tokens is None:
        tokens = []
    
    # Create tokens from elements which we can't divide
    if not hasattr(root, 'tokens'):
        token = str(root).strip()
        if token == '' or token == ';':
            pass
        elif len(token.split(' ')) > 1:
            tokens.extend(token.split(' '))
        elif token == '!=':
            tokens.extend(['!', '='])
        elif token == '>=':
            tokens.extend(['>', '='])
        elif token == '<=':
            tokens.extend(['<', '='])
        else:
            tokens.append(str(root))
    
    # Not split identifiers like "T1.name" into separate tokens
    elif isinstance(root, sql.Identifier) and '.' in str(root):
        tokens.append(str(root))
        
    # Tokenize strings using polish tokenizer
    elif isinstance(root, sql.Identifier) and len(root.tokens) == 1 and str(root.tokens[0].ttype) == 'Token.Literal.String.Symbol':
        tokens.extend(tokenize_polish(str(root)))
        
    # Tokenize other compound elements recursively
    else:
        for token in root.tokens:
            tokenize_statement(token, tokens)
            
    return tokens

In [6]:
def tokenize_query(query):
    statement = sqlparse.parse(query)[0]
    tokens = tokenize_statement(statement)
    return tokens

### Test

In [None]:
# Compare created tokenization function with tokens from oryginal spider
discrepancies = 0

for sample in spider_en:
    my_tokens = tokenize_query(sample['query'])
    oryginal_tokens = sample['query_toks']
    if my_tokens != oryginal_tokens:
        discrepancies += 1
        print(sample['query'])
        print(my_tokens)
        print(oryginal_tokens)
        print()
        
print(discrepancies, 'discrepancies')
# found 4754 discrepancies

## Question tokenization

In [7]:
def tokenize_question(question):
    return tokenize_polish(question)

## Completing dataset

In [8]:
def add_calculated_attributes_to_samples(src_path, tgt_path, tables_path):
    with open(src_path) as f:
        samples = json.load(f)
        
    new_samples = []
        
    for sample in tqdm(samples):
        try:
            sql = create_sql(sample['db_id'], sample['query_pl'], tables_path)
        except SQLParseException:
            # skip samples with invalid sqls
            print('Skipping sample')
        
        new_sample = {
            'db_id': sample['db_id'],
            'question': sample['question_pl'],
            'question_toks': tokenize_question(sample['question_pl']),
            'query': sample['query_pl'],
            'query_toks': tokenize_query(sample['query_pl']),
            'query_toks_no_value': tokenize_query_no_value(sample['query_pl']),
            'sql': sql
        }
        
        new_samples.append(new_sample)
        
    with open(tgt_path, 'w') as f:
        json.dump(new_samples, f, indent=4, ensure_ascii=False)

In [9]:
def create_gold_sql(src_files, tgt_file):
    with open(tgt_file, 'w') as f:
        for src_file in src_files:
            with open(src_file) as g:
                for sample in json.load(g):
                    f.write(f"{sample['query_pl']}\t{sample['db_id']}\n")

In [418]:
def create_tables_translation_dict(tables_names_path):
    with open(tables_names_path) as f:
        tables_names = json.load(f)
        
    tables_trans_dict = {}
    for entry in tables_names:
        if entry['db_id'] not in tables_trans_dict:
            tables_trans_dict[entry['db_id']] = {}
        tables_trans_dict[entry['db_id']][entry['name_original'].lower()] = {'name': entry['name_pl'], 'name_original': entry['name_original_pl']}
        
    return tables_trans_dict

In [416]:
def create_columns_translation_dict(columns_names_path):
    with open(columns_names_path) as f:
        columns_names = json.load(f)
        
    columns_trains_dict = {}
    for entry in columns_names:
        if entry['db_id'] not in columns_trains_dict:
            columns_trains_dict[entry['db_id']] = {}
        db_stuff = columns_trains_dict[entry['db_id']]
        if entry['table_name_original'].lower() not in db_stuff:
            db_stuff[entry['table_name_original'].lower()] = {}
        table_stuff = db_stuff[entry['table_name_original'].lower()]
        table_stuff[entry['column_name_original'].lower()] = {'name': entry['column_name_pl'], 'name_original': entry['column_name_original_pl']}
        
    return columns_trains_dict

In [434]:
# tables_translation_dict = create_tables_translation_dict('../auxiliary/translated_schema/tables_names.json')
columns_translation_dict = create_columns_translation_dict('../auxiliary/translated_schema/columns_names.json')

In [437]:
columns_translation_dict['perpetrator']['perpetrator']

{'perpetrator_id': {'name': 'id sprawcy', 'name_original': 'Sprawca_ID'},
 'people_id': {'name': 'id osoby', 'name_original': 'Ludzie_ID'},
 'date': {'name': 'data', 'name_original': 'Data'},
 'year': {'name': 'rok', 'name_original': 'Rok'},
 'location': {'name': 'lokalizacja', 'name_original': 'Lokalizacja'},
 'country': {'name': 'kraj', 'name_original': 'Kraj'},
 'killed': {'name': 'zabity', 'name_original': 'Zabity'},
 'injured': {'name': 'obrażenia', 'name_original': 'Obrażenia'}}

In [116]:
def translate_tables(tables_path, tables_names_path, columns_names_path, output_path):
    tables_trans_dict = create_tables_translation_dict(tables_names_path)
    columns_trans_dict = create_columns_translation_dict(columns_names_path)
        
    with open(tables_path) as f:
        tables_json = json.load(f)
        
    # perform translation
    for db in tables_json:
        db_id = db['db_id']
        
        # translate columns
        assert db['column_names'][0][1] == '*'
        assert db['column_names_original'][0][1] == '*'
        for i in range(1, len(db['column_names_original'])):
            table_idx, column_name_original = db['column_names_original'][i]
            table_name = db['table_names_original'][table_idx]
            translations = columns_trans_dict[db_id][table_name][column_name_original]
            db['column_names_original'][i][1] = translations['name_original']
            db['column_names'][i][1] = translations['name']
        
        # translate tables
        for i in range(len(db['table_names_original'])):
            table_name = db['table_names_original'][i]
            translations = tables_trans_dict[db_id][table_name]
            db['table_names_original'][i] = translations['name']
            db['table_names'][i] = translations['name_original']
        
    # save translated
    with open(output_path, 'w') as f:
        json.dump(tables_json, f, indent=4, ensure_ascii=False)

In [None]:
x = translate_tables(
    '../auxiliary/translated_schema/tables.json',
    '../auxiliary/translated_schema/tables_names.json',
    '../auxiliary/translated_schema/columns_names.json',
    '../auxiliary/translated_likes/tables.json'
)

### Experiments

Nazwy kolumn:
- po select
    - w funkcjach agregacyjnych (min, max, avg)
- w porównaniach (where, on)
- po order by

Nazwy tabel:
- tam gdzie nazwy kolumn
- po from
- po join

In [325]:
q1 = "select horsepower ,  T1.Make FROM CAR_NAMES AS T1 JOIN CARS_DATA AS T2 ON T1.MakeId  =  T2.Id WHERE T2.cylinders  =  3 ORDER BY T2.horsepower DESC LIMIT 1;"
q2 = "SELECT LOCATION ,  name FROM stadium WHERE capacity BETWEEN 5000 AND 10000"
q3 = "select max(capacity) from stadium as ST"

In [558]:
def get_tables_aliasing(sql):
    aliasing = []
    tokens = [token for token in sqlparse.parse(sql)[0].flatten() if str(token).strip() != '']
    for i in range(1, len(tokens)-1):
        prev, current, next = tokens[i-1:i+2]
        if str(current).upper() == 'AS':
            assert str(current.ttype) == 'Token.Keyword'
            aliasing.append((str(prev), str(next))) 
    return aliasing

# get_tables_aliasing(q1)

q5 = 'SELECT T1.name ,  T1.id FROM station AS T1 JOIN status AS T2 ON T1.id  =  T2.station_id GROUP BY T2.station_id HAVING avg(T2.bikes_available)  >  14 UNION SELECT name ,  id FROM station WHERE installation_date LIKE "12/%"'
q6 = 'SELECT count(*) FROM station AS T1 JOIN trip AS T2 JOIN station AS T3 JOIN trip AS T4 ON T1.id  =  T2.start_station_id AND T2.id  =  T4.id AND T3.id  =  T4.end_station_id WHERE T1.city  =  "Mountain View" AND T3.city  =  "Palo Alto"'
q7 = 'SELECT DISTINCT T2.Hardware_Model_name FROM screen_mode AS T1 JOIN phone AS T2 ON T1.Graphics_mode = T2.screen_mode WHERE T1.Type  =  "Graphics" OR t2.Company_name  =  "Nokia Corporation"'
get_tables_aliasing(q7)

[('screen_mode', 'T1'), ('phone', 'T2')]

In [557]:
def get_tables_aliasing(sql):
    parsed = sqlglot.parse_one(sql)
    aliasing = {table.this.output_name: table.alias for table in parsed.find_all(exp.Table)}
    return aliasing

get_tables_aliasing(q1)

{'CAR_NAMES': 'T1', 'CARS_DATA': 'T2'}

In [536]:
q5 = 'SELECT T1.name ,  T1.id FROM station AS T1 JOIN status AS T2 ON T1.id  =  T2.station_id GROUP BY T2.station_id HAVING avg(T2.bikes_available)  >  14 UNION SELECT name ,  id FROM station WHERE installation_date LIKE "12/%"'
q6 = 'SELECT count(*) FROM station AS T1 JOIN trip AS T2 JOIN station AS T3 JOIN trip AS T4 ON T1.id  =  T2.start_station_id AND T2.id  =  T4.id AND T3.id  =  T4.end_station_id WHERE T1.city  =  "Mountain View" AND T3.city  =  "Palo Alto"'
get_tables_aliasing(q6)

{'station': 'T3', 'trip': 'T4'}

In [352]:
def find_tables(sql):
    parsed = sqlglot.parse_one(sql)
    tables = [table.this.output_name for table in parsed.find_all(exp.Table)]
    return tables

find_tables(q1)

['CAR_NAMES', 'CARS_DATA']

In [450]:
q4 = "SELECT student_id FROM students WHERE student_id NOT IN (SELECT student_id FROM student_course_attendance)"

In [385]:
def find_columns(sql):
    parsed = sqlglot.parse_one(sql)
    columns = [col.output_name for col in parsed.find_all(exp.Column)]
    return columns

find_columns(q1)

['horsepower', 'Make', 'MakeId', 'Id', 'cylinders', 'horsepower']

In [None]:
def translate_query(query, db_id, tables_trans_dict, columns_trans_dict):
    tables_trans_dict = tables_trans_dict[db_id]
    columns_trans_dict = columns_trans_dict[db_id]
    
    aliasing_rev = {new.lower(): old for old, new in get_tables_aliasing(query)}
    
    columns_names = find_columns(query)
    tables_names = find_tables(query)
    
    statement = sqlparse.parse(query)[0]
    tokens = [token for token in statement.flatten() if str(token).strip() != '']

    for i in reversed(range(len(tokens))):
        if str(tokens[i].ttype).startswith('Token.Literal.String'): # skip if token is string value
            continue
        
        if str(tokens[i]) in columns_names:
            table_name = str(tokens[i-2]) if i >= 2 and str(tokens[i-1]) == '.' else None
            if table_name is not None:
                table_name = aliasing_rev.get(table_name.lower(), table_name)
            
            if table_name is None:
                possible_table_names = [table_name for table_name, table_trans in columns_trans_dict.items() if table_name in [x.lower() for x in tables_names] and str(tokens[i]).lower() in table_trans]
                assert len(possible_table_names) > 0
                if len(possible_table_names) > 1 and not( ' union ' not in query.lower() and ' except ' not in query.lower() and ' intersect ' not in query.lower()):
                    print(query, '|', str(tokens[i]), f'({db_id})')
                table_name = possible_table_names[0]
            column_name_pl = columns_trans_dict[table_name.lower()][str(tokens[i]).lower()]['name_original']
            tokens[i].value = column_name_pl
            
        elif str(tokens[i]) in tables_names:
            table_name_pl = tables_trans_dict[str(tokens[i]).lower()]['name_original']
            tokens[i].value = table_name_pl
            
    return str(statement)
                

q0 = "1SELECT student_id FROM student_course_registrations UNION 2SELECT student_id FROM student_course_attendance intersect 3SELECT student_id FROM student_course_attendance"
translate_query(
    q0,
    "student_assessment",
    create_tables_translation_dict('../auxiliary/translated_schema/tables_names.json'),
    create_columns_translation_dict('../auxiliary/translated_schema/columns_names.json'),
)

In [685]:
q0 = "1SELECT student_id FROM student_course_registrations"

In [686]:
match = re.fullmatch(r'^(.+?)(?: (UNION|INTERSECT|EXCEPT) (.+?))*$', q0, flags=re.IGNORECASE)
match.groups()

('1SELECT student_id FROM student_course_registrations', None, None)

In [648]:
dir(match)

['__class__',
 '__class_getitem__',
 '__copy__',
 '__deepcopy__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'end',
 'endpos',
 'expand',
 'group',
 'groupdict',
 'groups',
 'lastgroup',
 'lastindex',
 'pos',
 're',
 'regs',
 'span',
 'start',
 'string']

In [613]:
def translate_samples(samples, tables_translation_dict, columns_translation_dict):
    for sample in samples:
        query_pl = translate_query(sample['query'], sample['db_id'], tables_translation_dict, columns_translation_dict)
        sample['query'] = query_pl
    return samples

In [619]:
with open('../auxiliary/translated_likes/train_spider.json') as f:
    samples = json.load(f)
    
tables_translation_dict = create_tables_translation_dict('../auxiliary/translated_schema/tables_names.json')
columns_translation_dict = create_columns_translation_dict('../auxiliary/translated_schema/columns_names.json')

trans_samples = translate_samples(samples, tables_translation_dict, columns_translation_dict)

SELECT student_id FROM student_course_registrations UNION SELECT student_id FROM student_course_attendance | student_id (student_assessment)
SELECT student_id FROM student_course_registrations UNION SELECT student_id FROM student_course_attendance | student_id (student_assessment)
SELECT student_id FROM student_course_registrations UNION SELECT student_id FROM student_course_attendance | student_id (student_assessment)
SELECT student_id FROM student_course_registrations UNION SELECT student_id FROM student_course_attendance | student_id (student_assessment)
SELECT course_id FROM student_course_registrations WHERE student_id = 121 UNION SELECT course_id FROM student_course_attendance WHERE student_id = 121 | student_id (student_assessment)
SELECT course_id FROM student_course_registrations WHERE student_id = 121 UNION SELECT course_id FROM student_course_attendance WHERE student_id = 121 | course_id (student_assessment)
SELECT course_id FROM student_course_registrations WHERE student_id

In [591]:
trans_samples

In [596]:
with open('../auxiliary/translated_likes/train_spider.json', 'w') as f:
        json.dump(trans_samples, f, indent=4, ensure_ascii=False)

In [18]:
def create_complete(src_name, dst_name):
    src_path = Path('../auxiliary') / src_name
    dst_path = Path('../complete') / dst_name
    
    dst_path.mkdir(exist_ok=True, parents=True)
    
    for file in ['dev.json', 'train_others.json', 'train_spider.json']:
        add_calculated_attributes_to_samples(
            src_path / file,
            dst_path / file,
            src_path / 'tables.json'
        )

    create_gold_sql(
        [
            src_path / 'dev.json'
        ],
        dst_path / 'dev_gold.sql'
    )

    create_gold_sql(
        [
            src_path / 'train_spider.json',
            src_path / 'train_others.json'
        ],
        dst_path / 'train_gold.sql'
    )
    
    shutil.copyfile(src_path / 'tables.json', dst_path / 'tables.json')

In [19]:
create_complete('machine_translated', 'machine_translated')