In [1]:
import re
import json
import ipdb
import datasets

In [2]:
# load dataset
all_data = datasets.load_dataset('wikisql')

Using custom data configuration default
Reusing dataset wikisql (/Users/alexwang/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d)


  0%|          | 0/3 [00:00<?, ?it/s]

In [81]:
all_data['train'][3]['sql']

{'human_readable': 'SELECT Text/background colour FROM table WHERE State/territory = Australian Capital Territory',
 'sel': 1,
 'agg': 0,
 'conds': {'column_index': [0],
  'operator_index': [0],
  'condition': ['Australian Capital Territory']}}

In [25]:
to_replace = {
    ' ': '-',
    '/': '-',
    '.': '',
    '#': 'Number',
}


def process_column_name(col):
    
    for k, v in to_replace.items():
        col = col.replace(k, v)
    return col
    

def process_tables(examples):
    """ For each table, format them to be BigQuery ready:
        * remove whitespace from table names
        * quote strings
    """
    
    new_examples = []
    for example in examples:
        table = example['table']

        sql = example['sql']
        target = sql['human_readable']
        for condition in sql['conds']['condition']:
            target = target.replace(condition, f'\"{condition}\"')
        
        cols = table['header']
        new_cols = []
        
        for col in cols:
            new_col = process_column_name(col)
            new_cols.append(new_col)
            target = target.replace(col, new_col)
            
            
        example['modified_sql'] = target
        table['modified_header'] = new_cols
        example['table'] = table
        new_examples.append(example)
        
    return new_examples

In [26]:
SPLITS = ["train", "validation", "test"]
#SPLITS = ["validation"]

split2data = {}
for split in SPLITS:
    split2data[split] = process_tables(all_data[split])

In [12]:
#all_data['train'][0]['table']
split2data['validation'][0]

