# 08. Data Augmentation

TODO: Cleanup Documents & Codes
TODO: Share Results with wandb (I was too cornered due to the internal deadline)

Results

| Params | Exact Match [%] |
| :---: | :---: |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=False BASE_MODEL=model-v0 Augmentation=x1 | 0.717 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=True BASE_MODEL=tapex Augmentation=x1 | 0.533 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=True BASE_MODEL=tapex Augmentation=x0 | 0.470 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=False BASE_MODEL=tapex Augmentation=x1  | 0.7074074074074074 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=False BASE_MODEL=tapex Augmentation=x0.5 | 0.6444444444444445 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=False BASE_MODEL=tapex Augmentation=x1.5 | 0.6296296296296297 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=False BASE_MODEL=model-v0 Augmentation=x0.5 | 0.6148148148148148 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=False BASE_MODEL=model-v0 Augmentation=x1.5 | 0.6666666666666666 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=True BASE_MODEL=tapex Augmentation=x0.5 | 0.5740740740740741 |
| REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF=True BASE_MODEL=tapex Augmentation=x1.5 | 0.5518518518518518 |


## Setup

### Define Parameters


In [1]:
data_dir: str = "../data/"
push_model_to_huggingface_hub: bool = True
report_to_wandb: bool = False


### Load Modules

In [2]:
import functools
import multiprocessing
import os
import pandas as pd
import sqlite3

from datetime import datetime
from itertools import product
from pathlib import Path
from random import Random
from typing import Any, Dict, List, Optional, Tuple, Union

import evaluate
import datasets
import numpy as np
import torch
import transformers
import wandb

from datasets import Dataset, DatasetDict
from transformers import (
    BartConfig,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    EvalPrediction,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TapexTokenizer,
    trainer_utils,
)

from vxnli._vega_zero import VegaZero, VegaZeroTransform


### Define Variables

In [3]:
RANDOM_SEED: int = 123

# Paths

DATA_DIR: Path = Path(data_dir)

MODEL_NAME: str = "vxnli-v2beta1"

DATABASE_DIR: Path = DATA_DIR.joinpath("datasets/nvBench/database")
DATASET_DIR: Path = DATA_DIR.joinpath(f"datasets/vxnli-v1/")
DATASET_OUTPUT_DIR: Path = DATA_DIR.joinpath(f"datasets/{MODEL_NAME}.hf/")

MODEL_OUTPUT_DIR: Path = DATA_DIR.joinpath(f"models/{MODEL_NAME}/")
RESULT_OUTPUT_DIR: Path = DATA_DIR.joinpath(f"results/{MODEL_NAME}/")

# Model Parameters
# BASE_MODEL: str = "microsoft/tapex-base-finetuned-wtq"
BASE_MODEL: str = "kwkty/vxnli-v0"

MAX_SOURCE_LENGTH: int = 1024
MAX_TARGET_LENGTH: int = 124

# Data Augmentation Parameters

REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF: bool = False

INSERT_FILTER_FRAC: float = 1.0  # 1.0

INSERT_FILTER_KWARG_FRAC: float = 0.75
SHUFFLE_FRAC: float = 0.0


In [4]:
RESULT_OUTPUT_DIR.mkdir(exist_ok=True)

transformers.set_seed(RANDOM_SEED)


### Load Datasets

In [5]:
train_df = pd.read_json(DATASET_DIR.joinpath("train.ndjson"), lines=True)
test_df = pd.read_json(DATASET_DIR.joinpath("test.ndjson"), lines=True)
val_df = pd.read_json(DATASET_DIR.joinpath("val.ndjson"), lines=True)


In [6]:
if REMOVE_FILTER_EXAMPLES_FROM_TRAIN_DF:
    train_df = train_df[~train_df["vega_zero"].str.contains(" filter ")]


### Load Tokenizer


In [7]:
tokenizer = TapexTokenizer.from_pretrained(
    BASE_MODEL, use_fast=True, add_prefix_space=True
)

tokenizer.add_special_tokens(
    {"additional_special_tokens": ["[arg]", "[kwarg]", "[eq]"]}
)


Downloading:   0%|          | 0.00/999k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/957 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

3

### Load Model

In [8]:
model_config = BartConfig.from_pretrained(
    BASE_MODEL,
    no_repeat_ngram_size=0,
    max_length=MAX_SOURCE_LENGTH,
    early_stopping=False,
)

model = BartForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    config=model_config,
)

model.resize_token_embeddings(len(tokenizer))


Downloading:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/558M [00:00<?, ?B/s]

Embedding(50268, 768)

## Augment Data

In [9]:
def load_table(db_id: str, table_name: str) -> pd.DataFrame:
    db_path = DATABASE_DIR.joinpath(f"{db_id}/{db_id}.sqlite")

    with sqlite3.connect(db_path) as con:
        return pd.read_sql(f"SELECT * FROM {table_name}", con)


In [10]:
def get_column_candidates(db_id: str, table: str) -> List[Tuple[str, int, int]]:
    df = load_table(db_id, table)

    columns = [
        (col.lower(), dtype, df[col].unique().tolist())
        for col, dtype in zip(df.columns, df.dtypes)
        if pd.api.types.is_string_dtype(dtype) or pd.api.types.is_integer_dtype(dtype)
    ]

    return columns


COLUMN_CANDIDATES = {
    (db_id, table): get_column_candidates(db_id, table)
    for _, db_id, table in train_df[["db_id", "table"]].drop_duplicates().itertuples()
}


In [11]:
test_df["vega_zero"][test_df["vega_zero"].str.contains(" topk ")]


Series([], Name: vega_zero, dtype: object)

### Compose Filter Examples

In [12]:
FILTER_KEYWORDS = [
    "filter",
    "condition",
    "cond",
    "where",
    "if_",
    "when",
]

INT_FILTER_PATTERNS = {
    "{col} < {val1}": [
        "{col} < {val1}",
        "{col} is lower than {val1}",
        "{col} is smaller than {val1}",
    ],
    "{col} <= {val1}": [
        "{col} <= {val1}",
        "{col} is lower than or equal to {val1}",
        "{col} is smaller than or equal to {val1}",
    ],
    "{col} > {val1}": [
        "{col} > {val1}",
        "{col} is bigger than {val1}",
        "{col} is greater than {val1}",
    ],
    "{col} >= {val1}": [
        "{col} >= {val1}",
        "{col} is bigger than or equal to {val1}",
        "{col} is greater than or equal to {val1}",
    ],
    "{col} = {val1}": [
        "{col} = {val1}",
        "{col} == {val1}",
        "{col} is {val1}",
    ],
    "{col} != {val1}": [
        "{col} != {val1}",
        "{col} is not {val1}",
        "{col} does not equal to {val1}",
    ],
    "{col} between {val1} and {val2}": [
        "{col} between {val1} and {val2}",
        "{col} in range({val1}, {val2})",
        "{col} <- [{val1}, {val2}]",
    ],
}

STR_FILTER_PATTERNS = {
    '{col} = "{val1}"': [
        '{col} = "{val1}"',
        '{col} == "{val1}"',
        '{col} is "{val1}"',
    ],
    '{col} != "{val1}"': [
        '{col} != "{val1}"',
        '{col} is not "{val1}"',
        '{col} does not equal to "{val1}"',
    ],
}

DOUBLE_FILTER_PATTERNS = {
    "{cond1} and {cond2}": [
        "{cond1} and {cond2}",
        "{cond1} && {cond2}",
        "{cond1} & {cond2}",
    ],
    "{cond1} or {cond2}": [
        "{cond1} or {cond2}",
        "{cond1} || {cond2}",
        "{cond1} | {cond2}",
    ],
}

TRIPLE_FILTER_PATTERNS = {
    "{cond1} and {cond2} and {cond3}": [
        "{cond1} and {cond2} and {cond3}",
        "{cond1} && {cond2} && {cond3}",
        "{cond1} & {cond2} & {cond3}",
    ],
    "{cond1} and {cond2} or {cond3}": [
        "{cond1} and {cond2} or {cond3}",
        "{cond1} && {cond2} || {cond3}",
        "{cond1} & {cond2} | {cond3}",
    ],
    "{cond1} or {cond2} and {cond3}": [
        "{cond1} or {cond2} and {cond3}",
        "{cond1} || {cond2} && {cond3}",
        "{cond1} | {cond2} & {cond3}",
    ],
    "{cond1} or+ {cond2} or {cond3}": [
        "{cond1} or {cond2} or {cond3}",
        "{cond1} || {cond2} || {cond3}",
        "{cond1} | {cond2} | {cond3}",
    ],
}


