In [None]:
# ===================== Training T5 Model ======================#
# ==============================================================#

import sys, os
from os import environ
from dotenv import load_dotenv
import logging
from logging.config import fileConfig

import datasets
import numpy as np
import evaluate
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelWithLMHead
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import Trainer, TrainingArguments, AdamW
from transformers import DataCollatorForSeq2Seq

# ==============================================================#
# ====================== GLOBAL VARIABLES ======================#

GlobalVar           = 0
train_batch_size    = 8
eval_batch_size     = 8
max_input_length    = 512
max_target_length   = 64

MODEL_BASE_DIR      = "/Users/sree/.cache/huggingface/hub"
IN_MODEL_NAME       = "mrm8488/t5-base-finetuned-wikiSQL"
OUT_MODEL_NAME      = "my--t5-base-finetuned-wiki-to-SQL" ##-finetuned-{source_lang}-to-{target_lang}"
OUTPUT_DIR          = MODEL_BASE_DIR + "/models--" + OUT_MODEL_NAME 
LOG_DIR             = OUTPUT_DIR + "/logs"

DS_LOCAL            = 'my_sql_data.json'
DS_HUGGINGFACE      = 'wikisql'
USE_LOCAL_DATASET   = False
DATASET             = DS_LOCAL if USE_LOCAL_DATASET else DS_HUGGINGFACE
USE_TKN_CACHE       = False
DATA_SAMPLING       = True
SAMPLING_SPLIT      = {'test':50, 'train':20, 'validation':20}
DATASET_PATH        = "./dataset/"
ENCODED_DATA_PATH   = "./cache/encoded-" + DATASET

# modify data path in case only samples are being processed
ENCODED_DATA_PATH   += "-sample" if DATA_SAMPLING else ""

INPUT_FIELD         = 'question'
RESULT_FIELD        = 'sql'
NESTED_RESULT_FIELD = 'human_readable'
TASK_PREFIX         = "translate English to SQL : "

# use_fast=True param to speed up tokenization
# initialize model & it's tokenizer
tokenizer           = AutoTokenizer.from_pretrained(IN_MODEL_NAME)
model               = AutoModelForSeq2SeqLM.from_pretrained(IN_MODEL_NAME)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ==============================================================#
# ====================== GLOBAL FUNCTIONS ======================#

def compute_metric(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    # Extract ROUGE f1 scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    # Add mean generated length to metrics
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]

    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

def get_training_args():
    seq_training_args = Seq2SeqTrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=3,
        save_total_limit=3,
        evaluation_strategy ="steps",
        eval_steps=100,
        logging_strategy="steps",
        logging_steps=100,
        save_strategy="steps",
        save_steps=200,
        
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        predict_with_generate=True,
        metric_for_best_model="rouge1",
        load_best_model_at_end=True,    

        weight_decay=0.01,
        learning_rate=4e-5,
        optim = "adamw_torch",
        report_to=["tensorboard"]
        #adam_beta1=optimizer,
        #fp16=True, # can be used with CUDA devices, not CPU
        #report_to="tensorboard" #still to make it work
    )
    return seq_training_args

def get_trainer(tokenized_dataset, training_args):
    seq_trainer = Seq2SeqTrainer(
        model           = model,
        args            = training_args,
        train_dataset   = tokenized_dataset["train"],
        eval_dataset    = tokenized_dataset["validation"],
        data_collator   = data_collator,
        tokenizer       = tokenizer,
        compute_metrics = compute_metric
    )
    return seq_trainer

select_slice = lambda ds, split: ds[split].shuffle(seed=20).select(range(SAMPLING_SPLIT[split]))

def load_data():
    my_ds = load_dataset(
        'json', data_files = DATASET_PATH + DATASET) if USE_LOCAL_DATASET else load_dataset(DATASET)
    
    # loading samples from HF dataset for fast testing
    if( not USE_LOCAL_DATASET and DATA_SAMPLING ):
        my_ds = DatasetDict({
            "test": select_slice(my_ds, "test"),
            "train": select_slice(my_ds, "train"),
            "validation": select_slice(my_ds, "validation")
            })
    return my_ds

prefix_input_batch = lambda prefix, items: [prefix + q for q in items]

extract_from_nested_batch = lambda nested_col, items: [q[nested_col] for q in items]

tokenize = lambda data: tokenizer(data, max_length=max_input_length, truncation=True) #, return_tensors="pt")

def tokenize_batch(record_batch):
    
    # prefix each question in the batch
    input_questions = prefix_input_batch(TASK_PREFIX, record_batch[INPUT_FIELD])
    
    # WikiSQL has nested field showing the expected output
    expected_results = extract_from_nested_batch(NESTED_RESULT_FIELD, record_batch[RESULT_FIELD])
    
    # for self created dataset samples, use flat structure:
    #expected_results = record[RESULT_FIELD]
    
    # encoding the questions from dataset to get input_ids & attention_masks
    model_inputs = tokenize(input_questions)
    
    # encoding the expected results from dataset to compare against the predicted
    with tokenizer.as_target_tokenizer():
        labels = tokenize(expected_results)

    # merging labels, dictionary will contain:
    # input_ids & attention_masks of the inputs
    # input_ids (labels) & attention_masks (decoder_attention_masks) of the expected results
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["decoder_attention_mask"] = labels["attention_mask"]

    return model_inputs

def tokenize_dataset(dataset_to_tokenize):

    if USE_TKN_CACHE and os.path.exists(ENCODED_DATA_PATH):
        print("Loading tokenized dataset from cache")
        tokenized_dataset = datasets.load_from_disk(ENCODED_DATA_PATH)
    else:
        print("Tokenizing dataset ... it will take some time")
        tokenized_dataset = dataset_to_tokenize.map(tokenize_batch, batched=True)

        print("Saving tokenizing dataset ... ", ENCODED_DATA_PATH)
        tokenized_dataset.save_to_disk(ENCODED_DATA_PATH)

    # NOTE: for some reason, the tokenized dataset is contaiing 
    # all fields from input instead of just 4 tokenized fields
    # ['phase', 'question', 'table', 'sql', 'input_ids', 
    #   'attention_mask', 'labels', 'decoder_attention_mask']
    return tokenized_dataset

# ==============================================================#
# ======================== MAIN PROGRAM ========================#

full_dataset = load_data()
print("-"*20, "Dataset Loaded ")

tokenized_dataset = tokenize_dataset(full_dataset)
print("-"*20, "Dataset Tokenized ")

#using trainer to train, preferred as the code is cleaner
training_args = get_training_args()
print("-"*20, "Training Args Ready ")

metric = evaluate.load("rouge")
print("-"*20, "Rouge Metric Ready ")

data_collator = DataCollatorForSeq2Seq(tokenizer)

trainer = get_trainer(tokenized_dataset, training_args)
print("Training Logs at : ", LOG_DIR)

# Start tensorboard
%load_ext tensorboard
%tensorboard --logdir LOG_DIR+'/runs'

trainer.train()
trainer.save_model()
    