# 02. Preprocess

We preprocess the nvBench dataset in ncNet, based on the EDA result in the previous notebook.

There are two primary columns in the nvBench dataset: nl_queries and vis_query.VQL (SQL + Visualization Attributes).
It is troublesome to parse VQL. However, ncNet provides a parsed format of VQL, VegaZero.
So we preprocess the nvBench dataset in ncNet to use VegaZero, instead of the original one. 

Our goal is NOT to create a better model for V-NLI but to propose a new UI for data visualization.
That is, there is no need to use precisely the same dataset as nvBench or ncNet because it's not a must to compare our model with the existing one.
So, to make experiments easier, we filter out samples containing a broken SQL, Chinese characters, etc.
Finally, we split the dataset into train / test / validation subsets.

We use this preprocessed dataset to train our models in the following notebooks.


## Setup

### Define Parameters


In [1]:
data_dir: str = "../data/"


### Load Modules


In [2]:
import sqlite3

from itertools import product
from pathlib import Path
from random import Random

import numpy as np
import pandas as pd

from sklearn.model_selection import GroupShuffleSplit

from vxnli._vega_zero import VegaZero


### Load Dataset


In [3]:
DATA_DIR: Path = Path(data_dir)
DATABASE_DIR: Path = DATA_DIR.joinpath("datasets/nvBench/database/")
NCNET_DATASET_DIR: Path = DATA_DIR.joinpath("datasets/ncNet/dataset/")
OUTPUT_DIR: Path = DATA_DIR.joinpath("datasets/vxnli-v0/")


In [4]:
OUTPUT_DIR.mkdir(exist_ok=True)


In [5]:
ncnet_train_df = pd.read_csv(NCNET_DATASET_DIR.joinpath("dataset_final/train.csv"))
ncnet_test_df = pd.read_csv(NCNET_DATASET_DIR.joinpath("dataset_final/test.csv"))
ncnet_val_df = pd.read_csv(NCNET_DATASET_DIR.joinpath("dataset_final/dev.csv"))

df = pd.concat(
    [
        ncnet_train_df,
        ncnet_test_df,
        ncnet_val_df,
    ],
    ignore_index=True,
)

df