In [13]:
def insert_filter_to_vega_zero(
    vega_zero: str,
    filter_: str,
) -> str:
    vega_zero = VegaZero.parse(vega_zero)

    if vega_zero.transform is not None:
        transform = vega_zero.transform
    else:
        transform = VegaZeroTransform()

    transform.filter = filter_

    return str(vega_zero)


def insert_filter_to_args_or_kwargs(
    args: list, kwargs: dict, filter_arg: str, rand: Random
):
    use_kwarg = rand.random() < INSERT_FILTER_KWARG_FRAC

    if use_kwarg:
        kwargs = list(kwargs.items())

        i = rand.randrange(len(kwargs) + 1)
        k = rand.choice(FILTER_KEYWORDS)

        kwargs.insert(i, (k, filter_arg))

        kwargs = dict(kwargs)
    else:
        # Copy args to change the original argument (just in case)
        args = [*args]

        i = rand.randrange(len(args) + 1)

        args.insert(i, filter_arg)

    return args, kwargs


def compose_single_filter(
    column: str, dtype: Any, values: List[Any], rand: Random
) -> Tuple[str, str]:
    [val1, val2] = rand.choices(values, k=2)

    if pd.api.types.is_integer_dtype(dtype):
        patterns = INT_FILTER_PATTERNS

        if val1 > val2:
            val1, val2 = val2, val1

    elif pd.api.types.is_string_dtype(dtype):
        patterns = STR_FILTER_PATTERNS
    else:
        raise TypeError(f"Unexpected dtype: {dtype}")

    vega_zero_pattern = rand.choice(list(patterns.keys()))
    arg_pattern = rand.choice(patterns[vega_zero_pattern])

    vega_zero = vega_zero_pattern.format(col=column, val1=val1, val2=val2)
    arg = arg_pattern.format(col=column, val1=val1, val2=val2)

    return vega_zero, arg


def compose_filter(db_id: str, table: str, rand: Random) -> str:
    n_filters = rand.choice([1, 2, 3])

    candidates = rand.choices(COLUMN_CANDIDATES[(db_id, table)], k=n_filters)

    filters = [
        compose_single_filter(col, dtype, values, rand)
        for col, dtype, values in candidates
    ]

    if n_filters == 1:
        vega_zero_filter, arg = filters[0]
    elif n_filters == 2:
        vega_zero_filter = rand.choice(list(DOUBLE_FILTER_PATTERNS.keys()))
        arg = rand.choice(DOUBLE_FILTER_PATTERNS[vega_zero_filter])

        vega_zero_filter = vega_zero_filter.format(
            cond1=filters[0][0], cond2=filters[1][0]
        )
        arg = arg.format(cond1=filters[0][1], cond2=filters[1][1])
    else:
        vega_zero_filter = rand.choice(list(TRIPLE_FILTER_PATTERNS.keys()))
        arg = rand.choice(TRIPLE_FILTER_PATTERNS[vega_zero_filter])

        vega_zero_filter = vega_zero_filter.format(
            cond1=filters[0][0], cond2=filters[1][0], cond3=filters[2][0]
        )
        arg = arg.format(cond1=filters[0][1], cond2=filters[1][1], cond3=filters[2][0])

    return vega_zero_filter, arg


def augment_filter_data(df: pd.DataFrame, frac: float) -> pd.DataFrame:
    rand = Random(RANDOM_SEED)

    df = df[~df["vega_zero"].str.contains(" filter ")]

    df = df.copy()

    df = df.sample(frac=frac, replace=True, random_state=RANDOM_SEED)

    # df.apply raises an error when frac is small and the number of rows is 0
    if len(df) == 0:
        return df

    df[["_filter", "_arg"]] = df[["db_id", "table"]].apply(
        lambda row: compose_filter(row[0], row[1], rand), axis=1, result_type="expand"
    )

    df["vega_zero"] = df[["vega_zero", "_filter"]].apply(
        lambda row: insert_filter_to_vega_zero(row[0], row[1]),
        axis=1,
    )

    df[["args", "kwargs"]] = df[["args", "kwargs", "_filter"]].apply(
        lambda row: insert_filter_to_args_or_kwargs(
            row["args"], row["kwargs"], row["_filter"], rand
        ),
        axis=1,
        result_type="expand",
    )

    df = df.drop(columns=["_filter", "_arg"])

    return df


augmented_filter_train_df = augment_filter_data(train_df, INSERT_FILTER_FRAC)

augmented_filter_train_df


