# 06. Model v1

Model v1 is a ML model for the proposed UI which accepts "any" inputs.

We fine-tune model-v0 and TAPEX with the custom dataset annotated in the previous notebook.
We modify the tokenizer from the original one to support variable length and keyword arguments.
The model performance is improved 

As a result, the exact match ratio of the model-v0 fine-tuning model gets to be ~64% for the test dataset ([Weights & Biases](https://wandb.ai/kwkty/vxnli/runs/3r21665c)).
And tha TAPEX fine-tuning model is ~54% ([Weights & Biases](https://wandb.ai/kwkty/vxnli/runs/3grk8w92)).

It's better to use the former model, however, we use the latter one intentionally in the user study.
Because we want to clarify that the performance of this model doesn't depend on the dataset size.


## Setup

### Define Parameters


In [1]:
data_dir: str = "../data/"
push_model_to_huggingface_hub: bool = True
report_to_wandb: bool = True


### Load Modules

In [2]:
import functools
import multiprocessing
import os
import pandas as pd
import sqlite3

from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List

import evaluate
import datasets
import numpy as np
import torch
import transformers
import wandb

from datasets import Dataset, DatasetDict
from transformers import (
    BartConfig,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    EvalPrediction,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TapexTokenizer,
    trainer_utils,
)


In [3]:
transformers.set_seed(123)


### Define Variables

In [4]:
# Paths

DATA_DIR: Path = Path(data_dir)

MODEL_NAME: str = "vxnli-v1"

DATABASE_DIR: Path = DATA_DIR.joinpath("datasets/nvBench/database")
DATASET_DIR: Path = DATA_DIR.joinpath(f"datasets/{MODEL_NAME}/")
DATASET_OUTPUT_DIR: Path = DATA_DIR.joinpath(f"datasets/{MODEL_NAME}.hf/")

MODEL_OUTPUT_DIR: Path = DATA_DIR.joinpath(f"models/{MODEL_NAME}/")
RESULT_OUTPUT_DIR: Path = DATA_DIR.joinpath(f"results/{MODEL_NAME}/")

# Model Parameters

BASE_MODEL: str = "microsoft/tapex-base-finetuned-wtq"
# BASE_MODEL: str = "kwkty/vxnli-v0"

MAX_SOURCE_LENGTH: int = 1024
MAX_TARGET_LENGTH: int = 124


In [5]:
RESULT_OUTPUT_DIR.mkdir(exist_ok=True)


### Load Tokenizer


In [6]:
tokenizer = TapexTokenizer.from_pretrained(
    BASE_MODEL, use_fast=True, add_prefix_space=True
)

tokenizer.add_special_tokens(
    {"additional_special_tokens": ["[arg]", "[kwarg]", "[eq]"]}
)


3

### Load Model

In [7]:
model_config = BartConfig.from_pretrained(
    BASE_MODEL,
    no_repeat_ngram_size=0,
    max_length=MAX_SOURCE_LENGTH,
    early_stopping=False,
)

model = BartForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    config=model_config,
)

model.resize_token_embeddings(len(tokenizer))


Embedding(50268, 768)

## Preprocess Dataset


In [8]:
def load_table(db_id: str, table_name: str) -> pd.DataFrame:
    db_path = DATABASE_DIR.joinpath(f"{db_id}/{db_id}.sqlite")

    with sqlite3.connect(db_path) as con:
        return pd.read_sql(f"SELECT * FROM {table_name}", con)


# Example
load_table("customers_and_products_contacts", "products").head()


Unnamed: 0,product_id,product_type_code,product_name,product_price
0,1,Hardware,Apple,54753980.0
1,2,Clothes,jcrew,30590930.0
2,3,Hardware,Apple,10268.85
3,4,Hardware,Apple,22956670.0
4,5,Clothes,jcrew,5927022.0


In [9]:
def preprocess_table(df: pd.DataFrame) -> pd.DataFrame:
    df = df.rename(columns={col: col.lower() for col in df.columns})

    # The TAPEX tokenizer raises an error when the table contains non-str columns
    df = df.astype(str)

    for col_name, col_dtype in zip(df.columns, df.dtypes):
        df[col_name] = df[col_name].str.lower()

    return df


preprocess_table(load_table("customers_and_products_contacts", "products").head())


Unnamed: 0,product_id,product_type_code,product_name,product_price
0,1,hardware,apple,54753982.574522
1,2,clothes,jcrew,30590929.528306
2,3,hardware,apple,10268.85297069
3,4,hardware,apple,22956668.699482
4,5,clothes,jcrew,5927021.8748021


In [10]:
# functools.cache is supported in python3.9+, but use lru_cache to support python3.7+
@functools.lru_cache(maxsize=None)
def load_and_preprocess_table(db_id: str, table_name: str) -> pd.DataFrame:
    table = load_table(db_id, table_name)
    table = preprocess_table(table)

    return table


In [11]:
def preprocess_dataset(example: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    table = load_and_preprocess_table(example["db_id"], example["table"])

    query = example["query"]
    answer = example["vega_zero"]

    model_inputs = tokenizer(
        table=table,
        query=query,
        answer=answer,
        max_length=MAX_SOURCE_LENGTH,
        padding=True,
        truncation=True,
    )

    labels = tokenizer(
        answer=answer,
        max_length=MAX_TARGET_LENGTH,
        padding=True,
        truncation=True,
    )

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs


In [12]:
def preprocess_query(*args, **kwargs) -> str:
    args = (str(arg) for arg in args)
    args = " [arg] ".join(args)
    args = f"[arg] {args}"

    kwargs = (f"{k} [eq] {v}" for k, v in kwargs.items())
    kwargs = " [kwarg] ".join(kwargs)
    kwargs = f"[kwarg] {kwargs}"

    return f"{args} {kwargs}".lower()


In [13]:
def load_vxnli_dataset(subset: str) -> Dataset:
    # datasets.load_dataset("json", PATH) raises an json parse error
    # this is probably because it cannot parse the args and kwargs columns (list and dict types) well

    df = pd.read_json(DATASET_DIR.joinpath(f"{subset}.ndjson"), lines=True)
    df["query"] = df.apply(
        lambda row: preprocess_query(*row["args"], **row["kwargs"]), axis=1
    )
    df = df.drop(columns=["args", "kwargs"])

    return Dataset.from_pandas(df)


In [14]:
if DATASET_OUTPUT_DIR.exists():
    # load_from_dist doesn't support pathlib.Path
    dataset = datasets.load_from_disk(str(DATASET_OUTPUT_DIR))
else:
    dataset = DatasetDict()

    dataset["train"] = load_vxnli_dataset("train")
    dataset["test"] = load_vxnli_dataset("test")
    dataset["validation"] = load_vxnli_dataset("val")

    dataset = dataset.map(
        preprocess_dataset,
        batched=False,
        num_proc=multiprocessing.cpu_count(),
    )

    # save_to_disk doesn't support pathlib.Path
    dataset.save_to_disk(str(DATASET_OUTPUT_DIR))

dataset


DatasetDict({
    train: Dataset({
        features: ['db_id', 'table', 'chart', 'hardness', 'vega_zero', 'query', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1260
    })
    test: Dataset({
        features: ['db_id', 'table', 'chart', 'hardness', 'vega_zero', 'query', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 270
    })
    validation: Dataset({
        features: ['db_id', 'table', 'chart', 'hardness', 'vega_zero', 'query', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 270
    })
})

## Train Model


In [15]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=None,
)


In [16]:
exact_match = evaluate.load("exact_match")


def compute_metrics(eval_pred: EvalPrediction):
    preds, labels = eval_pred

    preds = tokenizer.batch_decode(
        preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = tokenizer.batch_decode(
        labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    return exact_match.compute(predictions=preds, references=labels)


In [17]:
trainer = Seq2SeqTrainer(
    model=model,
    args=Seq2SeqTrainingArguments(
        output_dir=MODEL_OUTPUT_DIR,
        predict_with_generate=True,
        num_train_epochs=50,
        evaluation_strategy="epoch",
        logging_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        load_best_model_at_end=True,
        do_eval=True,
        metric_for_best_model="exact_match",
        push_to_hub=push_model_to_huggingface_hub,
        report_to="wandb" if report_to_wandb else "none",
    ),
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=5),
    ],
)


Cloning https://huggingface.co/kwkty/vxnli-v1 into local empty directory.


In [18]:
trainer.train()


The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: table, db_id, chart, vega_zero, query, hardness. If table, db_id, chart, vega_zero, query, hardness are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1260
  Num Epochs = 50
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 7900
  Number of trainable parameters = 139422720
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mkwkty[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Exact Match
1,0.8075,0.29903,0.274074
2,0.1626,0.219206,0.525926
3,0.085,0.234584,0.533333
4,0.0534,0.223286,0.588889
5,0.0371,0.219652,0.562963
6,0.0305,0.244793,0.574074
7,0.0245,0.246876,0.611111
8,0.0177,0.254745,0.592593
9,0.0159,0.249572,0.581481
10,0.0153,0.293903,0.559259


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: table, db_id, chart, vega_zero, query, hardness. If table, db_id, chart, vega_zero, query, hardness are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 270
  Batch size = 8
Saving model checkpoint to ../data/models/vxnli-v1/checkpoint-158
Configuration saved in ../data/models/vxnli-v1/checkpoint-158/config.json
Model weights saved in ../data/models/vxnli-v1/checkpoint-158/pytorch_model.bin
tokenizer config file saved in ../data/models/vxnli-v1/checkpoint-158/tokenizer_config.json
Special tokens file saved in ../data/models/vxnli-v1/checkpoint-158/special_tokens_map.json
added tokens file saved in ../data/models/vxnli-v1/checkpoint-158/added_tokens.json
tokenizer config file saved in ../data/models/vxnli-v1/tokenizer_config.json
Special tokens file save

TrainOutput(global_step=1896, training_loss=0.10589330501948731, metrics={'train_runtime': 1223.285, 'train_samples_per_second': 51.501, 'train_steps_per_second': 6.458, 'total_flos': 8911117321666560.0, 'train_loss': 0.10589330501948731, 'epoch': 12.0})

## Evaluate Model


In [19]:
# trainer.evaluate must be called for the model card

trainer.evaluate(dataset["test"])


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: table, db_id, chart, vega_zero, query, hardness. If table, db_id, chart, vega_zero, query, hardness are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 270
  Batch size = 8


{'eval_loss': 0.25023138523101807,
 'eval_exact_match': 0.5444444444444444,
 'eval_runtime': 24.8449,
 'eval_samples_per_second': 10.867,
 'eval_steps_per_second': 1.368,
 'epoch': 12.0}

In [20]:
def predict(ds: Dataset) -> List[str]:
    preds = trainer.predict(
        ds,
        max_length=MAX_TARGET_LENGTH,
    )

    preds = tokenizer.batch_decode(
        preds.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    return [pred.strip() for pred in preds]


In [21]:
preds = predict(dataset["test"])

preds[:5], dataset["test"]["vega_zero"][:5]


The following columns in the test set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: table, db_id, chart, vega_zero, query, hardness. If table, db_id, chart, vega_zero, query, hardness are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 270
  Batch size = 8


(['mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark point encoding x investor_id y aggregate mean share_count transform group x',
  'mark point encoding x investor_id y aggregate mean share_count transform group x'],
 ['mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark point encoding x investor_id y aggregate mean share_count transform group x',
  'mark point encoding x investor_id y aggregate mean share_count transform group x'])

In [22]:
exact_match.compute(
    predictions=preds,
    references=dataset["test"]["vega_zero"],
)


{'exact_match': 0.5407407407407407}

In [23]:
preds_df = dataset["test"].to_pandas()
preds_df = preds_df.drop(columns=["input_ids", "attention_mask", "labels"])
preds_df["pred"] = preds
preds_df["exact_matched"] = preds_df["pred"] == preds_df["vega_zero"]

preds_df.to_csv(RESULT_OUTPUT_DIR.joinpath("prediction.csv"))

preds_df


Unnamed: 0,db_id,table,chart,hardness,vega_zero,query,pred,exact_matched
0,candidate_poll,people,bar,Easy,mark bar encoding x name y aggregate none weig...,[arg] [kwarg] use_bar_chart [eq] true [kwarg]...,mark bar encoding x name y aggregate none weig...,True
1,candidate_poll,people,bar,Easy,mark bar encoding x name y aggregate none weig...,[arg] use a bar chart [kwarg] x [eq] name [kwa...,mark bar encoding x name y aggregate none weig...,True
2,candidate_poll,people,bar,Easy,mark bar encoding x name y aggregate none weig...,[arg] [kwarg] graph [eq] bar [kwarg] x [eq] n...,mark bar encoding x name y aggregate none weig...,True
3,tracking_share_transactions,transactions,point,Easy,mark point encoding x investor_id y aggregate ...,[arg] scatter chart [arg] investor id and mean...,mark point encoding x investor_id y aggregate ...,True
4,tracking_share_transactions,transactions,point,Easy,mark point encoding x investor_id y aggregate ...,[arg] [kwarg] graph_type [eq] scatter [kwarg]...,mark point encoding x investor_id y aggregate ...,True
...,...,...,...,...,...,...,...,...
265,tracking_share_transactions,transactions,Line,Medium,mark line encoding x date_of_transaction y agg...,[arg] [kwarg] time_axis [eq] date_of_transact...,mark line encoding x date_of_transaction y agg...,False
266,tracking_share_transactions,transactions,Line,Medium,mark line encoding x date_of_transaction y agg...,[arg] show me a trend [kwarg] x [eq] date_of_t...,mark line encoding x date_of_transaction y agg...,False
267,customers_and_invoices,financial_transactions,Bar,Medium,mark bar encoding x transaction_type y aggrega...,[arg] show the transaction types and the total...,mark bar encoding x transaction_type y aggrega...,False
268,customers_and_invoices,financial_transactions,Bar,Medium,mark bar encoding x transaction_type y aggrega...,[arg] [kwarg] x [eq] type [kwarg] y [eq] amou...,mark bar encoding x transaction_type y aggrega...,False


In [24]:
pd.concat(
    [
        preds_df[preds_df["hardness"] == hardness]["exact_matched"]
        .value_counts()
        .rename(hardness)
        for hardness in ("Easy", "Medium", "Hard", "Extra Hard")
    ],
    axis=1,
)


Unnamed: 0,Easy,Medium,Hard,Extra Hard
True,52,80,13,1
False,29,52,26,17


In [25]:
pd.concat(
    [
        preds_df[preds_df["chart"] == chart]["exact_matched"]
        .value_counts()
        .rename(chart)
        for chart in preds_df["chart"].unique()
    ],
    axis=1,
)


Unnamed: 0,bar,point,arc,line,Bar,Stacked Bar,Line
True,96,12,12,11,12,2,1
False,63,18,3,10,24,1,5


## Complete Training


In [26]:
if report_to_wandb:
    wandb.finish()


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/exact_match,▁▆▆█▇▇██▇▇▇▇▇
eval/loss,█▁▂▁▁▃▃▄▄█▇▄▄
eval/runtime,▄▂▁▂▄█▂▃▂▃▂▃▂
eval/samples_per_second,▅▆█▇▅▁▇▆▇▅▇▆▇
eval/steps_per_second,▅▆█▇▅▁▇▆▇▅▇▆▇
train/epoch,▁▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▆▆▇▇▇▇████
train/global_step,▁▁▂▂▂▂▃▃▄▄▄▄▅▅▅▅▆▆▇▇▇▇████
train/learning_rate,█▇▇▆▅▅▄▄▃▂▂▁
train/loss,█▂▂▁▁▁▁▁▁▁▁▁
train/total_flos,▁

0,1
eval/exact_match,0.54444
eval/loss,0.25023
eval/runtime,24.8449
eval/samples_per_second,10.867
eval/steps_per_second,1.368
train/epoch,12.0
train/global_step,1896.0
train/learning_rate,4e-05
train/loss,0.0096
train/total_flos,8911117321666560.0


In [27]:
if push_model_to_huggingface_hub:
    trainer.push_to_hub()


Saving model checkpoint to ../data/models/vxnli-v1
Configuration saved in ../data/models/vxnli-v1/config.json
Model weights saved in ../data/models/vxnli-v1/pytorch_model.bin
tokenizer config file saved in ../data/models/vxnli-v1/tokenizer_config.json
Special tokens file saved in ../data/models/vxnli-v1/special_tokens_map.json
added tokens file saved in ../data/models/vxnli-v1/added_tokens.json
remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/kwkty/vxnli-v1
   609b05e..94c14ff  main -> main

Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Sequence-to-sequence Language Modeling', 'type': 'text2text-generation'}}
To https://huggingface.co/kwkty/vxnli-v1
   94c14ff..1c5562d  main -> main