Unnamed: 0,tvBench_id,db_id,chart,hardness,query,question,vega_zero,mentioned_columns,mentioned_values,query_template,source,labels,token_types
0,1000@y_name@DESC,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,product_name,Sony,mark [T] data products encoding x [X] y aggreg...,<N> Bar chart x axis product name y axis how m...,mark bar data products encoding x product_name...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
1,1000@y_name@DESC,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,product_name,Sony,mark bar data products encoding x [X] y aggreg...,<N> Bar chart x axis product name y axis how m...,mark bar data products encoding x product_name...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
2,2463@x_name@ASC,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,job age name,Zach Bob Dan,mark [T] data person encoding x [X] y aggregat...,<N> how old is the youngest person for each jo...,mark bar data person encoding x job y aggregat...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
3,2463@x_name@ASC,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,job age name,Zach Bob Dan,mark bar data person encoding x [X] y aggregat...,<N> how old is the youngest person for each jo...,mark bar data person encoding x job y aggregat...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
4,2545@y_name@DESC,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...,pet_age petid pettype,cat,mark [T] data pets encoding x [X] y aggregate ...,<N> Please give me a bar chart to show the ave...,mark bar data pets encoding x pettype y aggreg...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
...,...,...,...,...,...,...,...,...,...,...,...,...,...
31583,2911,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...,id meter_200 meter_300 meter_700 meter_100 met...,,mark bar data [D] encoding x [X] y aggregate [...,<N> Return a bar chart about the distribution ...,mark bar data swimmer encoding x meter_200 y a...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
31584,218,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,monthly_rental other_details,,mark [T] data student_addresses encoding x [X]...,<N> Return a bar chart about the distribution ...,mark bar data student_addresses encoding x oth...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
31585,218,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,monthly_rental other_details,,mark bar data student_addresses encoding x [X]...,<N> Return a bar chart about the distribution ...,mark bar data student_addresses encoding x oth...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...
31586,1761@y_name@ASC,hr_1,Bar,Hard,"Visualize BAR SELECT HIRE_DATE , AVG(MANAGER_I...",For those employees who was hired before 2002-...,mark bar data employees encoding x hire_date y...,hire_date manager_id first_name last_name,Luis King Banda Gee Lee Chen Ande Baer Fay Raj...,mark [T] data employees encoding x [X] y aggre...,<N> For those employees who was hired before 2...,mark bar data employees encoding x hire_date y...,nl nl nl nl nl nl nl nl nl nl nl nl nl nl nl n...


## Preprocess

### Remove Unnecessary Columns


In [6]:
df = df[["db_id", "chart", "hardness", "query", "question", "vega_zero"]]

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...
1,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...
2,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...
3,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...
4,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...
...,...,...,...,...,...,...
31583,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...
31584,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...
31585,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...
31586,hr_1,Bar,Hard,"Visualize BAR SELECT HIRE_DATE , AVG(MANAGER_I...",For those employees who was hired before 2002-...,mark bar data employees encoding x hire_date y...


### Remove Duplicated Rows


In [7]:
df = df.drop_duplicates()
df = df.reset_index(drop=True)

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar data payments encoding x payment_date...
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar data county encoding x county_name y ...
...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar data documents_with_expenses encoding...
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar data wrestler encoding x location y a...
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...


### Add SQL Column


In [8]:
def retrieve_sql_from_query(query: str) -> str:
    # Remove "Visualize CHART_TYPE"
    [_, _, sql] = query.split(" ", maxsplit=2)

    # Remove "BIN ..." because it is not a standard sqlite3 format and raises an error
    [sql, *_] = sql.rsplit("BIN", maxsplit=1)

    return sql


df["SQL"] = df["query"].apply(retrieve_sql_from_query)

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,"SELECT product_name , COUNT(product_name) FROM..."
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,"SELECT job , min(age) FROM Person GROUP BY job..."
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...,"SELECT PetType , avg(pet_age) FROM pets GROUP ..."
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar data payments encoding x payment_date...,"SELECT payment_date , COUNT(payment_date) FROM..."
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar data county encoding x county_name y ...,"SELECT County_name , Population FROM county"
...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar data documents_with_expenses encoding...,"SELECT Budget_Type_Code , count(*) FROM Docume..."
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar data wrestler encoding x location y a...,"SELECT Location , COUNT(Location) FROM wrestle..."
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ..."
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,"SELECT other_details , SUM(monthly_rental) FRO..."


### Remove Rows Containing Broken SQLs

In [9]:
has_sql_error = np.zeros(len(df), dtype=bool)

for i, row in df.iterrows():
    con = row["db_id"]
    con = DATABASE_DIR.joinpath(f"{con}/{con}.sqlite")
    con = sqlite3.connect(con)

    cur = con.cursor()

    try:
        cur.execute(row["SQL"]).fetchall()
    except sqlite3.OperationalError:
        has_sql_error[i] = True
    finally:
        con.close()

df = df[~has_sql_error]

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,"SELECT product_name , COUNT(product_name) FROM..."
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,"SELECT job , min(age) FROM Person GROUP BY job..."
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...,"SELECT PetType , avg(pet_age) FROM pets GROUP ..."
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar data payments encoding x payment_date...,"SELECT payment_date , COUNT(payment_date) FROM..."
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar data county encoding x county_name y ...,"SELECT County_name , Population FROM county"
...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar data documents_with_expenses encoding...,"SELECT Budget_Type_Code , count(*) FROM Docume..."
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar data wrestler encoding x location y a...,"SELECT Location , COUNT(Location) FROM wrestle..."
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ..."
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,"SELECT other_details , SUM(monthly_rental) FRO..."


### Remove Rows Containing Broken Table


In [10]:
tables = set()

for _, row in df.iterrows():
    db_id = row["db_id"]

    vega_zero = row["vega_zero"]
    vega_zero = VegaZero.from_str(vega_zero)

    table = vega_zero.data

    tables.add((db_id, table))

broken_tables = set()

for db_id, table in tables:
    with sqlite3.connect(DATABASE_DIR.joinpath(f"{db_id}/{db_id}.sqlite")) as con:
        try:
            pd.read_sql(f"SELECT * FROM {table}", con)
        except:
            broken_tables.add((db_id, table))

broken_tables


{('wta_1', 'players')}

In [11]:
def has_broken_table(row) -> bool:
    db_id = row["db_id"]

    vega_zero = row["vega_zero"]
    vega_zero = VegaZero.from_str(vega_zero)

    table = vega_zero.data

    return (db_id, table) in broken_tables


df = df[~df.apply(has_broken_table, axis=1)]

len(df)


15731

### Replace Apostrophes in Question with Double Quote


In [12]:
# Copy dataframe to avoid SettingWithCopyWarning
df = df.copy()

df["question"] = df["question"].str.replace("’", '"')
df["question"] = df["question"].str.replace("‘", '"').copy()

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,"SELECT product_name , COUNT(product_name) FROM..."
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,"SELECT job , min(age) FROM Person GROUP BY job..."
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...,"SELECT PetType , avg(pet_age) FROM pets GROUP ..."
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar data payments encoding x payment_date...,"SELECT payment_date , COUNT(payment_date) FROM..."
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar data county encoding x county_name y ...,"SELECT County_name , Population FROM county"
...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar data documents_with_expenses encoding...,"SELECT Budget_Type_Code , count(*) FROM Docume..."
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar data wrestler encoding x location y a...,"SELECT Location , COUNT(Location) FROM wrestle..."
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ..."
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,"SELECT other_details , SUM(monthly_rental) FRO..."


### Remove Rows Containing Chinese Characters


In [13]:
def is_ascii(question: str) -> bool:
    return question.isascii()


# pandas dataframe doesn't have str.isascii method
df = df[df["question"].apply(is_ascii)]

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,"SELECT product_name , COUNT(product_name) FROM..."
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,"SELECT job , min(age) FROM Person GROUP BY job..."
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...,"SELECT PetType , avg(pet_age) FROM pets GROUP ..."
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar data payments encoding x payment_date...,"SELECT payment_date , COUNT(payment_date) FROM..."
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar data county encoding x county_name y ...,"SELECT County_name , Population FROM county"
...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar data documents_with_expenses encoding...,"SELECT Budget_Type_Code , count(*) FROM Docume..."
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar data wrestler encoding x location y a...,"SELECT Location , COUNT(Location) FROM wrestle..."
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ..."
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,"SELECT other_details , SUM(monthly_rental) FRO..."


### Remove Commas

In [14]:
# Copy dataframe to avoid SettingWithCopyWarning
df = df.copy()

df["vega_zero"] = df["vega_zero"].str.replace(" , ", " ")

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,"SELECT product_name , COUNT(product_name) FROM..."
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,"SELECT job , min(age) FROM Person GROUP BY job..."
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...,"SELECT PetType , avg(pet_age) FROM pets GROUP ..."
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar data payments encoding x payment_date...,"SELECT payment_date , COUNT(payment_date) FROM..."
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar data county encoding x county_name y ...,"SELECT County_name , Population FROM county"
...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar data documents_with_expenses encoding...,"SELECT Budget_Type_Code , count(*) FROM Docume..."
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar data wrestler encoding x location y a...,"SELECT Location , COUNT(Location) FROM wrestle..."
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ..."
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,"SELECT other_details , SUM(monthly_rental) FRO..."


### Add Table Column


In [15]:
df["table"] = df["vega_zero"].apply(lambda x: VegaZero.from_str(x).data)

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar data products encoding x product_name...,"SELECT product_name , COUNT(product_name) FROM...",products
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar data person encoding x job y aggregat...,"SELECT job , min(age) FROM Person GROUP BY job...",person
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar data pets encoding x pettype y aggreg...,"SELECT PetType , avg(pet_age) FROM pets GROUP ...",pets
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar data payments encoding x payment_date...,"SELECT payment_date , COUNT(payment_date) FROM...",payments
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar data county encoding x county_name y ...,"SELECT County_name , Population FROM county",county
...,...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar data documents_with_expenses encoding...,"SELECT Budget_Type_Code , count(*) FROM Docume...",documents_with_expenses
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar data wrestler encoding x location y a...,"SELECT Location , COUNT(Location) FROM wrestle...",wrestler
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar data swimmer encoding x meter_200 y a...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ...",swimmer
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar data student_addresses encoding x oth...,"SELECT other_details , SUM(monthly_rental) FRO...",student_addresses


### Remove Data Attribute from VegaZero

We assume a tabular data is provided as a pandas dataframe.
So we don't use the data attribute of VegaZero in the prediction.


In [16]:
def remove_data_attr(vega_zero_str: str) -> str:
    vega_zero = VegaZero.from_str(vega_zero_str)

    vega_zero.data = None

    return str(vega_zero)


df["vega_zero"] = df["vega_zero"].apply(remove_data_attr)

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar encoding x product_name y aggregate c...,"SELECT product_name , COUNT(product_name) FROM...",products
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar encoding x job y aggregate min age tr...,"SELECT job , min(age) FROM Person GROUP BY job...",person
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar encoding x pettype y aggregate mean p...,"SELECT PetType , avg(pet_age) FROM pets GROUP ...",pets
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar encoding x payment_date y aggregate c...,"SELECT payment_date , COUNT(payment_date) FROM...",payments
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar encoding x county_name y aggregate no...,"SELECT County_name , Population FROM county",county
...,...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar encoding x budget_type_code y aggrega...,"SELECT Budget_Type_Code , count(*) FROM Docume...",documents_with_expenses
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar encoding x location y aggregate count...,"SELECT Location , COUNT(Location) FROM wrestle...",wrestler
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar encoding x meter_200 y aggregate mean...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ...",swimmer
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar encoding x other_details y aggregate ...,"SELECT other_details , SUM(monthly_rental) FRO...",student_addresses


### Replace Single Quote with Double Quote

In [17]:
# Copy dataframe to avoid SettingWithCopyWarning
df = df.copy()

df["vega_zero"] = df["vega_zero"].str.replace("'", '"')

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",Bar chart x axis product name y axis how many ...,mark bar encoding x product_name y aggregate c...,"SELECT product_name , COUNT(product_name) FROM...",products
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar encoding x job y aggregate min age tr...,"SELECT job , min(age) FROM Person GROUP BY job...",person
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",Please give me a bar chart to show the average...,mark bar encoding x pettype y aggregate mean p...,"SELECT PetType , avg(pet_age) FROM pets GROUP ...",pets
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",What are the payment date of the payment with ...,mark bar encoding x payment_date y aggregate c...,"SELECT payment_date , COUNT(payment_date) FROM...",payments
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",What are the name and population of each count...,mark bar encoding x county_name y aggregate no...,"SELECT County_name , Population FROM county",county
...,...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar encoding x budget_type_code y aggrega...,"SELECT Budget_Type_Code , count(*) FROM Docume...",documents_with_expenses
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",Give the number of locations of all wrestlers ...,mark bar encoding x location y aggregate count...,"SELECT Location , COUNT(Location) FROM wrestle...",wrestler
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",Return a bar chart about the distribution of m...,mark bar encoding x meter_200 y aggregate mean...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ...",swimmer
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",Return a bar chart about the distribution of o...,mark bar encoding x other_details y aggregate ...,"SELECT other_details , SUM(monthly_rental) FRO...",student_addresses


### Replace Questions with Lower Case

In VegaZero provided by the ncNet authors, all of characters are lower cased.
It's ok to fix them with the actual cases, however, it's a bit time consuming.
Moreover, if we unify the cases, we can improve the model accuracy, and avoid unnecessary errors in our final user study.

Again, this study's purpose is to compare V-XNLI with V-NLI.


In [18]:
df["question"] = df["question"].str.lower()

df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",bar chart x axis product name y axis how many ...,mark bar encoding x product_name y aggregate c...,"SELECT product_name , COUNT(product_name) FROM...",products
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar encoding x job y aggregate min age tr...,"SELECT job , min(age) FROM Person GROUP BY job...",person
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",please give me a bar chart to show the average...,mark bar encoding x pettype y aggregate mean p...,"SELECT PetType , avg(pet_age) FROM pets GROUP ...",pets
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",what are the payment date of the payment with ...,mark bar encoding x payment_date y aggregate c...,"SELECT payment_date , COUNT(payment_date) FROM...",payments
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",what are the name and population of each count...,mark bar encoding x county_name y aggregate no...,"SELECT County_name , Population FROM county",county
...,...,...,...,...,...,...,...,...
15780,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar encoding x budget_type_code y aggrega...,"SELECT Budget_Type_Code , count(*) FROM Docume...",documents_with_expenses
15781,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",give the number of locations of all wrestlers ...,mark bar encoding x location y aggregate count...,"SELECT Location , COUNT(Location) FROM wrestle...",wrestler
15782,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",return a bar chart about the distribution of m...,mark bar encoding x meter_200 y aggregate mean...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ...",swimmer
15783,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT other_details , SUM(month...",return a bar chart about the distribution of o...,mark bar encoding x other_details y aggregate ...,"SELECT other_details , SUM(monthly_rental) FRO...",student_addresses


## Split Dataset

We split the dataset into train, test and validation subsets (Train : Test : Validate = 8 : 1 : 1).

We use sklearn's GroupShuffleSplit (grouped by db_id column) to avoid data leakage.
Without this, the trained model knows the database contents beforehand.
It increases the accuracy of known tables, but decreases the accuracy of unknown tables.
The ncNet authors trains their model with the table content leaked data, however, we avoid it to consider the real world usage. 

It might be better to spare more samples for the train dataset, however, the number of databases are not that large (141) while the sample size is huge.
We use GroupShuffleSplit, so the number of databases are also split into 8 : 1 : 1.

### Split Dataset with GroupShuffleSplit

With GroupShuffleSplit, the numbers of db_id are 8 : 1 : 1, but the dataset sizes are not guaranteed to be 8 : 1 : 1.
However, the actual sizes happen to be almost 8 : 1 : 1 (12798 : 1543 : 1385).


In [19]:
gss = GroupShuffleSplit(n_splits=1, train_size=0.8, random_state=123)

for train_idx, test_idx in gss.split(df, groups=df["db_id"]):
    train_df, test_df = df.iloc[train_idx, :], df.iloc[test_idx, :]

gss = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=123)

for val_idx, test_idx in gss.split(test_df, groups=test_df["db_id"]):
    test_df, val_df = (
        test_df.iloc[test_idx, :],
        test_df.iloc[val_idx, :],
    )

train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)


