In [1]:
import json
import re
from collections import Counter

In [2]:
tables_json = json.load(open('../data/spider/tables.json'))
train_spider = json.load(open('../data/spider/train_spider.json'))
dev_spider = json.load(open('../data/spider/dev.json'))
train_dev_spider = train_spider + dev_spider

## Number of examples

In [3]:
len(train_dev_spider)

8034

In [4]:
train_dev_spider[30]

{'db_id': 'farm',
 'query': 'SELECT Official_Name FROM city ORDER BY Population DESC',
 'query_toks': ['SELECT',
  'Official_Name',
  'FROM',
  'city',
  'ORDER',
  'BY',
  'Population',
  'DESC'],
 'query_toks_no_value': ['select',
  'official_name',
  'from',
  'city',
  'order',
  'by',
  'population',
  'desc'],
 'question': 'List official names of cities in descending order of population.',
 'question_toks': ['List',
  'official',
  'names',
  'of',
  'cities',
  'in',
  'descending',
  'order',
  'of',
  'population',
  '.'],
 'sql': {'from': {'table_units': [['table_unit', 0]], 'conds': []},
  'select': [False, [[0, [0, [0, 2, False], None]]]],
  'where': [],
  'groupBy': [],
  'having': [],
  'orderBy': ['desc', [[0, [0, 5, False], None]]],
  'limit': None,
  'intersect': None,
  'union': None,
  'except': None}}

## Number of different queries

In [5]:
queries = set([ex['query'].lower() for ex in train_dev_spider])
len(queries)

4491

## Number of SQL n-grams

In [6]:
ngrams = set()
n = 3
for ex in train_dev_spider:
    tokens = [t.lower() for t in ex['query_toks']]
    for i in range(len(tokens)-n+1):
        ngrams.add(tuple(tokens[i:i+n]))
print(len(ngrams))

25263


## Number of question n-grams

In [7]:
ngrams = set()
n = 3
for ex in train_dev_spider:
    tokens = [t.lower() for t in ex['question_toks'] if t not in ['.', '?', ',']]  # spider doesn't contain other punctuations AFAIR, but we should replace this with a better filter for other datasets
    for i in range(len(tokens)-n+1):
        ngrams.add(tuple(tokens[i:i+n]))
print(len(ngrams))

41706


## Average tables per question

In [8]:
# tables_json[0]

In [9]:
schemas = {}
for db_json in tables_json:
    db_id = db_json['db_id']
    table_names = db_json["table_names_original"]
    columns = [(column_name[0], column_name[1]) for column_name in db_json["column_names_original"]]
    schemas[db_id] = {}
    for table_index, table_name in enumerate(table_names):
        schemas[db_id][table_name] = []
        table_columns = [column for column in columns if column[0] == table_index]
        for table_column in table_columns:
            schemas[db_id][table_name].append(table_column[1])

# print(schemas["farm"])

In [10]:
counts = []
for ex in train_dev_spider:
    available_tables = set([t.lower() for t in schemas[ex['db_id']].keys()])
    table_tokens_used = [t for t in ex['query_toks'] if t.lower() in available_tables]
#     print(table_tokens_used)
#     print(available_tables)
#     print([t.lower() for t in ex['query_toks']])
    counts.append(len(table_tokens_used))
print(sum(counts) / len(counts))

1.7082399800846402


## Anonymized queries

In [12]:
templates = dict()
for ex in train_dev_spider:
    query = ex['query'].replace('(', ' ( ').replace(')', ' ) ').lower().strip('; ')
    query = re.sub(r'".*"', '{value}', query)
    query = re.sub(r"'.*'", '{value}', query)
    query = re.sub(r"\s\d+.\d+", '{number}', query)
    query = re.sub(r"\s\d+", '{number}', query)
    query_tokens = [t for t in query.split() if t]
    
    for i, token in enumerate(query_tokens):
        if token.startswith("t1.") or token.startswith("t2.") or token.startswith("t3."):
            query_tokens[i] = "{item}"
    
    available_tables = set([t.lower() for t in schemas[ex['db_id']].keys()])
    available_columns = set([c.lower() for t in schemas[ex['db_id']].values() for c in t])
    
    available_items = available_tables.union(available_columns)
    
    anonymized = " ".join(['{item}' if t in available_items else t for t in query_tokens])
    if anonymized not in templates:
        templates[anonymized] = 1
    else:
        templates[anonymized] += 1
#     print(query)
#     print(available_items)
#     print(' '.join(anonymized))
#     print("****")

print(len(templates))

1059


In [13]:
print(f"Avg. # queries / templates = {sum(list(templates.values()))/len(templates)}")

Avg. # queries / templates = 7.586402266288951
