In [1]:
import json
import re

import srsly

In [2]:
tables_json = json.load(open('../stackexchange_schema/tables_so.json'))
train_sede = [line for line in srsly.read_jsonl('../data/sede/train.jsonl')]
dev_sede = [line for line in srsly.read_jsonl('../data/sede/val.jsonl')]
test_sede = [line for line in srsly.read_jsonl('../data/sede/test.jsonl')]
train_dev_test = train_sede + dev_sede + test_sede

## Number of examples

In [3]:
len(train_dev_test)

12023

In [5]:
train_dev_test[30]

{'QuerySetId': 2803,
 'Title': 'Users by location, with a minimum reputation',
 'Description': None,
 'QueryBody': 'WITH  a minimum reputation\n\nselect\n  Id as "User Link",\n  Reputation,\n  WebsiteUrl as "Website URL",\n  Location\nfrom Users\nwhere\n  Location like \'%##location##%\' and\n  Reputation >= ##minimumReputation##\norder by Reputation desc',
 'CreationDate': '2020-12-09 04:38:29',
 'validated': False}

## Number of different queries

In [6]:
queries = set([ex['QueryBody'].lower() for ex in train_dev_test])
len(queries)

11767

## Number of SQL n-grams

In [7]:
ngrams = set()
n = 3
for ex in train_dev_test:
    tokens = [t.lower() for t in ex['QueryBody'].split()]
    for i in range(len(tokens)-n+1):
        ngrams.add(tuple(tokens[i:i+n]))
print(len(ngrams))

173343


## Number of question n-grams

In [8]:
ngrams = set()
n = 3
for ex in train_dev_test:
    tokens = [t.lower() for t in ex['Title'].split() if t not in ['.', '?', ',']]  # spider doesn't contain other punctuations AFAIR, but we should replace this with a better filter for other datasets
    for i in range(len(tokens)-n+1):
        ngrams.add(tuple(tokens[i:i+n]))
print(len(ngrams))

42615


## Average tables per question

In [9]:
schemas = {}
for db_json in tables_json:
    db_id = db_json['db_id']
    table_names = db_json["table_names_original"]
    columns = [(column_name[0], column_name[1]) for column_name in db_json["column_names_original"]]
    schemas[db_id] = {}
    for table_index, table_name in enumerate(table_names):
        schemas[db_id][table_name] = []
        table_columns = [column for column in columns if column[0] == table_index]
        for table_column in table_columns:
            schemas[db_id][table_name].append(table_column[1])

# print(schemas["farm"])

In [10]:
print(f"Number of tables: {len(schemas['stackexchange'])}")
num_of_columns = sum([len(columns) for _, columns in schemas["stackexchange"].items()])
print(f"Number of columns: {num_of_columns}")

Number of tables: 29
Number of columns: 211


In [11]:
counts = []
for ex in train_dev_test:
    available_tables = set([t.lower() for t in schemas[ex.get('db_id', "stackexchange")].keys()])
    table_tokens_used = [t for t in ex['QueryBody'].split() if t.lower() in available_tables]
#     print(table_tokens_used)
#     print(available_tables)
#     print([t.lower() for t in ex['QueryBody'].split()])
    counts.append(len(table_tokens_used))
print(sum(counts) / len(counts))

2.1445562671546203


## Anonymized queries

In [12]:
templates = dict()
for ex in train_dev_test:
    query = ex['QueryBody'].replace('(', ' ( ').replace(')', ' ) ').lower().strip('; ')
    query = re.sub(r'".*"', '{value}', query)
    query = re.sub(r"'.*'", '{value}', query)
    query = re.sub(r"\s\d+.\d+", '{number}', query)
    query = re.sub(r"\s\d+", '{number}', query)
    query_tokens = [t for t in query.split() if t]
    
    for i, token in enumerate(query_tokens):
        if token.startswith("t1.") or token.startswith("t2.") or token.startswith("t3."):
            query_tokens[i] = "{item}"
    
    available_tables = set([t.lower() for t in schemas[ex.get('db_id', "stackexchange")].keys()])
    available_columns = set([c.lower() for t in schemas[ex.get('db_id', "stackexchange")].values() for c in t])
    
    available_items = available_tables.union(available_columns)
    
    anonymized = " ".join(['{item}' if t in available_items else t for t in query_tokens])
    if anonymized not in templates:
        templates[anonymized] = 1
    else:
        templates[anonymized] += 1
#     print(query)
#     print(available_items)
#     print(' '.join(anonymized))
#     print("****")

print(len(templates))

10664


In [13]:
print(f"Avg. # queries / templates = {sum(list(templates.values()))/len(templates)}")

Avg. # queries / templates = 1.1274381095273818
