In [129]:
import re
import json
import ipdb
import random
from collections import defaultdict

import datasets
import numpy as np

# WikiSQL

In [88]:
# load dataset
all_data = datasets.load_dataset('wikisql')

Using custom data configuration default
Reusing dataset wikisql (/Users/alexwang/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)


  0%|          | 0/3 [00:00<?, ?it/s]

In [91]:
all_data['train'][3]['sql']

{'human_readable': 'SELECT Text/background colour FROM table WHERE State/territory = Australian Capital Territory',
 'sel': 1,
 'agg': 0,
 'conds': {'column_index': [0],
  'operator_index': [0],
  'condition': ['Australian Capital Territory']}}

In [92]:
to_replace = {
    ' ': '-',
    '/': '-',
    '.': '',
    '#': 'Number',
}


def process_column_name(col):
    
    for k, v in to_replace.items():
        col = col.replace(k, v)
    return col
    

def process_tables(examples):
    """ For each table, format them to be BigQuery ready:
        * remove whitespace from table names
        * quote strings
    """
    
    new_examples = []
    for example in examples:
        table = example['table']

        sql = example['sql']
        target = sql['human_readable']
        for condition in sql['conds']['condition']:
            target = target.replace(condition, f'\"{condition}\"')
        
        cols = table['header']
        new_cols = []
        
        for col in cols:
            new_col = process_column_name(col)
            new_cols.append(new_col)
            target = target.replace(col, new_col)
            
            
        example['modified_sql'] = target
        table['modified_header'] = new_cols
        example['table'] = table
        new_examples.append(example)
        
    return new_examples

In [93]:
SPLITS = ["train", "validation", "test"]
#SPLITS = ["validation"]

split2data = {}
for split in SPLITS:
    split2data[split] = process_tables(all_data[split])

In [12]:
#all_data['train'][0]['table']
split2data['validation'][0]