In [20]:
train_df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table
0,customers_and_products_contacts,Bar,Medium,"Visualize BAR SELECT product_name , COUNT(prod...",bar chart x axis product name y axis how many ...,mark bar encoding x product_name y aggregate c...,"SELECT product_name , COUNT(product_name) FROM...",products
1,network_2,Bar,Easy,"Visualize BAR SELECT job , min(age) FROM Perso...",how old is the youngest person for each job ? ...,mark bar encoding x job y aggregate min age tr...,"SELECT job , min(age) FROM Person GROUP BY job...",person
2,pets_1,Bar,Medium,"Visualize BAR SELECT PetType , avg(pet_age) FR...",please give me a bar chart to show the average...,mark bar encoding x pettype y aggregate mean p...,"SELECT PetType , avg(pet_age) FROM pets GROUP ...",pets
3,products_for_hire,Bar,Extra Hard,"Visualize BAR SELECT payment_date , COUNT(paym...",what are the payment date of the payment with ...,mark bar encoding x payment_date y aggregate c...,"SELECT payment_date , COUNT(payment_date) FROM...",payments
4,election,Bar,Easy,"Visualize BAR SELECT County_name , Population ...",what are the name and population of each count...,mark bar encoding x county_name y aggregate no...,"SELECT County_name , Population FROM county",county
...,...,...,...,...,...,...,...,...
12793,hr_1,Scatter,Medium,"Visualize SCATTER SELECT COMMISSION_PCT , DEPA...",for those employees whose salary is in the ran...,mark point encoding x commission_pct y aggrega...,"SELECT COMMISSION_PCT , DEPARTMENT_ID FROM emp...",employees
12794,cre_Docs_and_Epenses,Bar,Medium,"Visualize BAR SELECT Budget_Type_Code , count(...","what are the different budget type codes , and...",mark bar encoding x budget_type_code y aggrega...,"SELECT Budget_Type_Code , count(*) FROM Docume...",documents_with_expenses
12795,wrestler,Bar,Medium,"Visualize BAR SELECT Location , COUNT(Location...",give the number of locations of all wrestlers ...,mark bar encoding x location y aggregate count...,"SELECT Location , COUNT(Location) FROM wrestle...",wrestler
12796,swimming,Bar,Easy,"Visualize BAR SELECT meter_200 , AVG(ID) FROM ...",return a bar chart about the distribution of m...,mark bar encoding x meter_200 y aggregate mean...,"SELECT meter_200 , AVG(ID) FROM swimmer GROUP ...",swimmer