{'phase': 1,
 'question': 'What position does the player who played for butler cc (ks) play?',
 'table': {'header': ['Player',
   'No',
   'Nationality',
   'Position',
   'Years-in-Toronto',
   'School-Club-Team'],
  'page_title': 'Toronto Raptors all-time roster',
  'page_id': '',
  'types': ['text', 'text', 'text', 'text', 'text', 'text'],
  'id': '1-10015132-11',
  'section_title': 'L',
  'caption': 'L',
  'rows': [['Antonio Lang',
    '21',
    'United States',
    'Guard-Forward',
    '1999-2000',
    'Duke'],
   ['Voshon Lenard', '2', 'United States', 'Guard', '2002-03', 'Minnesota'],
   ['Martin Lewis',
    '32, 44',
    'United States',
    'Guard-Forward',
    '1996-97',
    'Butler CC (KS)'],
   ['Brad Lohaus', '33', 'United States', 'Forward-Center', '1996', 'Iowa'],
   ['Art Long',
    '42',
    'United States',
    'Forward-Center',
    '2002-03',
    'Cincinnati'],
   ['John Long', '25', 'United States', 'Guard', '1996-97', 'Detroit'],
   ['Kyle Lowry', '3', 'United Stat

## Format Data

In [37]:
# NB: model must be cased
# NB: position of question vs table

def format_input(example, use_modified_fields=True,
                 input_token="<i>",
                 col_tok="<c>", typ_tok="<t>",
                 use_prefix=False, prefix="Generate SQL for the question:"):
    # input_token: string for separating table from question
    question = example['question']
    
    table = example['table']
    table_name = table['name'] if table['name'] else table['id']
    cols = table['modified_header'] if use_modified_fields else table['header']
    types = table['types']
    
    linearized_table = f'{table_name}'
    for col, typ in zip(cols, types):
        linearized_table += f' {col_tok} {col} {typ_tok} {typ}'
    
    if use_prefix:
        input_str = f"{prefix} {linearized_table} {input_token} {question}"
    else:
        input_str = f"{linearized_table} {input_token} {question}"
    
    return f"{linearized_table} {input_token} {question}".strip()
    
    
def format_output(example, use_modified_fields=True):
    target = example['modified_sql'] if use_modified_fields else example['sql']['human_readable']
    
    table = example['table']
    table_name = table['name'] if table['name'] else table['id']
    target = target.replace('FROM table', f'FROM {table_name}')
    return target
    
    
def format_example(example, use_modified_fields=True,
                   sep_token="<s>", end_tok="</s>",
                   input_token="<i>",
                   col_tok="<c>", typ_tok="<t>"):
    example_input = format_input(example, input_token=input_token, use_modified_fields=use_modified_fields)
    example_output = format_output(example, use_modified_fields=use_modified_fields)
    return f"{example_input} {sep_token} {example_output} {end_tok}"


def process_dataset(data, max_examples=-1,
                    example_separator="--SEPARATOR--",
                    use_modified_fields=True):
    """
    """
    n_exs = 0
    processed_examples = []
    for example in data:
        processed_examples.append(format_example(example, use_modified_fields=use_modified_fields))
        n_exs += 1
        if max_examples > 0 and n_exs >= max_examples:
            break
            
    print(f"\tProcessed {len(processed_examples)} examples")
    return f"\n{example_separator}\n".join(processed_examples)

In [40]:
SPLITS = ["train", "validation", "test"]
#SPLITS = ["validation"]
max_examples = 1000
use_prefix = False
use_modified_fields = False

for split in SPLITS:
    #split_data = all_data[split]
    split_data = split2data[split]
    print(f"Processing {split}")
    processed_data = process_dataset(split_data, max_examples=max_examples,
                                     use_modified_fields=use_modified_fields)
    
    split_file_name = f'{split}'
    if max_examples > 0:
        split_file_name += f'.nexs{max_examples}'
    if use_prefix:
        split_file_name += '.prefix'
    out_dir = 'bigquery' if use_modified_fields else 'wikisql'
    split_file_name = f'data/{out_dir}/{split_file_name}.txt'
    
    with open(split_file_name, 'w') as out_fh:
        out_fh.write(processed_data)
    print(f'\tWrote to {split_file_name}')

Processing train
	Processed 1000 examples
	Wrote to data/wikisql/train.nexs1000.txt
Processing validation
	Processed 1000 examples
	Wrote to data/wikisql/validation.nexs1000.txt
Processing test
	Processed 1000 examples
	Wrote to data/wikisql/test.nexs1000.txt


## Extract Tables

In [33]:
def extract_tables(data):
    """ 
    """
    
    all_tables = {}
    
    for split in ["train", "validation", "test"]:
        print(f"Extracting tables from {split}")
        for example in data[split]:
            table = example["table"]
            table_name = table['name'] if table['name'] else table['id']
            
            if table_name in all_tables:
                assert all_tables[table_name] == table, "Table name collision"
            else:
                all_tables[table_name] = table
            
    all_tables = list(all_tables.values())
    with open("data/all_tables.jsonl", 'w', encoding='utf-8') as out_fh:
        for table in all_tables:
            out_fh.write(f'{json.dumps(table)}\n')
            
    print(f"Wrote {len(all_tables)} tables")
    
#extract_tables(all_data)
extract_tables(split2data)

Extracting tables from train
Extracting tables from validation
Extracting tables from test
Wrote 25683 tables


In [29]:
# try loading 
tables = [json.loads(l) for l in open("data/all_tables.jsonl")]

In [30]:
tables[0]

{'header': ['Player',
  'No.',
  'Nationality',
  'Position',
  'Years in Toronto',
  'School/Club Team'],
 'page_title': 'Toronto Raptors all-time roster',
 'page_id': '',
 'types': ['text', 'text', 'text', 'text', 'text', 'text'],
 'id': '1-10015132-11',
 'section_title': 'L',
 'caption': 'L',
 'rows': [['Antonio Lang',
   '21',
   'United States',
   'Guard-Forward',
   '1999-2000',
   'Duke'],
  ['Voshon Lenard', '2', 'United States', 'Guard', '2002-03', 'Minnesota'],
  ['Martin Lewis',
   '32, 44',
   'United States',
   'Guard-Forward',
   '1996-97',
   'Butler CC (KS)'],
  ['Brad Lohaus', '33', 'United States', 'Forward-Center', '1996', 'Iowa'],
  ['Art Long',
   '42',
   'United States',
   'Forward-Center',
   '2002-03',
   'Cincinnati'],
  ['John Long', '25', 'United States', 'Guard', '1996-97', 'Detroit'],
  ['Kyle Lowry', '3', 'United States', 'Guard', '2012-Present', 'Villanova']],
 'name': 'table_10015132_11'}