Unnamed: 0,db_id,table,chart,hardness,vega_zero,args,kwargs
708,swimming,swimmer,bar,Easy,mark bar encoding x name y aggregate none id,[],"{'x': 'name', 'filter': 'nationality = ""Canada..."
434,sports_competition,competition,bar,Medium,mark bar encoding x competition_type y aggrega...,[],"{'y': 'count', 'where': 'competition_type = ""F..."
460,manufactory_1,products,bar,Medium,mark bar encoding x name y aggregate count nam...,"[count names, price > 66, sort name in alphabet]",{}
382,e_learning,student_course_enrolment,bar,Medium,mark bar encoding x date_of_completion y aggre...,[count of records],{'cond': 'course_id >= 1 or+ course_id = 9 or ...
101,university_basketball,basketball_match,arc,Easy,mark arc encoding x all_neutral y aggregate no...,[arc],"{'filter': 'all_games != ""28–6"" or all_games =..."
...,...,...,...,...,...,...,...
800,university_basketball,university,arc,Easy,mark arc encoding x affiliation y aggregate su...,[show a pie chart],"{'when': 'school_id > 1', 'color': 'affiliatio..."
71,university_basketball,basketball_match,point,Easy,mark point encoding x team_id y aggregate none...,"[acc_home = ""7–1"" or team_name = ""Virginia Tech""]","{'chart': 'scatter', 'x': 'team id', 'y': 'all..."
484,college_2,section,line,Hard,mark line encoding x year y aggregate count ye...,[time series],"{'time': 'year', 'value': 'count', 'filter': '..."
457,university_basketball,basketball_match,bar,Medium,mark bar encoding x all_home y aggregate mean ...,[draw hist],"{'x_axis': 'all home', 'y_axis': 'mean team_id..."


### Shuffle Arguments

In [14]:
def augment_shuffled_data(
    df: pd.DataFrame,
    frac: float,
) -> pd.DataFrame:
    rand = Random(RANDOM_SEED)

    df = df.copy()

    df = df.sample(frac=frac, replace=True, random_state=RANDOM_SEED)
    df["args"] = df["args"].apply(lambda args: rand.sample(args, len(args)))
    df["kwargs"] = df["kwargs"].apply(
        lambda kwargs: dict(rand.sample(list(kwargs.items()), len(kwargs)))
    )

    df = df.reset_index(drop=True)

    return df


augmented_shuffle_train_df = augment_shuffled_data(train_df, SHUFFLE_FRAC)

augmented_shuffle_train_df


Unnamed: 0,db_id,table,chart,hardness,vega_zero,args,kwargs


### Concatenate Augmented Data

In [15]:
augmented_train_df = pd.concat([augmented_filter_train_df, augmented_shuffle_train_df])

augmented_train_df = augmented_train_df.reset_index(drop=True)
augmented_train_df


Unnamed: 0,db_id,table,chart,hardness,vega_zero,args,kwargs
0,swimming,swimmer,bar,Easy,mark bar encoding x name y aggregate none id,[],"{'x': 'name', 'filter': 'nationality = ""Canada..."
1,sports_competition,competition,bar,Medium,mark bar encoding x competition_type y aggrega...,[],"{'y': 'count', 'where': 'competition_type = ""F..."
2,manufactory_1,products,bar,Medium,mark bar encoding x name y aggregate count nam...,"[count names, price > 66, sort name in alphabet]",{}
3,e_learning,student_course_enrolment,bar,Medium,mark bar encoding x date_of_completion y aggre...,[count of records],{'cond': 'course_id >= 1 or+ course_id = 9 or ...
4,university_basketball,basketball_match,arc,Easy,mark arc encoding x all_neutral y aggregate no...,[arc],"{'filter': 'all_games != ""28–6"" or all_games =..."
...,...,...,...,...,...,...,...
868,university_basketball,university,arc,Easy,mark arc encoding x affiliation y aggregate su...,[show a pie chart],"{'when': 'school_id > 1', 'color': 'affiliatio..."
869,university_basketball,basketball_match,point,Easy,mark point encoding x team_id y aggregate none...,"[acc_home = ""7–1"" or team_name = ""Virginia Tech""]","{'chart': 'scatter', 'x': 'team id', 'y': 'all..."
870,college_2,section,line,Hard,mark line encoding x year y aggregate count ye...,[time series],"{'time': 'year', 'value': 'count', 'filter': '..."
871,university_basketball,basketball_match,bar,Medium,mark bar encoding x all_home y aggregate mean ...,[draw hist],"{'x_axis': 'all home', 'y_axis': 'mean team_id..."


## Preprocess Dataset


In [16]:
def preprocess_table(df: pd.DataFrame) -> pd.DataFrame:
    df = df.rename(columns={col: col.lower() for col in df.columns})

    # The TAPEX tokenizer raises an error when the table contains non-str columns
    df = df.astype(str)

    for col_name, col_dtype in zip(df.columns, df.dtypes):
        df[col_name] = df[col_name].str.lower()

    return df


In [17]:
# functools.cache is supported in python3.9+, but use lru_cache to support python3.7+
@functools.lru_cache(maxsize=None)
def load_and_preprocess_table(db_id: str, table_name: str) -> pd.DataFrame:
    table = load_table(db_id, table_name)
    table = preprocess_table(table)

    return table


In [18]:
def preprocess_dataset(example: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    table = load_and_preprocess_table(example["db_id"], example["table"])

    query = example["query"]
    answer = example["vega_zero"]

    model_inputs = tokenizer(
        table=table,
        query=query,
        answer=answer,
        max_length=MAX_SOURCE_LENGTH,
        padding=True,
        truncation=True,
    )

    labels = tokenizer(
        answer=answer,
        max_length=MAX_TARGET_LENGTH,
        padding=True,
        truncation=True,
    )

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs


In [19]:
def preprocess_query(*args, **kwargs) -> str:
    args = (str(arg) for arg in args)
    args = " [arg] ".join(args)
    args = f"[arg] {args}"

    kwargs = (f"{k} [eq] {v}" for k, v in kwargs.items())
    kwargs = " [kwarg] ".join(kwargs)
    kwargs = f"[kwarg] {kwargs}"

    return f"{args} {kwargs}".lower()


In [20]:
train_df["query"] = train_df.apply(
    lambda row: preprocess_query(*row["args"], **row["kwargs"]), axis=1
)

if len(augmented_train_df) > 0:
    augmented_train_df["query"] = augmented_train_df.apply(
        lambda row: preprocess_query(*row["args"], **row["kwargs"]), axis=1
    )

test_df["query"] = test_df.apply(
    lambda row: preprocess_query(*row["args"], **row["kwargs"]), axis=1
)

val_df["query"] = val_df.apply(
    lambda row: preprocess_query(*row["args"], **row["kwargs"]), axis=1
)


In [21]:
dataset = DatasetDict()

dataset["train"] = Dataset.from_pandas(
    pd.concat([train_df, augmented_train_df])
    .reset_index(drop=True)
    .drop(columns=["args", "kwargs"])
)
dataset["test"] = Dataset.from_pandas(test_df.drop(columns=["args", "kwargs"]))
dataset["validation"] = Dataset.from_pandas(val_df.drop(columns=["args", "kwargs"]))

dataset = dataset.map(
    preprocess_dataset,
    batched=False,
    num_proc=multiprocessing.cpu_count(),
)

dataset


                 

#0:   0%|          | 0/178 [00:00<?, ?ex/s]

     

#3:   0%|          | 0/178 [00:00<?, ?ex/s]

#1:   0%|          | 0/178 [00:00<?, ?ex/s]

#2:   0%|          | 0/178 [00:00<?, ?ex/s]

  

#4:   0%|          | 0/178 [00:00<?, ?ex/s]

#6:   0%|          | 0/178 [00:00<?, ?ex/s]

#5:   0%|          | 0/178 [00:00<?, ?ex/s]

#8:   0%|          | 0/178 [00:00<?, ?ex/s]

#7:   0%|          | 0/178 [00:00<?, ?ex/s]

#9:   0%|          | 0/177 [00:00<?, ?ex/s]

#10:   0%|          | 0/177 [00:00<?, ?ex/s]

#11:   0%|          | 0/177 [00:00<?, ?ex/s]

                    

#0:   0%|          | 0/23 [00:00<?, ?ex/s]

    

#3:   0%|          | 0/23 [00:00<?, ?ex/s]

#5:   0%|          | 0/23 [00:00<?, ?ex/s]

#8:   0%|          | 0/22 [00:00<?, ?ex/s]

#2:   0%|          | 0/23 [00:00<?, ?ex/s]

#1:   0%|          | 0/23 [00:00<?, ?ex/s]

#6:   0%|          | 0/22 [00:00<?, ?ex/s]

#9:   0%|          | 0/22 [00:00<?, ?ex/s]

#4:   0%|          | 0/23 [00:00<?, ?ex/s]

#10:   0%|          | 0/22 [00:00<?, ?ex/s]

#11:   0%|          | 0/22 [00:00<?, ?ex/s]

#7:   0%|          | 0/22 [00:00<?, ?ex/s]

                    

#1:   0%|          | 0/23 [00:00<?, ?ex/s]

#5:   0%|          | 0/23 [00:00<?, ?ex/s]

 

#2:   0%|          | 0/23 [00:00<?, ?ex/s]

 

#0:   0%|          | 0/23 [00:00<?, ?ex/s]

#3:   0%|          | 0/23 [00:00<?, ?ex/s]

#8:   0%|          | 0/22 [00:00<?, ?ex/s]

#4:   0%|          | 0/23 [00:00<?, ?ex/s]

  

#6:   0%|          | 0/22 [00:00<?, ?ex/s]

#9:   0%|          | 0/22 [00:00<?, ?ex/s]

#10:   0%|          | 0/22 [00:00<?, ?ex/s]

#11:   0%|          | 0/22 [00:00<?, ?ex/s]

#7:   0%|          | 0/22 [00:00<?, ?ex/s]

DatasetDict({
    train: Dataset({
        features: ['db_id', 'table', 'chart', 'hardness', 'vega_zero', 'query', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 2133
    })
    test: Dataset({
        features: ['db_id', 'table', 'chart', 'hardness', 'vega_zero', 'query', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 270
    })
    validation: Dataset({
        features: ['db_id', 'table', 'chart', 'hardness', 'vega_zero', 'query', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 270
    })
})

## Train Model


In [22]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,
    pad_to_multiple_of=None,
)


In [23]:
exact_match = evaluate.load("exact_match")


def compute_metrics(eval_pred: EvalPrediction):
    preds, labels = eval_pred

    preds = tokenizer.batch_decode(
        preds, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    labels = tokenizer.batch_decode(
        labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    return exact_match.compute(predictions=preds, references=labels)


Downloading builder script:   0%|          | 0.00/5.67k [00:00<?, ?B/s]

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=Seq2SeqTrainingArguments(
        output_dir=MODEL_OUTPUT_DIR,
        predict_with_generate=True,
        num_train_epochs=50,
        evaluation_strategy="epoch",
        logging_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        load_best_model_at_end=True,
        do_eval=True,
        metric_for_best_model="exact_match",
        push_to_hub=push_model_to_huggingface_hub,
        report_to="wandb" if report_to_wandb else "none",
    ),
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=5),
    ],
)


In [None]:
trainer.train()


The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: chart, db_id, vega_zero, query, hardness, table. If chart, db_id, vega_zero, query, hardness, table are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 2183
  Num Epochs = 50
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 13650
  Number of trainable parameters = 139422720


Epoch,Training Loss,Validation Loss,Exact Match
1,0.4184,0.318852,0.418519
2,0.066,0.302561,0.455556
3,0.0366,0.337232,0.522222
4,0.0196,0.329871,0.574074
5,0.0159,0.31989,0.533333
6,0.0113,0.362631,0.540741
7,0.0125,0.364245,0.544444
8,0.0082,0.370134,0.562963
9,0.005,0.37455,0.592593
10,0.0077,0.413344,0.533333


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: chart, db_id, vega_zero, query, hardness, table. If chart, db_id, vega_zero, query, hardness, table are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 270
  Batch size = 8
Saving model checkpoint to ../data/models/vxnli-v2beta10/checkpoint-273
Configuration saved in ../data/models/vxnli-v2beta10/checkpoint-273/config.json
Model weights saved in ../data/models/vxnli-v2beta10/checkpoint-273/pytorch_model.bin
tokenizer config file saved in ../data/models/vxnli-v2beta10/checkpoint-273/tokenizer_config.json
Special tokens file saved in ../data/models/vxnli-v2beta10/checkpoint-273/special_tokens_map.json
added tokens file saved in ../data/models/vxnli-v2beta10/checkpoint-273/added_tokens.json
The following columns in the evaluation set don't have a correspo

TrainOutput(global_step=3822, training_loss=0.04471869168638372, metrics={'train_runtime': 1340.9279, 'train_samples_per_second': 81.399, 'train_steps_per_second': 10.18, 'total_flos': 1.579095420005376e+16, 'train_loss': 0.04471869168638372, 'epoch': 14.0})

## Evaluate Model


In [None]:
# trainer.evaluate must be called for the model card

trainer.evaluate(dataset["test"])


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: chart, db_id, vega_zero, query, hardness, table. If chart, db_id, vega_zero, query, hardness, table are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 270
  Batch size = 8


{'eval_loss': 0.42609551548957825,
 'eval_exact_match': 0.5518518518518518,
 'eval_runtime': 24.0048,
 'eval_samples_per_second': 11.248,
 'eval_steps_per_second': 1.416,
 'epoch': 14.0}

In [None]:
def predict(ds: Dataset) -> List[str]:
    preds = trainer.predict(
        ds,
        max_length=MAX_TARGET_LENGTH,
    )

    preds = tokenizer.batch_decode(
        preds.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    return [pred.strip() for pred in preds]


In [None]:
preds = predict(dataset["test"])

preds[:5], dataset["test"]["vega_zero"][:5]


The following columns in the test set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: chart, db_id, vega_zero, query, hardness, table. If chart, db_id, vega_zero, query, hardness, table are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 270
  Batch size = 8


(['mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark point encoding x investor_id y aggregate mean share_count transform group x',
  'mark point encoding x investor_id y aggregate mean share_count transform group x'],
 ['mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark bar encoding x name y aggregate none weight transform sort x asc',
  'mark point encoding x investor_id y aggregate mean share_count transform group x',
  'mark point encoding x investor_id y aggregate mean share_count transform group x'])

In [None]:
exact_match.compute(
    predictions=preds,
    references=dataset["test"]["vega_zero"],
)


{'exact_match': 0.5481481481481482}

In [None]:
preds_df = dataset["test"].to_pandas()
preds_df = preds_df.drop(columns=["input_ids", "attention_mask", "labels"])
preds_df["pred"] = preds
preds_df["exact_matched"] = preds_df["pred"] == preds_df["vega_zero"]

preds_df.to_csv(RESULT_OUTPUT_DIR.joinpath("prediction.csv"))

preds_df


Unnamed: 0,db_id,table,chart,hardness,vega_zero,query,pred,exact_matched
0,candidate_poll,people,bar,Easy,mark bar encoding x name y aggregate none weig...,[arg] [kwarg] use_bar_chart [eq] true [kwarg]...,mark bar encoding x name y aggregate none weig...,True
1,candidate_poll,people,bar,Easy,mark bar encoding x name y aggregate none weig...,[arg] use a bar chart [kwarg] x [eq] name [kwa...,mark bar encoding x name y aggregate none weig...,True
2,candidate_poll,people,bar,Easy,mark bar encoding x name y aggregate none weig...,[arg] [kwarg] graph [eq] bar [kwarg] x [eq] n...,mark bar encoding x name y aggregate none weig...,True
3,tracking_share_transactions,transactions,point,Easy,mark point encoding x investor_id y aggregate ...,[arg] scatter chart [arg] investor id and mean...,mark point encoding x investor_id y aggregate ...,True
4,tracking_share_transactions,transactions,point,Easy,mark point encoding x investor_id y aggregate ...,[arg] [kwarg] graph_type [eq] scatter [kwarg]...,mark point encoding x investor_id y aggregate ...,True
...,...,...,...,...,...,...,...,...
265,tracking_share_transactions,transactions,Line,Medium,mark line encoding x date_of_transaction y agg...,[arg] [kwarg] time_axis [eq] date_of_transact...,mark line encoding x date_of_transaction y agg...,False
266,tracking_share_transactions,transactions,Line,Medium,mark line encoding x date_of_transaction y agg...,[arg] show me a trend [kwarg] x [eq] date_of_t...,mark line encoding x date_of_transaction y agg...,True
267,customers_and_invoices,financial_transactions,Bar,Medium,mark bar encoding x transaction_type y aggrega...,[arg] show the transaction types and the total...,mark bar encoding x transaction_type y aggrega...,False
268,customers_and_invoices,financial_transactions,Bar,Medium,mark bar encoding x transaction_type y aggrega...,[arg] [kwarg] x [eq] type [kwarg] y [eq] amou...,mark bar encoding x transaction_type y aggrega...,True


## Complete Training


In [None]:
if report_to_wandb:
    wandb.finish()


In [None]:
if push_model_to_huggingface_hub:
    trainer.push_to_hub()
