In [1]:
import json
import os
import numpy as np
import re
import pickle
# import polars as pl
import pandas as pd
from cleantext import clean as cl
from collections import Counter
from nltk import word_tokenize, pos_tag
import sqlite3 as sqlite
from pprint import pprint as PP
from tqdm.notebook import tqdm

In [2]:
DATA = "../data/"
SPIDER = os.path.join(DATA, 'spider')
WIKI = os.path.join(DATA, 'wikisql')
VALUE_NODE = "{value}"
SW_PATH = os.path.join(*'../utils/stopwords-en.txt'.split('/'))
with open(SW_PATH, 'rb') as f:
    SW = f.readlines()

In [3]:
def load_json(file):
    with open(file, 'r') as f:
        return json.load(f)

def read_pickle(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

def write_pickle(path, obj):
    with open(path, 'wb+') as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)

In [4]:
spider = load_json(os.path.join(SPIDER, 'spider.json'))
wiki = load_json(os.path.join(WIKI, 'wikisql.json'))

In [5]:
print(f"SPIDER ds size : {len(spider)}\nWIKISQL ds size : {len(wiki)}")

SPIDER ds size : 5183
WIKISQL ds size : 51159


In [9]:
def clean_text(t):
    return cl(t, punct=False)

def tokenize_nl(t):
    _toks = [i for i in word_tokenize(clean_text(t)) if i not in SW]
    return _toks

def tokenize_query(t):
    string = str(t)
    quote_idxs = [idx for idx, char in enumerate(string) if char == "\""]
    assert len(quote_idxs) % 2 == 0, "Unexpected quote"
    vals = {}
    for i in range(len(quote_idxs)-1, -1, -2):
        qidx1 = quote_idxs[i-1]
        qidx2 = quote_idxs[i]
        val = string[qidx1: qidx2+1]
        key = "__val_{}_{}__".format(qidx1, qidx2)
        string = string[:qidx1] + key + string[qidx2+1:]
        vals[key] = val
    
    toks = [word.lower() for word in word_tokenize(string)]
    for i in range(len(toks)):
        if toks[i] in vals:
            toks[i] = vals[toks[i]]
    
    eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
    eq_idxs.reverse()
    prefix = ('!', '>', '<')
    for eq_idx in eq_idxs:
        pre_tok = toks[eq_idx-1]
        if pre_tok in prefix:
            toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1:]
    
    return toks

In [6]:
tokenize_nl('What is the current series in June 2011 ?')

['what', 'is', 'the', 'current', 'series', 'in', 'june', '2011']

In [7]:
tokenize_query('SELECT TABLEalias0.CURRENT_SERIES_FIELD tbl WHERE TABLEalias0.NOTES_FIELD = "var0" ;')

['select',
 'tablealias0.current_series_field',
 'tbl',
 'where',
 'tablealias0.notes_field',
 '=',
 '"var0"',
 ';']

In [7]:
def get_all_parellel(dump, sep='<--->'):
    cnt_que = Counter()
    cnt_sql = Counter()
    
    nl_vocab = open(os.path.join('..', 'data', 'temp', 'vocab_nl.tsv'), 'w+', encoding='utf-8')
    nl_sql = open(os.path.join('..', 'data', 'temp', 'nl_sql.csv'), 'w+', encoding='utf-8')
    sql_vocab = open(os.path.join('..', 'data', 'temp', 'vocab_sql.tsv'), 'w+', encoding='utf-8')
    
    nl_sql.write(f"question{sep}query\n")
    
    for i in tqdm(range(len(dump))):
        k = dump[i]
        q_toks = tokenize_nl(k["sentences"][0]["text"])
        temp_ques = " ".join(q_toks)
        for j in k["sql"]:
            try:
                sql_toks = tokenize_query(cl(j))
                temp_sql = " ".join(sql_toks)
                nl_sql.write(f"{temp_ques}{sep}{temp_sql}\n")
                cnt_sql.update(sql_toks)
            except AssertionError:
                continue
        cnt_que.update(q_toks)
        
        
    write_pickle("../data/temp/counter_nl.pickle", cnt_que)
    write_pickle("../data/temp/counter_sql.pickle", cnt_sql)
    
    for n, i in enumerate(cnt_que):
        nl_vocab.write(f"{i}\t{n}\n")
    
    for n, i in enumerate(cnt_sql):
        sql_vocab.write(f"{i}\t{n}\n")
    
    nl_sql.close()
    nl_vocab.close()
    sql_vocab.close()

In [10]:
get_all_parellel(wiki, sep="\t")

  0%|          | 0/51159 [00:00<?, ?it/s]

In [15]:
nl_p = read_pickle("../data/temp/counter_nl.pickle")

In [16]:
len(nl_p)

19506