In [None]:
# ===================== Traing base T5 ==================== # 
import sys
import random
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset, DatasetDict
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
import matplotlib.pyplot as plt

#================== Global Constants ==================#
TASK_PREFIX     = {
                    "translation":"Translate English to SQL: ",
                    "metadata":"Find data tables in question: "
                    }
TASK_KEY        = {
                    "translation":"sql",
                    "metadata":"tables"
                    }
TASK_HINT       = {
                    "translation":". Use following data tables - ",
                    "metadata":""
                    }
MODEL_CLASSES   = {
                    "t5-small": (T5ForConditionalGeneration, T5Tokenizer),
                    "t5-base": (T5ForConditionalGeneration, T5Tokenizer),
                    }
OPTIM_CLASSES   = {
                    "sgd": optim.SGD,
                    "adam": optim.Adam,
                    }
MODEL_BASE_DIR  = "/Users/sree/.cache/huggingface/hub"
#LOG_DIR        = OUTPUT_DIR + "/logs"


class MyT5Trainer:

    #================== Training Logic Properties ==================#
    task_mode           = "translation"
    loss_log            = []

    #================== Adam Hyperparameters ==================#
    adam_lr             = 3e-4
    adam_eps            = 1e-8
    

    def __init__(self, model_name, seed):
        print("Initializing Trainer")
        self.__initialize_seed__(seed)

        #init - self.model_name, self.out_model_name, self.output_dir
        self.__initialize_model_dir__(model_name)

        #init - self.tokenizer, self.model
        self.__initialize_model__()

        #init - self.optimizer
        self.__initialize_optimizer__()
        
    def __initialize_seed__(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    def __initialize_model_dir__(self, model_name):
        self.model_name = model_name
        self.out_model_name = f"my--{self.model_name}--finetuned-text-to-SQL"
        self.output_dir = f"{MODEL_BASE_DIR}/models--{self.out_model_name}"
    
    def __initialize_model__(self):
        try:
            model_class, tokenizer_class = MODEL_CLASSES[ self.model_name ]
        except KeyError:
            raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

        # model_max_length=512,
        self.tokenizer = tokenizer_class.from_pretrained( self.model_name )
        self.model = model_class.from_pretrained(self.model_name, pad_token_id=self.tokenizer.eos_token_id )
    
    def __initialize_optimizer_param__(self):
        no_decay = ["bias", "LayerNorm.weight"]
        grouped_parameters = [
            {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,},
            {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,},]
        return grouped_parameters
    
    def __initialize_optimizer__(self):
        grouped_parameters = self.__initialize_optimizer_param__()
        self.optimizer = optim.Adam(grouped_parameters, lr=self.adam_lr, eps=self.adam_eps)
        #self.optimizer = optim.Adam(self.model.parameters*(), lr=self.ADAM_LR, eps=self.ADAM_EPS, weight_decay=0.0)

    __tokenize__ = lambda self, text: self.tokenizer.encode_plus(
        text, max_length=96, padding=True, truncation=True, return_tensors="pt")


    def __process_data__(self, record):
        if self.skip_record(record):
            return
        
        input = self.extract_input_from_ds(self.task_mode, record)
        tokenized_input = self.__tokenize__(input)
        input_ids  = tokenized_input["input_ids"]
        attention_mask = tokenized_input["attention_mask"]
        
        expected_output = self.extract_expected_output_from_ds(self.task_mode, record)
        tokenized_output = self.__tokenize__(expected_output)
        lm_labels = tokenized_output["input_ids"]
        decoder_attention_mask=  tokenized_output["attention_mask"]

        # forward pass - predict
        output = self.model(
            input_ids = input_ids, attention_mask = attention_mask, 
            labels = lm_labels, decoder_attention_mask = decoder_attention_mask)
        
        # foward pass - compute loss
        loss = output[0]
        
        #record the loss for plotting
        self.loss_log.append(loss.item())

        #zero all gradients before tha backward pass
        self.optimizer.zero_grad()

        # backward pass
        loss.backward()

        self.optimizer.step()

    def set_data_extration_fuctions(self, extract_input, extract_expected_output, skip_record):
        self.extract_input_from_ds = extract_input
        self.extract_expected_output_from_ds = extract_expected_output
        self.skip_record = skip_record

    def plot_loss(self):
        plt.plot(self.loss_log, label = "Stochastic Gradient Descent")
        #plt.plot(loss_Adam,label = "Adam Optimizer")
        plt.xlabel('epoch')
        plt.ylabel('Cost/ total loss')
        plt.legend()
        plt.show()

    def train_model(self, epochs, ds, task_mode):
        if self.extract_input_from_ds == None:
            print("Set data extraction logic before training model")
            return
        
        self.loss_log = []
        self.TASK_MODE = task_mode
        print("Training ", self.TASK_MODE)
        self.model.train(mode=True)
        for epoch in range(epochs):
            print("epoch ",epoch)
            ds.map(self.__process_data__)


#================== Main Program ==================#

#extract_input = lambda task_mode, record: TASK_PREFIX[task_mode] + record['question'] # No Hint
extract_input = lambda task_mode, record: TASK_PREFIX[task_mode] + record['question'] + TASK_HINT[task_mode] + record['tables']
extract_expected_output = lambda task_mode, record: record[ TASK_KEY[task_mode] ]
skip_record = lambda record: record['comment']

data_path = "./data/my_flat_sql_data_meta.json"
my_ds = load_dataset('json', data_files = data_path)
print(my_ds)

# initialize tokenizer/encoder, model & optimizer
trainer = MyT5Trainer("t5-small", 42)
trainer.set_data_extration_fuctions(extract_input, extract_expected_output, skip_record)

task_modes = ['translation'] #'metadata'
for task_mode in task_modes:
    trainer.train_model(8, my_ds, task_mode)
    #plot_loss(loss_log)
    print(trainer.loss_log)