In [21]:
test_df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table
0,cinema,Bar,Medium,"Visualize BAR SELECT Openning_year , COUNT(Ope...",give me a bar chart showing the number of cine...,mark bar encoding x openning_year y aggregate ...,"SELECT Openning_year , COUNT(Openning_year) FR...",cinema
1,wta_1,Bar,Medium,"Visualize BAR SELECT year , count(*) FROM matc...",find the number of matches happened in each ye...,mark bar encoding x year y aggregate count yea...,"SELECT year , count(*) FROM matches GROUP BY Y...",matches
2,candidate_poll,Bar,Easy,"Visualize BAR SELECT Sex , min(weight) FROM pe...",what is the minimum weights for people of each...,mark bar encoding x sex y aggregate min weight...,"SELECT Sex , min(weight) FROM people GROUP BY ...",people
3,candidate_poll,Bar,Medium,"Visualize BAR SELECT Sex , AVG(Weight) FROM pe...",show me the average of weight by sex in a hist...,mark bar encoding x sex y aggregate mean weigh...,"SELECT Sex , AVG(Weight) FROM people GROUP BY ...",people
4,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT date_address_from , COUNT...",visualize a bar chart about the distribution o...,mark bar encoding x date_address_from y aggreg...,"SELECT date_address_from , COUNT(date_address_...",student_addresses
...,...,...,...,...,...,...,...,...
1538,local_govt_in_alabama,Pie,Easy,"Visualize PIE SELECT Event_Details , COUNT(Eve...",group and count details for the events using a...,mark arc encoding x event_details y aggregate ...,"SELECT Event_Details , COUNT(Event_Details) FR...",events
1539,riding_club,Bar,Medium,"Visualize BAR SELECT Occupation , COUNT(Occupa...",bar chart x axis occupation y axis how many oc...,mark bar encoding x occupation y aggregate cou...,"SELECT Occupation , COUNT(Occupation) FROM pla...",player
1540,store_product,Pie,Easy,"Visualize PIE SELECT Type , count(*) FROM stor...","for each type of store , how many of them are ...",mark arc encoding x type y aggregate count typ...,"SELECT Type , count(*) FROM store GROUP BY TYPE",store
1541,candidate_poll,Bar,Easy,"Visualize BAR SELECT Name , Weight FROM people...",return a bar chart about the distribution of n...,mark bar encoding x name y aggregate none weig...,"SELECT Name , Weight FROM people ORDER BY Name...",people


In [22]:
val_df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table
0,college_1,Bar,Medium,"Visualize BAR SELECT CRS_CODE , count(*) FROM ...",visualize a bar chart for how many sections do...,mark bar encoding x crs_code y aggregate count...,"SELECT CRS_CODE , count(*) FROM CLASS GROUP BY...",class
1,game_injury,Grouping Line,Hard,"Visualize LINE SELECT Season , COUNT(Season) F...",list the number of games in each season and gr...,mark line encoding x season y aggregate count ...,"SELECT Season , COUNT(Season) FROM game GROUP ...",game
2,cre_Doc_Tracking_DB,Grouping Line,Extra Hard,"Visualize LINE SELECT Date_in_Locaton_To , COU...",show the number of documents in different endi...,mark line encoding x date_in_locaton_to y aggr...,"SELECT Date_in_Locaton_To , COUNT(Date_in_Loca...",document_locations
3,department_store,Bar,Medium,"Visualize BAR SELECT payment_method_code , cou...","for each payment method , return how many cust...",mark bar encoding x payment_method_code y aggr...,"SELECT payment_method_code , count(*) FROM cus...",customers
4,college_1,Scatter,Easy,"Visualize SCATTER SELECT max(stu_gpa) , avg(st...",find the relationship between maximum and aver...,mark point encoding x max(stu_gpa) y aggregate...,"SELECT max(stu_gpa) , avg(stu_gpa) FROM studen...",student
...,...,...,...,...,...,...,...,...
1380,school_finance,Bar,Medium,"Visualize BAR SELECT County , count(*) FROM sc...",draw a bar chart of county versus the total nu...,mark bar encoding x county y aggregate count c...,"SELECT County , count(*) FROM school GROUP BY ...",school
1381,cre_Doc_Tracking_DB,Stacked Bar,Hard,"Visualize BAR SELECT Date_in_Locaton_To , COUN...",stacked bar of date in locaton to and the numb...,mark bar encoding x date_in_locaton_to y aggre...,"SELECT Date_in_Locaton_To , COUNT(Date_in_Loca...",document_locations
1382,cre_Doc_Tracking_DB,Bar,Medium,"Visualize BAR SELECT Location_Code , count(*) ...",show the location codes and the number of docu...,mark bar encoding x location_code y aggregate ...,"SELECT Location_Code , count(*) FROM Document_...",document_locations
1383,cre_Doc_Tracking_DB,Bar,Easy,"Visualize BAR SELECT Location_Code , count(*) ...",what is the code of each location and the numb...,mark bar encoding x location_code y aggregate ...,"SELECT Location_Code , count(*) FROM Document_...",document_locations


In [23]:
train_df.to_csv(OUTPUT_DIR.joinpath("train.csv"), index=False)
test_df.to_csv(OUTPUT_DIR.joinpath("test.csv"), index=False)
val_df.to_csv(OUTPUT_DIR.joinpath("val.csv"), index=False)


## Analytics


In [24]:
df["hardness"].value_counts()


Medium        6931
Easy          5582
Hard          1834
Extra Hard    1379
Name: hardness, dtype: int64

In [25]:
df["chart"].value_counts()


Bar                 11592
Pie                  1124
Line                 1078
Scatter               722
Stacked Bar           624
Grouping Scatter      377
Grouping Line         209
Name: chart, dtype: int64

In [26]:
# Check if the database subsets don't have the same databases (just in case)

train_df_unique_db_ids = set(train_df["db_id"].unique())
test_df_unique_db_ids = set(test_df["db_id"].unique())
val_df_unique_db_ids = set(val_df["db_id"].unique())

assert len(train_df_unique_db_ids & test_df_unique_db_ids & val_df_unique_db_ids) == 0

# The ratio of databases
len(train_df_unique_db_ids), len(test_df_unique_db_ids), len(val_df_unique_db_ids)


(112, 15, 14)

In [27]:
# The number of chart is not that balanced
# However, it cannot be helped to some extent because the number of examples in each database is pretty imbalanced
pd.concat(
    [
        train_df["chart"].value_counts().rename("train"),
        val_df["chart"].value_counts().rename("val"),
        test_df["chart"].value_counts().rename("test"),
    ],
    axis=1,
)


Unnamed: 0,train,val,test
Bar,9600,946,1046
Pie,887,148,89
Line,837,113,128
Scatter,557,31,134
Stacked Bar,480,65,79
Grouping Scatter,322,15,40
Grouping Line,115,67,27


In [28]:
# Same as the above comment
pd.concat(
    [
        train_df["hardness"].value_counts().rename("train"),
        val_df["hardness"].value_counts().rename("val"),
        test_df["hardness"].value_counts().rename("test"),
    ],
    axis=1,
)


Unnamed: 0,train,val,test
Medium,5445,717,769
Easy,4522,472,588
Hard,1595,157,82
Extra Hard,1236,39,104
