<a href="https://colab.research.google.com/github/dmi3eva/araneae/blob/main/p3_araneae_wrapper/AraneaeWrapper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Preprocessing

In [2]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
%cd /content
%cd drive
%cd My\ Drive
%cd PhD
%cd Paper_01
%cd spider_preprocessing

/content
/content/drive
/content/drive/My Drive
/content/drive/My Drive/PhD
/content/drive/My Drive/PhD/Paper_01
/content/drive/My Drive/PhD/Paper_01/spider_preprocessing


In [4]:
from process_sql import *

### Utils

In [5]:
def unfold_list(values):
  unfolded_list = []
  for _value in values:
    if isinstance(_value, list):
      unfolded_list += unfold_list(_value)
    else:
      new_value = _value
      if isinstance(_value, float) and (int(_value) - _value < 0.01):
        new_value = int(value)
      unfolded_list.append(new_value)
  return unfolded_list

In [6]:
def preprocess_SQL(sentence, splitter='select'):
    sentence = sentence.strip()
    tokens = sentence.split()

    # Убираем названия (новое)
    if ' as ' in sentence.lower():
      renaming = {}    
      for i, _token in enumerate(tokens):
        if _token.lower() == 'as':
          sentence = sentence.replace(f'AS {tokens[i + 1]} ', '')
          sentence = sentence.replace(tokens[i + 1], tokens[i - 1])
    

    # Сортируем аргументы SELECT
    tokens = sentence.split()
    select_part = []
    select_tokens = []
    before_tokens = []
    after_tokens = []
    select_flag = 'before'
    for _token in tokens:  
      if splitter in _token.lower():
        select_flag = 'during'
        before_tokens.append(_token)
      elif select_flag == 'during' and 'from' != _token.lower().strip():
        select_tokens.append(_token)
      else:
        if 'from' in _token.lower().strip():
          select_flag = 'after'
          after_tokens.append(_token)
        elif select_flag == 'after':
          after_tokens.append(_token)
        else:
          before_tokens.append(_token)

    select_part = ' '.join(select_tokens)
    select_tokens = select_part.split(',')
    select_tokens = sorted([_t.strip() for _t in select_tokens])
    select_part = ' , '.join(select_tokens)

    sentence = "{} {} {}".format(' '.join(before_tokens), select_part, ' '.join(after_tokens))

    # Убираем названия таблиц, если она одна
    if '.' in sentence and not 'join' in sentence.lower():
      tokens = sentence.split()          
      new_tokens = []
      for _token in tokens:  
        if '.' in _token:
          new_tokens.append(_token.split('.')[1])
        else:
          new_tokens.append(_token)
      sentence = ' '.join(new_tokens)
      
    result = sentence.replace('* ', '*').replace(' * ', '*').replace(' ,', ',').replace('( ', '(').replace(' )', ')')
    return result

### Main class

In [7]:
import pandas as pd
from copy import deepcopy
from typing import Tuple

