# 03. Model v0

Model v0 is an ML model for typical V-NLI, which accepts text and tabular data as input, and returns the corresponding figure.

We adopt [TAPEX](https://arxiv.org/abs/2107.07653), a pre-trained [BART](https://arxiv.org/abs/1910.13461) model, as a base model.
And we use almost the same hyperparameters as TAPEX.

We fine-tune the model with the two nvBench datasets preprocessed in the previous notebook.
One model is for the user study, and another is for real-world usage (Check the details in the previous notebook).

We use Hugging Face Transformers, one of the most famous NLP libraries, for implementation.
Primarily we refer to [this TAPEX example](https://github.com/huggingface/transformers/blob/main/examples/research_projects/tapex/run_wikisql_with_tapex.py
).

As a result, the exact match ratio of the user study model gets to be ~90% for the test dataset, and the model for real-world usage gets to be ~60%.
It's not comparable to the existing work because we adopt a different way to preprocess the nvBench dataset.
However, our goal is not to improve an ML model for V-NLI but to propose a novel UI for data visualization.

In the final user study, we use this model as the baseline model to compare our proposed interface with typical V-NLI.

## Setup

### Define Parameters


In [1]:
data_dir: str = "../data/"
load_model_from_last_checkpoint: bool = False
push_model_to_huggingface_hub: bool = True
skip_training: bool = False

# If user_study is True, the model is trained by the preprocessed dataset with the stratified sampling
# Otherwise, the group shuffled dataset is used
# See the preprocess notebook for the details of the datasets
user_study: bool = False


### Load Modules

In [2]:
import functools
import multiprocessing
import os
import pandas as pd
import sqlite3

from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List

import evaluate
import datasets
import numpy as np
import torch
import transformers

from datasets import Dataset
from transformers import (
    BartConfig,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    EvalPrediction,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TapexTokenizer,
    trainer_utils,
)


In [3]:
transformers.set_seed(123)


### Define Variables

In [4]:
# Paths

DATA_DIR: Path = Path(data_dir)
DATABASE_DIR: Path = DATA_DIR.joinpath("database")

if user_study:
    MODEL_NAME: str = "vxnli-v0-user-study"
    PREPROCESSED_NVBENCH_DIR: Path = DATA_DIR.joinpath(
        "preprocessed-nvBench/stratified"
    )
else:
    MODEL_NAME: str = "vxnli-v0"
    PREPROCESSED_NVBENCH_DIR: Path = DATA_DIR.joinpath("preprocessed-nvBench/grouped")

DATASET_OUTPUT_DIR: Path = DATA_DIR.joinpath(f"{MODEL_NAME}.hf")
MODEL_OUTPUT_DIR: Path = DATA_DIR.joinpath(MODEL_NAME)
PREDS_OUTPUT_PATH: Path = DATA_DIR.joinpath(f"{MODEL_NAME}-preds.csv")

# Model Parameters

BASE_MODEL: str = "microsoft/tapex-base-finetuned-wtq"
MAX_SOURCE_LENGTH: int = 1024
MAX_TARGET_LENGTH: int = 124


### Load Tokenizer


In [5]:
tokenizer = TapexTokenizer.from_pretrained(
    BASE_MODEL, use_fast=True, add_prefix_space=True
)


In [6]:
# Example
tokenizer(
    table=pd.DataFrame.from_dict(
        {
            "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
            "Number of movies": ["87", "53", "69"],
        }
    ),
    answer="how many movies does Leonardo Di Caprio have?",
    return_tensors="pt",
)


{'input_ids': tensor([[    0, 11311,  4832,  5552,  1721,   346,     9,  4133,  3236,   112,
          4832,  5378,   625,   181,  2582,  1721,  8176,  3236,   132,  4832,
          2084,   261,  6782,  2269,  2927, 12834,  1721,  4268,  3236,   155,
          4832,  5473, 26875, 42771,  6071,  1721,  5913,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

### Load Model

In [7]:
model_config = BartConfig.from_pretrained(
    BASE_MODEL,
    no_repeat_ngram_size=0,
    max_length=MAX_SOURCE_LENGTH,
    early_stopping=False,
)

if load_model_from_last_checkpoint:
    model = trainer_utils.get_last_checkpoint(MODEL_OUTPUT_DIR)
else:
    model = BASE_MODEL

model = BartForConditionalGeneration.from_pretrained(
    model,
    config=model_config,
)


## Preprocess Dataset


In [8]:
def load_table(db_id: str, table_name: str) -> pd.DataFrame:
    db_path = DATABASE_DIR.joinpath(f"{db_id}/{db_id}.sqlite")

    with sqlite3.connect(db_path) as con:
        return pd.read_sql(f"SELECT * FROM {table_name}", con)


# Example
load_table("customers_and_products_contacts", "products").head()


Unnamed: 0,product_id,product_type_code,product_name,product_price
0,1,Hardware,Apple,54753980.0
1,2,Clothes,jcrew,30590930.0
2,3,Hardware,Apple,10268.85
3,4,Hardware,Apple,22956670.0
4,5,Clothes,jcrew,5927022.0


In [9]:
def preprocess_table(df: pd.DataFrame) -> pd.DataFrame:
    df = df.rename(columns={col: col.lower() for col in df.columns})

    # The TAPEX tokenizer raises an error when the table contains non-str columns
    df = df.astype(str)

    for col_name, col_dtype in zip(df.columns, df.dtypes):
        df[col_name] = df[col_name].str.lower()

    return df


preprocess_table(load_table("customers_and_products_contacts", "products").head())


Unnamed: 0,product_id,product_type_code,product_name,product_price
0,1,hardware,apple,54753982.574522
1,2,clothes,jcrew,30590929.528306
2,3,hardware,apple,10268.85297069
3,4,hardware,apple,22956668.699482
4,5,clothes,jcrew,5927021.8748021


In [10]:
# functools.cache is supported in python3.9+, but use lru_cache to support python3.7+
@functools.lru_cache(maxsize=None)
def load_and_preprocess_table(db_id: str, table_name: str) -> pd.DataFrame:
    table = load_table(db_id, table_name)
    table = preprocess_table(table)

    return table


In [11]:
def preprocess_dataset(example: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    table = load_and_preprocess_table(example["db_id"], example["table"])

    query = example["question"]
    answer = example["vega_zero"]

    model_inputs = tokenizer(
        table=table,
        query=query,
        answer=answer,
        max_length=MAX_SOURCE_LENGTH,
        padding=True,
        truncation=True,
    )

    labels = tokenizer(
        answer=answer,
        max_length=MAX_TARGET_LENGTH,
        padding=True,
        truncation=True,
    )

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs


In [12]:
if DATASET_OUTPUT_DIR.exists():
    # load_from_dist doesn't support pathlib.Path
    dataset = datasets.load_from_disk(str(DATASET_OUTPUT_DIR))
else:
    dataset = datasets.load_dataset(
        "csv",
        data_files={
            # load_dataset doesn't support pathlib.Path
            "train": str(PREPROCESSED_NVBENCH_DIR.joinpath("train.csv")),
            "test": str(PREPROCESSED_NVBENCH_DIR.joinpath("test.csv")),
            "validation": str(PREPROCESSED_NVBENCH_DIR.joinpath("val.csv")),
        },
    )

    dataset = dataset.map(
        preprocess_dataset,
        batched=False,
        num_proc=multiprocessing.cpu_count(),
    )

    # save_to_disk doesn't support pathlib.Path
    dataset.save_to_disk(str(DATASET_OUTPUT_DIR))

dataset


DatasetDict({
    train: Dataset({
        features: ['db_id', 'chart', 'hardness', 'query', 'question', 'vega_zero', 'SQL', 'table', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 12798
    })
    test: Dataset({
        features: ['db_id', 'chart', 'hardness', 'query', 'question', 'vega_zero', 'SQL', 'table', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1543
    })
    validation: Dataset({
        features: ['db_id', 'chart', 'hardness', 'query', 'question', 'vega_zero', 'SQL', 'table', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1385
    })
})

## Training


In [13]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=None,
)


In [14]:
exact_match = evaluate.load("exact_match")


def compute_metrics(eval_pred: EvalPrediction):
    preds, labels = eval_pred

    preds = tokenizer.batch_decode(
        preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = tokenizer.batch_decode(
        labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    return exact_match.compute(predictions=preds, references=labels)


In [15]:
# Huggingface trainer uses environment variables to configure mlflow
# https://github.com/huggingface/transformers/blob/94b3f544a1f5e04b78d87a2ae32a7ac252e22e31/src/transformers/integrations.py#L884

# MLFlow experiment name must be updated if you update training arguments
os.environ["MLFLOW_EXPERIMENT_NAME"] = f"{MODEL_NAME}-{datetime.now().strftime('%Y%m%d%H%M')}"


In [16]:
trainer = Seq2SeqTrainer(
    model=model,
    args=Seq2SeqTrainingArguments(
        output_dir=MODEL_OUTPUT_DIR,
        predict_with_generate=True,
        num_train_epochs=50,
        evaluation_strategy="epoch",
        logging_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        load_best_model_at_end=True,
        do_eval=True,
        metric_for_best_model="exact_match",
        push_to_hub=push_model_to_huggingface_hub,
        report_to="mlflow",
    ),
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=5),
    ],
)


/home/jupyter/vxnli/notebooks/../data/vxnli-v0 is already a clone of https://huggingface.co/kwkty/vxnli-v0. Make sure you pull the latest changes with `repo.git_pull()`.


In [17]:
if not skip_training:
    trainer.train()


The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: vega_zero, query, table, hardness, question, SQL, chart, db_id. If vega_zero, query, table, hardness, question, SQL, chart, db_id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 12798
  Num Epochs = 50
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 40000
  Number of trainable parameters = 139420416
2022/11/30 09:28:14 INFO mlflow.tracking.fluent: Experiment with name 'vxnli-v0-202211300927' does not exist. Creating a new experiment.


Epoch,Training Loss,Validation Loss,Exact Match
1,0.157,0.398999,0.524188
2,0.0195,0.430023,0.574007
3,0.0113,0.498301,0.550181
4,0.0089,0.448108,0.555235
5,0.0088,0.424967,0.580505
6,0.0077,0.481935,0.582671
7,0.0058,0.369306,0.566787
8,0.0058,0.447385,0.554513
9,0.0059,0.488884,0.574007
10,0.0049,0.449606,0.590614


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: vega_zero, query, table, hardness, question, SQL, chart, db_id. If vega_zero, query, table, hardness, question, SQL, chart, db_id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1385
  Batch size = 16
Saving model checkpoint to ../data/vxnli-v0/checkpoint-800
Configuration saved in ../data/vxnli-v0/checkpoint-800/config.json
Model weights saved in ../data/vxnli-v0/checkpoint-800/pytorch_model.bin
tokenizer config file saved in ../data/vxnli-v0/checkpoint-800/tokenizer_config.json
Special tokens file saved in ../data/vxnli-v0/checkpoint-800/special_tokens_map.json
tokenizer config file saved in ../data/vxnli-v0/tokenizer_config.json
Special tokens file saved in ../data/vxnli-v0/special_tokens_map.json
The following columns in the evaluation set don'

## Evaluation

In [18]:
# trainer.evaluate must be called for the model card

trainer.evaluate(dataset["test"])

The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: vega_zero, query, table, hardness, question, SQL, chart, db_id. If vega_zero, query, table, hardness, question, SQL, chart, db_id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1543
  Batch size = 16


{'eval_loss': 0.33394524455070496,
 'eval_exact_match': 0.609850939727803,
 'eval_runtime': 115.3458,
 'eval_samples_per_second': 13.377,
 'eval_steps_per_second': 0.841,
 'epoch': 18.0}

In [19]:
def predict(ds: Dataset) -> List[str]:
    preds = trainer.predict(
        ds,
        max_length=MAX_TARGET_LENGTH,
    )

    preds = tokenizer.batch_decode(
        preds.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    return [pred.strip() for pred in preds]


In [20]:
preds = predict(dataset["test"])

preds[:5], dataset["test"]["vega_zero"][:5]


The following columns in the test set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: vega_zero, query, table, hardness, question, SQL, chart, db_id. If vega_zero, query, table, hardness, question, SQL, chart, db_id are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 1543
  Batch size = 16


(['mark bar encoding x openning_year y aggregate count openning_year transform group x sort x desc',
  'mark bar encoding x year y aggregate count year transform group x sort y desc',
  'mark bar encoding x sex y aggregate min weight transform group x sort y desc',
  'mark bar encoding x sex y aggregate mean weight transform group x sort y desc',
  'mark bar encoding x date_address_from y aggregate count date_address_from transform bin x by month'],
 ['mark bar encoding x openning_year y aggregate count openning_year transform group x sort x desc',
  'mark bar encoding x year y aggregate count year transform group x sort y desc',
  'mark bar encoding x sex y aggregate min weight transform group x sort y desc',
  'mark bar encoding x sex y aggregate mean weight transform group x sort y desc',
  'mark bar encoding x date_address_from y aggregate count date_address_from transform sort monthly_rental desc bin x by year'])

In [21]:
exact_match.compute(
    predictions=preds,
    references=dataset["test"]["vega_zero"],
)

{'exact_match': 0.6072585871678549}

In [22]:
preds_df = dataset["test"].to_pandas()
preds_df = preds_df.drop(columns=["input_ids", "attention_mask", "labels"])
preds_df["pred"] = preds
preds_df["exact_matched"] = preds_df["pred"] == preds_df["vega_zero"]

preds_df.to_csv(PREDS_OUTPUT_PATH)

preds_df


Unnamed: 0,db_id,chart,hardness,query,question,vega_zero,SQL,table,pred,exact_matched
0,cinema,Bar,Medium,"Visualize BAR SELECT Openning_year , COUNT(Ope...",give me a bar chart showing the number of cine...,mark bar encoding x openning_year y aggregate ...,"SELECT Openning_year , COUNT(Openning_year) FR...",cinema,mark bar encoding x openning_year y aggregate ...,True
1,wta_1,Bar,Medium,"Visualize BAR SELECT year , count(*) FROM matc...",find the number of matches happened in each ye...,mark bar encoding x year y aggregate count yea...,"SELECT year , count(*) FROM matches GROUP BY Y...",matches,mark bar encoding x year y aggregate count yea...,True
2,candidate_poll,Bar,Easy,"Visualize BAR SELECT Sex , min(weight) FROM pe...",what is the minimum weights for people of each...,mark bar encoding x sex y aggregate min weight...,"SELECT Sex , min(weight) FROM people GROUP BY ...",people,mark bar encoding x sex y aggregate min weight...,True
3,candidate_poll,Bar,Medium,"Visualize BAR SELECT Sex , AVG(Weight) FROM pe...",show me the average of weight by sex in a hist...,mark bar encoding x sex y aggregate mean weigh...,"SELECT Sex , AVG(Weight) FROM people GROUP BY ...",people,mark bar encoding x sex y aggregate mean weigh...,True
4,behavior_monitoring,Bar,Medium,"Visualize BAR SELECT date_address_from , COUNT...",visualize a bar chart about the distribution o...,mark bar encoding x date_address_from y aggreg...,"SELECT date_address_from , COUNT(date_address_...",student_addresses,mark bar encoding x date_address_from y aggreg...,False
...,...,...,...,...,...,...,...,...,...,...
1538,local_govt_in_alabama,Pie,Easy,"Visualize PIE SELECT Event_Details , COUNT(Eve...",group and count details for the events using a...,mark arc encoding x event_details y aggregate ...,"SELECT Event_Details , COUNT(Event_Details) FR...",events,mark arc encoding x event_details y aggregate ...,True
1539,riding_club,Bar,Medium,"Visualize BAR SELECT Occupation , COUNT(Occupa...",bar chart x axis occupation y axis how many oc...,mark bar encoding x occupation y aggregate cou...,"SELECT Occupation , COUNT(Occupation) FROM pla...",player,mark bar encoding x occupation y aggregate cou...,True
1540,store_product,Pie,Easy,"Visualize PIE SELECT Type , count(*) FROM stor...","for each type of store , how many of them are ...",mark arc encoding x type y aggregate count typ...,"SELECT Type , count(*) FROM store GROUP BY TYPE",store,mark arc encoding x type y aggregate count typ...,True
1541,candidate_poll,Bar,Easy,"Visualize BAR SELECT Name , Weight FROM people...",return a bar chart about the distribution of n...,mark bar encoding x name y aggregate none weig...,"SELECT Name , Weight FROM people ORDER BY Name...",people,mark bar encoding x name y aggregate none weig...,True


In [23]:
pd.concat(
    [
        preds_df[preds_df["hardness"] == hardness]["exact_matched"]
        .value_counts()
        .rename(hardness)
        for hardness in ("Easy", "Medium", "Hard", "Extra Hard")
    ],
    axis=1,
)


Unnamed: 0,Easy,Medium,Hard,Extra Hard
True,422,495,20,
False,166,274,62,104.0


In [24]:
pd.concat(
    [
        preds_df[preds_df["chart"] == chart]["exact_matched"]
        .value_counts()
        .rename(chart)
        
        for chart in preds_df["chart"].unique()
    ],
    axis=1,
)


Unnamed: 0,Bar,Line,Scatter,Stacked Bar,Grouping Line,Pie,Grouping Scatter
True,766,32,46,6,,68,19
False,280,96,88,73,27.0,21,21


## Push Model


In [25]:
if push_model_to_huggingface_hub:
    # huggingface_hub.notebook_login()

    trainer.push_to_hub()


Saving model checkpoint to ../data/vxnli-v0
Configuration saved in ../data/vxnli-v0/config.json
Model weights saved in ../data/vxnli-v0/pytorch_model.bin
tokenizer config file saved in ../data/vxnli-v0/tokenizer_config.json
Special tokens file saved in ../data/vxnli-v0/special_tokens_map.json
remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/kwkty/vxnli-v0
   956d0f7..7987400  main -> main

Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Sequence-to-sequence Language Modeling', 'type': 'text2text-generation'}}
To https://huggingface.co/kwkty/vxnli-v0
   7987400..358b03d  main -> main