{'phase': 1,
 'question': 'What position does the player who played for butler cc (ks) play?',
 'table': {'header': ['Player',
   'No',
   'Nationality',
   'Position',
   'Years-in-Toronto',
   'School-Club-Team'],
  'page_title': 'Toronto Raptors all-time roster',
  'page_id': '',
  'types': ['text', 'text', 'text', 'text', 'text', 'text'],
  'id': '1-10015132-11',
  'section_title': 'L',
  'caption': 'L',
  'rows': [['Antonio Lang',
    '21',
    'United States',
    'Guard-Forward',
    '1999-2000',
    'Duke'],
   ['Voshon Lenard', '2', 'United States', 'Guard', '2002-03', 'Minnesota'],
   ['Martin Lewis',
    '32, 44',
    'United States',
    'Guard-Forward',
    '1996-97',
    'Butler CC (KS)'],
   ['Brad Lohaus', '33', 'United States', 'Forward-Center', '1996', 'Iowa'],
   ['Art Long',
    '42',
    'United States',
    'Forward-Center',
    '2002-03',
    'Cincinnati'],
   ['John Long', '25', 'United States', 'Guard', '1996-97', 'Detroit'],
   ['Kyle Lowry', '3', 'United Stat

### Format Data

In [115]:
# NB: model must be cased
# NB: position of question vs table

def format_input(example, use_modified_fields=True,
                 input_token="<i>",
                 col_tok="<c>", typ_tok="<t>",
                 use_prefix=False, prefix="Generate SQL for the question:",
                 augmentation_factor=1):
    # input_token: string for separating table from question
    question = example['question']
    
    table = example['table']
    table_name = table['name'] if table['name'] else table['id']
    cols = table['modified_header'] if use_modified_fields else table['header']
    types = table['types']
    cols_and_types = list(zip(cols, types))
    
    assert isinstance(augmentation_factor, int), "Augmentation factor must be an integer!"
    linearized_tables = []

    for aug_idx in range(augmentation_factor):  
        
        linearized_table = f'{table_name}'
        
        if aug_idx > 0: # zero-index
            # keep table columns in order to use original data
            random.shuffle(cols_and_types)
        
        # shuffle (col and typ)
        for col, typ in cols_and_types:
            linearized_table += f' {col_tok} {col} {typ_tok} {typ}'
    
        if use_prefix:
            input_str = f"{prefix} {linearized_table} {input_token} {question}"
        else:
            input_str = f"{linearized_table} {input_token} {question}"
            
        linearized_tables.append(f"{linearized_table} {input_token} {question}".strip())
    
    return linearized_tables
    
    
def format_output(example, use_modified_fields=True):
    target = example['modified_sql'] if use_modified_fields else example['sql']['human_readable']
    
    table = example['table']
    table_name = table['name'] if table['name'] else table['id']
    target = target.replace('FROM table', f'FROM {table_name}')
    return target
    
    
def format_example(example, use_modified_fields=True,
                   sep_token="<s>", end_tok="</s>",
                   input_token="<i>",
                   col_tok="<c>", typ_tok="<t>",
                   augmentation_factor=1):
    example_inputs = format_input(example, input_token=input_token, 
                                  use_modified_fields=use_modified_fields,
                                  augmentation_factor=augmentation_factor)
    example_output = format_output(example, use_modified_fields=use_modified_fields)
    examples = [f"{example_input} {sep_token} {example_output} {end_tok}" for example_input in example_inputs]
    return examples


def process_dataset(data, max_examples=-1,
                    example_separator="--SEPARATOR--",
                    use_modified_fields=True,
                    augmentation_factor=1):
    """
    """
    n_exs = 0
    processed_examples = []
    for example in data:
        processed_examples += format_example(example, use_modified_fields=use_modified_fields, 
                                             augmentation_factor=augmentation_factor)
        n_exs += 1
        if max_examples > 0 and n_exs >= max_examples:
            break
            
    print(f"\tProcessed {len(processed_examples)} examples")
    return f"\n{example_separator}\n".join(processed_examples)

In [128]:
SPLITS = ["train", "validation", "test"]
#SPLITS = ["validation"]
max_examples = -1
use_prefix = False
use_modified_fields = True
augmentation_factor = 5

for split in SPLITS:
    #split_data = all_data[split]
    split_data = split2data[split]
    print(f"Processing {split}")
    processed_data = process_dataset(split_data, max_examples=max_examples,
                                     use_modified_fields=use_modified_fields,
                                     augmentation_factor=augmentation_factor)
    
    split_file_name = f'{split}'
    if max_examples > 0:
        split_file_name += f'.nexs{max_examples}'
    if use_prefix:
        split_file_name += '.prefix'
    if augmentation_factor > 1:
        split_file_name += f'.aug{augmentation_factor}'
    out_dir = 'bigquery' if use_modified_fields else 'wikisql'
    split_file_name = f'data/wikisql/{out_dir}/{split_file_name}.txt'
    
    with open(split_file_name, 'w') as out_fh:
        out_fh.write(processed_data)
    print(f'\tWrote to {split_file_name}')

Processing train
	Processed 281775 examples
	Wrote to data/wikisql/bigquery/train.aug5.txt
Processing validation
	Processed 42105 examples
	Wrote to data/wikisql/bigquery/validation.aug5.txt
Processing test
	Processed 79390 examples
	Wrote to data/wikisql/bigquery/test.aug5.txt


### Extract Tables

In [120]:
def extract_tables(data, do_write=True):
    """ 
    """
    
    all_tables = {}
    
    for split in ["train", "validation", "test"]:
        print(f"Extracting tables from {split}")
        for example in data[split]:
            table = example["table"]
            table_name = table['name'] if table['name'] else table['id']
            
            if table_name in all_tables:
                assert all_tables[table_name] == table, "Table name collision"
            else:
                all_tables[table_name] = table
            
    all_tables = list(all_tables.values())
    if do_write:
        with open("data/all_tables.jsonl", 'w', encoding='utf-8') as out_fh:
            for table in all_tables:
                out_fh.write(f'{json.dumps(table)}\n')

        print(f"Wrote {len(all_tables)} tables")
        
    return all_tables
    
all_tables = extract_tables(split2data, do_write=False)

Extracting tables from train
Extracting tables from validation
Extracting tables from test


In [127]:
n_colss = np.array([len(t['header']) for t in all_tables])
n_colss.min()

5

In [29]:
# try loading 
tables = [json.loads(l) for l in open("data/all_tables.jsonl")]

# Spider

In [46]:
all_spider = datasets.load_dataset('spider')
spider_tables = {t['db_id']: t for t in json.load(open('data/spider/tables.json', encoding='utf-8'))}

Reusing dataset spider (/Users/alexwang/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa)


  0%|          | 0/2 [00:00<?, ?it/s]

In [131]:
def linearize_spider_table(table,                    
                           sep_token="<s>", end_tok="</s>",
                           input_token="<i>",
                           col_tok="<c>", typ_tok="<t>",
                           augmentation_factor=1):
    """ Expects tables as a dict (see below function) 
    """

    table_name = table['name']
    cols = table['columns']
    types = table['types']

    linearized_table = f'{table_name}'
    for col, typ in zip(cols, types):
        linearized_table += f' {col_tok} {col} {typ_tok} {typ}'

    return linearized_table


def linearize_spider_tables(tables, table_separator="</t>",
                            augmentation_factor=1):
    id2db = defaultdict(list)
    
    for db_name, db in spider_tables.items():
        
        all_table_names = db['table_names_original']
        idx2table = {k: {'name': v, 'columns': [], 'types': []} for k, v in enumerate(all_table_names)}

        # map columns to tables
        all_cols = db['column_names_original']
        all_types = db['column_types']
        for col, typ in zip(all_cols, all_types):
            table_idx, col_name = col
            if table_idx not in idx2table:
                continue
            idx2table[table_idx]['columns'].append(col_name)
            idx2table[table_idx]['types'].append(typ)
            
        # linearize tables and join them into one long string
        linearized_db = []
        for table in list(idx2table.values()):
            linearized_db.append(linearize_spider_table(table))
            
        for aug_idx in range(augmentation_factor):
            if augmentation_factor > 0:
                random.shuffle(linearized_db)

            id2db[db_name].append(f" {table_separator} ".join(linearized_db))
        
    return id2db

In [132]:
id2db['perpetrator']

['people <c> People_ID <t> number <c> Name <t> text <c> Height <t> number <c> Weight <t> number <c> Home Town <t> text </t> perpetrator <c> Perpetrator_ID <t> number <c> People_ID <t> number <c> Date <t> text <c> Year <t> number <c> Location <t> text <c> Country <t> text <c> Killed <t> number <c> Injured <t> number',
 'perpetrator <c> Perpetrator_ID <t> number <c> People_ID <t> number <c> Date <t> text <c> Year <t> number <c> Location <t> text <c> Country <t> text <c> Killed <t> number <c> Injured <t> number </t> people <c> People_ID <t> number <c> Name <t> text <c> Height <t> number <c> Weight <t> number <c> Home Town <t> text']

In [135]:
# NB: model must be cased
# NB: position of question vs table

def format_spider_input(example,
                        input_token="<i>",
                        use_prefix=False, prefix="Generate SQL for the question:",
                        augmentation_factor=1):
    # input_token: string for separating table from question
    question = example['question']
    
    table_id = example['db_id']
    linearized_dbs = id2db[table_id]
    linearized_inputs = []
    for linearized_db in linearized_dbs:
        
        if use_prefix:
            # TODO(AW): add prefix separator?
            input_str = f"{prefix} {linearized_db} {input_token} {question}"
        else:
            input_str = f"{linearized_db} {input_token} {question}"
    
        linearized_inputs.append(f"{input_str}".strip())
        
    return linearized_inputs
    
    
def format_spider_output(example):
    target = example['query']
    return target
    
    
def format_spider(example,
                   sep_token="<s>", end_tok="</s>",
                   input_token="<i>",
                   col_tok="<c>", typ_tok="<t>",
                   augmentation_factor=1):
    example_inputs = format_spider_input(example, input_token=input_token,
                                         augmentation_factor=augmentation_factor)
    example_output = format_spider_output(example)
    examples = [f"{example_input} {sep_token} {example_output} {end_tok}" for example_input in example_inputs]
    return examples


def process_spider_split(data, max_examples=-1,
                         example_separator="--SEPARATOR--",
                         augmentation_factor=1):
    """
    """
    n_exs = 0
    processed_examples = []
    for example in data:
        processed_examples += format_spider(example, augmentation_factor=augmentation_factor)
        n_exs += 1
        if max_examples > 0 and n_exs >= max_examples:
            break
            
    print(f"\tProcessed {len(processed_examples)} examples")
    return f"\n{example_separator}\n".join(processed_examples)

In [136]:
SPLITS = ["train", "validation"]
#SPLITS = ["validation"]
max_examples = -1
use_prefix = False
use_modified_fields = False

augmentation_factor = 2
id2db = linearize_spider_tables(spider_tables, augmentation_factor=augmentation_factor)

for split in SPLITS:
    #split_data = all_data[split]
    split_data = all_spider[split]
    print(f"Processing {split}")
    processed_data = process_spider_split(split_data, max_examples=max_examples)
    
    split_file_name = f'{split}'
    if max_examples > 0:
        split_file_name += f'.nexs{max_examples}'
    if use_prefix:
        split_file_name += '.prefix'
    if augmentation_factor > 1:
        split_file_name += f'.aug{augmentation_factor}'
    #out_dir = 'bigquery' if use_modified_fields else 'wikisql'
    split_file_name = f'data/spider/{split_file_name}.txt'
    
    with open(split_file_name, 'w') as out_fh:
        out_fh.write(processed_data)
    print(f'\tWrote to {split_file_name}')

Processing train
	Processed 14000 examples
	Wrote to data/spider/train.aug2.txt
Processing validation
	Processed 2068 examples
	Wrote to data/spider/validation.aug2.txt