In [8]:
class AraneaeWrapper:
  def __init__(self, samples_path, tables_path, db_path):
    self.samples = []
    self.db_path = db_path
    with open(samples_path) as sample_file:
      self.samples = json.load(sample_file)
    with open(tables_path) as tables_file:
      self.tables = json.load(tables_file)
    for _sample in self.samples:
      _sample = self.augment_sample(_sample)

  def augment_sample(self, sample):
    if 'description' not in sample.keys():
      sample['description'] = {}
    if 'tags' not in sample.keys():
      sample['tags'] = []    
    if 'question_toks' not in sample.keys():
      sample['question_toks'] = tokenize(sample['question'])      
    
    if 'sql' not in sample.keys():
      db_id = sample['db_id']
      db = get_schema(f"{self.db_path}/{db_id}/{db_id}.sqlite")
      schema = Schema(db)
      sql_dict = get_sql(schema, sample['query'])
      sample['sql'] = sql_dict
      query_no_value = sample['query']
      for _value in unfold_list(sql_dict['where']):
        if str(_value) in query_no_value:
          query_no_value = query_no_value.replace(str(_value), 'value')
      sample['question_toks_no_value'] = tokenize(query_no_value)
      
    augmented_sample = deepcopy(sample)
    augmented_sample = self.add_nl_tag(augmented_sample)
    return augmented_sample

  def add_nl_tag(self, sample):
    if 'NL-length' in sample['description']:
      return sample
    tag = None
    augmented_sample = deepcopy(sample)
    request_len = len(sample['question_toks'])
    if request_len <= 13:
      tag = "NL-short"
    elif request_len <= 17:
      tag = "NL-avg"
    else:
      tag = "NL-long"
    augmented_sample['tags'].append(tag)
    augmented_sample['description']['NL-length'] = tag
    return augmented_sample

  def get_samples_with_tag(self, tag):
    samples_with_tag = [_s for _s in self.samples if tag in _s['tags']]
    return samples_with_tag

  def add_sample(self, nl, db_id, sql, source, tag=None, description: Tuple[str, str]=(None, None)):
    new_sample = {
        'db_id': db_id,
        'question': nl,
        'source': source,
        'query': preprocess_SQL(sql),
        'tags': [],
        'description': {}
    }   
    print(new_sample) 
    if tag:
      new_sample['tags'].append(tag)
    if description[0] and description[1]:
      new_sample['description'][description[0]] = description[1]
    new_sample = self.augment_sample(new_sample)
    self.samples.append(new_sample)

  def add_samples_from_csv(self, file_path, tag=None, description: Tuple[str, str]=(None, None)):
    new_samples = pd.read_csv(file_path)
    for ind, row in new_samples.iterrows():
      self.add_sample(row['nl'], row['db_id'], row['sql'], file_path, tag=tag, description=description)

  def save_in_json(self, file_path):
    with open(file_path, 'w') as json_file:
      json.dump(self.samples, json_file)

### Spider preprocessing

In [49]:
%cd /content
%cd drive
%cd My\ Drive
%cd PhD
%cd Paper_01
%cd spider_preprocessing

/content
/content/drive
/content/drive/My Drive
/content/drive/My Drive/PhD
/content/drive/My Drive/PhD/Paper_01
/content/drive/My Drive/PhD/Paper_01/spider_preprocessing


In [50]:
!git clone https://github.com/taoyds/spider

Cloning into 'spider'...
remote: Enumerating objects: 380, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 380 (delta 8), reused 0 (delta 0), pack-reused 361[K
Receiving objects: 100% (380/380), 44.95 MiB | 15.96 MiB/s, done.
Resolving deltas: 100% (102/102), done.
Checking out files: 100% (261/261), done.


### Тест

In [9]:
%cd /content
%cd drive
%cd My\ Drive
%cd PhD
%cd Paper_01

/content
/content/drive
/content/drive/My Drive
/content/drive/My Drive/PhD
/content/drive/My Drive/PhD/Paper_01


In [10]:
import json

In [11]:
araneae = AraneaeWrapper("datasets/araneae/araneae.json", "datasets/spider/tables.json", "datasets/spider/database")

In [None]:
araneae.add_samples_from_csv("datasets/to_add/binary_values.csv", tag='binary-values', description=('values', 'containing-binary'))

In [None]:
araneae.get_samples_with_tag('binary-values')[0]

### Tests' zone

Tokenization from Spider

In [58]:
tokens = tokenize("My dogs is good!")
tokens

['my', 'dogs', 'is', 'good', '!']

SQL-processing from Spider

In [13]:
query = "SELECT name FROM singer"
schema_sql = get_schema("datasets/spider/database/singer/singer.sqlite")
schema = Schema(schema_sql)
sql = get_sql(schema, query)
sql

{'except': None,
 'from': {'conds': [], 'table_units': [('table_unit', '__singer__')]},
 'groupBy': [],
 'having': [],
 'intersect': None,
 'limit': None,
 'orderBy': [],
 'select': (False, [(0, (0, (0, '__singer.name__', False), None))]),
 'union': None,
 'where': []}

SQL-preprocessing from Araneae

In [16]:
sentence = "SELECT T1.name , zaza, count(*), A FROM concert AS T1 GROUP BY t1.stadium_id"
preprocessed_sql = preprocess_SQL(sentence)
preprocessed_sql

'SELECT A, name, count(*), zaza FROM concert GROUP BY stadium_id'