In [None]:
# ===================== Traing base T5 ==================== # 
import sys, time
import random
import pandas as pd
import numpy as np
import torch
import torch.optim as optim
#from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, DatasetDict
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
import matplotlib.pyplot as plt

#================== Global Constants ==================#
TASK_PREFIX     = {
                    "translation":"Translate English to SQL: ",
                    "metadata":"Find data tables in question: "
                    }
TASK_KEY        = {
                    "translation":"sql",
                    "metadata":"tables"
                    }
TASK_HINT       = {
                    "translation":". Use following data tables - ",
                    "metadata":""
                    }
MODEL_CLASSES   = {
                    "t5-small": (T5ForConditionalGeneration, T5Tokenizer),
                    "t5-base": (T5ForConditionalGeneration, T5Tokenizer),
                    }
OPTIM_CLASSES   = {
                    "sgd": optim.SGD,
                    "adam": optim.Adam,
                    }
MODEL_BASE_DIR  = "/Users/sree/.cache/huggingface/hub"
#LOG_DIR        = OUTPUT_DIR + "/logs"


class MyT5Trainer:

    #================== Training Logic Properties ==================#
    train_mode          = "record"      #record|batch
    task_mode           = "translation" #translation|metadata
    loss_log            = []

    #================== Adam Hyperparameters ==================#
    adam_lr             = 3e-4
    adam_eps            = 1e-8
    

    def __init__(self, model_name, seed):
        print("Initializing Trainer")
        self.__initialize_seed__(seed)

        #init - self.model_name, self.out_model_name, self.output_dir
        self.__initialize_model_dir__(model_name)

        #init - self.tokenizer, self.model
        self.__initialize_model__()

        #init - self.optimizer
        self.__initialize_optimizer__()

        #following to be setup before training:
        # mode, train_mode, task_mode, extract_input_from_ds, extract_expected_output_from_ds, skip_record
        
    def __initialize_seed__(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    def __initialize_model_dir__(self, model_name):
        self.model_name = model_name
        self.out_model_name = f"my--{self.model_name}--finetuned-text-to-SQL"
        self.output_dir = f"{MODEL_BASE_DIR}/models--{self.out_model_name}"
    
    def __initialize_model__(self):
        try:
            model_class, tokenizer_class = MODEL_CLASSES[ self.model_name ]
        except KeyError:
            raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

        # model_max_length=512,
        self.tokenizer = tokenizer_class.from_pretrained( self.model_name )
        self.model = model_class.from_pretrained(self.model_name, pad_token_id=self.tokenizer.eos_token_id )
    
    def __initialize_optimizer_param__(self):
        no_decay = ["bias", "LayerNorm.weight"]
        grouped_parameters = [
            {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,},
            {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,},]
        return grouped_parameters
    
    def __initialize_optimizer__(self):
        grouped_parameters = self.__initialize_optimizer_param__()
        self.optimizer = optim.Adam(grouped_parameters, lr=self.adam_lr, eps=self.adam_eps)
        #self.optimizer = optim.Adam(self.model.parameters*(), lr=self.ADAM_LR, eps=self.ADAM_EPS, weight_decay=0.0)

    def set_data_extractors(self, extract_input, extract_expected_output, skip_record):
        self.extract_input_from_ds = extract_input
        self.extract_expected_output_from_ds = extract_expected_output
        self.skip_record = skip_record

    # task modes - translation | metadata
    def set_task_mode(self, task_mode):
        self.TASK_MODE = task_mode

    # train modes - record | batch
    def set_data_mode(self, data_mode):
        self.data_mode = data_mode

    # mode - train | eval
    def set_mode(self, mode):
        self.mode = mode

    def reset_accuracy_log(self):
        self.loss_log = []
        self.correctness = 0

    def __train__(self):
        self.set_mode("train")
        self.model.train(mode=True)

    def __eval__(self):
        self.set_mode("eval")
        self.model.eval()
        
    __tokenize__ = lambda self, text: self.tokenizer.encode_plus(
        text, max_length=96, padding=True, truncation=True, return_tensors="pt")

    def __encode_data__(self, data):

        # preprocessing data - extracting from input data file, prefixing & cleaning up
        input = self.extract_input_from_ds(self.task_mode, data)
        expected_output = self.extract_expected_output_from_ds(self.task_mode, data)

        # tokenizing input & exptected output data
        tokenized_input = self.__tokenize__(input)
        tokenized_output = self.__tokenize__(expected_output)

        return (tokenized_input["input_ids"], tokenized_input["attention_mask"], 
        tokenized_output["input_ids"], tokenized_output["attention_mask"])

    def __process_data__(self, data):

        #if training in single record mode, check for empty or comment records
        if self.data_mode == 'record' and self.skip_record(data):
            return
        
        # parse, cleanse & tokenize input data record or bacth records (based on train_mode)
        start = time.time()
        input_ids, attention_mask, lm_labels, decoder_attention_mask = self.__encode_data__(data)
        end = time.time()
        print("Time (sec) to Encode: ", end-start)

        # forward pass - predict
        start = time.time()
        output = self.model(
            input_ids = input_ids, attention_mask = attention_mask, 
            labels = lm_labels, decoder_attention_mask = decoder_attention_mask)
        end = time.time()
        print("Time (sec) to Predict: ", end-start)

        if(self.mode == 'train'):
            # foward pass - compute loss
            loss = output[0]
            
            #record the loss for plotting
            self.loss_log.append(loss.item())

            #zero all gradients before tha backward pass
            self.optimizer.zero_grad()

            # backward pass
            loss.backward()

            self.optimizer.step()
        
        if(self.mode == 'eval'):
            print("Next iteration - EVAL")
            # Get the index of the max log-probability.
            pred = output.argmax(dim=1, keepdim=True)
            self.correctness += pred.eq(lm_labels.view_as(pred)).sum().item()


    def plot_loss(self):
        plt.plot(self.loss_log, label = "Stochastic Gradient Descent")
        #plt.plot(loss_Adam,label = "Adam Optimizer")
        plt.xlabel('epoch')
        plt.ylabel('Cost/ total loss')
        plt.legend()
        plt.show()

    def train_model(self, epochs, train_ds, eval_ds):
        if self.extract_input_from_ds == None:
            print("Set data extraction logic before training model")
            return
        
        trainer.reset_accuracy_log()
        batched = True if self.data_mode =='batch' else False
        for epoch in range(epochs):
            print("epoch ",epoch)

            # Model training
            self.__train__()
            train_ds.map(self.__process_data__, batched=batched)

            # Model validation
            # self.__eval__()
            # with torch.no_grad():
            #    eval_ds.map(self.__process_data__, batched=batched)
            # accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)


def compile_data_extractors(train_mode):
    if train_mode == 'batch':
        input_with_hint = lambda task_mode, tuple: TASK_PREFIX[task_mode] + tuple[0]+ TASK_HINT[task_mode] + tuple[1]
        given_input = lambda task_mode, batch: [input_with_hint(task_mode, row) for row in zip(batch['question'], batch['tables']) if all(row)]
        expected_output = lambda task_mode, batch: [row for row in batch[TASK_KEY[task_mode]] if row]
        skip_record = None
    else:
        extract_input = lambda task_mode, record: TASK_PREFIX[task_mode] + record['question']
        exract_hint = lambda task_mode, record: TASK_HINT[task_mode] + record['tables'] 
        given_input = lambda task_mode, record: extract_input(task_mode, record) + exract_hint(task_mode, record)
        expected_output = lambda task_mode, record: record[ TASK_KEY[task_mode] ]
        skip_record = lambda record: record['comment']
    
    return (given_input, expected_output, skip_record)

def load_data():
    data_path = "./data/my_flat_sql_data_meta.json"
    train_ds = load_dataset('json', data_files = data_path)
    #print(train_ds)

    # will have to load eval ds separate
    return train_ds, train_ds

#================== Main Program ==================#

# load training & evaluation/validation datasets
train_ds,eval_ds = load_data()

# record = one record at a time training 
# batch = training 1K batched records at a time - not yeilding expected results - needs investigation
data_mode = "record" #record|batch

# get lambdas for data extration
given_input, expected_output, skip_record = compile_data_extractors(data_mode)

# initialize tokenizer/encoder, model & optimizer
trainer = MyT5Trainer("t5-base", 42) # t5-base | t5-small

# set the lamdas, these lambdas will get invoked automatically during data extraction
trainer.set_data_extractors(given_input, expected_output, skip_record)

epoch = 1
task_modes = ['translation'] #'metadata'
for task_mode in task_modes:
    trainer.set_data_mode(data_mode)
    trainer.set_task_mode(task_mode)

    trainer.train_model(epoch, train_ds, eval_ds)
    trainer.plot_loss()
    print(trainer.loss_log)

In [None]:
TASK_MODE = 'translation'
input_text = TASK_PREFIX[TASK_MODE] + "Which teams played in 2022?" #+  "</s>"
hint_text = TASK_HINT[TASK_MODE] + "game_stats"
expected_output = "SELECT team_name FROM game_stats WHERE DATE_PART('YEAR', pay_date)= 2022"
print(input_text+hint_text)

trainer.set_data_mode('record')
trainer.set_task_mode('translation')
trainer.set_mode('eval')

test_tokenized = trainer.tokenizer.encode_plus(input_text+hint_text, return_tensors="pt")
test_input_ids  = test_tokenized["input_ids"]
test_attention_mask = test_tokenized["attention_mask"]

output_tokenized = trainer.tokenizer.encode_plus(expected_output, return_tensors="pt")
labels = output_tokenized["input_ids"]

trainer.model.eval()
outputs = trainer.model.generate(
    input_ids=test_input_ids,
    attention_mask=test_attention_mask,
    temperature = .96,
    max_new_tokens=64,
    #max_length=64,
    
    #early_stopping=True,
    #num_beams=10,
    #num_return_sequences=1, #3
    #no_repeat_ngram_size=2 #2
    
    # ----- Beam Search w/ return sequences -----#
    early_stopping=True,
    num_beams=10,
    no_repeat_ngram_size=2,
    num_return_sequences=5, #num_return_sequences<=num_beams

    # ----- Top P & Top K sampling -----# ANALYSIS - much faster than beam search
    #do_sample=True,
    #top_k=5, 
    #top_p=3,
    #num_return_sequences=1

    # --- Greedy search ----#
    
)

for beam_output in outputs:
    
    output = trainer.tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    print("BEAM: ", beam_output)
    print(output)

    # Get the index of the max log-probability.
    #pred = beam_output.argmax(dim=0, keepdim=True)
    pred = torch.argmax(beam_output, dim=0, keepdim=True)
    print("PRED Index: ",pred)

    max = torch.max()
    print("Max: ",max)

    print("EXPected: ",expected_output)
    print("LABELS: ",labels)

    print("EQ: ", beam_output.eq(labels))

    #print("EQ: ", labels.eq(pred).sum() )
    #print("EQ: ", labels.eq(pred).sum().item() )

    #print("VIEW SUM: ", labels.sum())
    
    #print("PRED SHAPE: ", pred.shape)
    print("LABEL SHAPE: ", labels.shape)
    
    #print("VIEW AS: ",pred.view_as(labels))
    #print("VIEW AS: ",labels.view_as(pred))

    #correct = pred.eq(labels.view_as(pred)).sum()
    correct = pred.eq(labels).sum().item()
    print("CORRECT: ", correct)

In [None]:

#extract_input = lambda task_mode, record: TASK_PREFIX[task_mode] + record['question']
#exract_hint = lambda task_mode, record: TASK_HINT[task_mode] + record['tables'] 
#extract_input_with_hint = lambda task_mode, record: extract_input(task_mode, record) + exract_hint(task_mode, record)
#expected_output = lambda task_mode, record: record[ TASK_KEY[task_mode] ]
#skip_record = lambda record: record['comment']

#join_input_hint = lambda task_mode, tuple: TASK_PREFIX[task_mode] + tuple[0]+ TASK_HINT[task_mode] + tuple[1]
#input_batch = lambda task_mode, batch: [join_input_hint(task_mode, row) for row in zip(batch['question'], batch['tables']) if all(row)]
#expected_output_batch = lambda task_mode, batch: [row for row in batch[TASK_KEY[task_mode]] if row]
#skip_record_batch = None

# one record at a time training 
#trainer.set_data_extractors(train_mode, given_input, expected_output, skip_record)

# batch training is not yeilding expected results - need to investigate
#trainer.set_data_extractors(train_mode, given_input, expected_output, skip_record)


    #if(self.train_mode =='record'):
        #    ds.map(self.__process_data__, batched=False)
        #if (self.train_mode =='batch'):
        #    ds.map(self.__process_data__, batched=True)


        #input_ids  = tokenized_input["input_ids"]
        #attention_mask = tokenized_input["attention_mask"]
        #lm_labels = tokenized_output["input_ids"]
        #decoder_attention_mask=  tokenized_output["attention_mask"]

In [None]:
import torch
t1 = torch.Tensor(1,2)
t2 = torch.Tensor([20])
t3 = torch.Tensor([[2,20,3,4,5]])
#t4 = torch.Tensor([],[])

print(t1.shape)
print(t2.shape)
print(t3.shape)
print(t3.size(0))
print(t3.size(1))
#print("VIEW:", t3.view(-1,-1))
eqt = t3.eq(t2)
print("EQ: ", eqt )
print("SUM: ", eqt.sum() )
print("ITEM: ", eqt.sum().item() )
