In [None]:
# ===================== Module Imports ==================== # 
import torch.optim as optim
#from torch.utils.data import Dataset, DataLoader

from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

#================== Global Constants ==================#
TASK_PREFIX     = {
                    "translation":"Translate English to SQL: ",
                    "metadata":"Find data tables in question: "
                    }
TASK_KEY        = {
                    "translation":"sql",
                    "metadata":"tables"
                    }
TASK_HINT       = {
                    "translation":". Use following data tables - ",
                    "metadata":""
                    }
MODEL_CLASSES   = {
                    "t5-small": (T5ForConditionalGeneration, T5Tokenizer),
                    "t5-base": (T5ForConditionalGeneration, T5Tokenizer),
                    }
OPTIM_CLASSES   = {
                    "sgd": optim.SGD,
                    "adam": optim.Adam,
                    }
MODEL_BASE_DIR  = "/Users/sree/.cache/huggingface/hub"
#LOG_DIR        = OUTPUT_DIR + "/logs"

In [None]:
from datasets import load_dataset, DatasetDict

class MyDataLoader:
    
    data_mode       = "record" #record|batch
    task_mode       = "translation" #translation|metadata
    given_input     = None
    expected_output = None
    skip_record     = None

    def __init__(self, data_mode):
        self.set_data_mode(data_mode)
        self.__init_data_extractors__()

    # train modes - record | batch
    def set_data_mode(self, data_mode):
        self.data_mode = data_mode

    # task modes - translation | metadata
    def set_task_mode(self, task_mode):
        self.task_mode = task_mode

    def load_data(self):
        data_path = "./data/my_flat_sql_data_meta.json"
        train_ds = load_dataset('json', data_files = data_path)
        #print("Keys: ",train_ds.keys())
        #print("Items: ",train_ds.items())

        # will have to load eval ds separate
        return train_ds, train_ds
    
    def __init_data_extractors__(self):
        
        if self.data_mode == 'batch':
            input_with_hint = lambda task_mode, tuple: TASK_PREFIX[task_mode] + tuple[0]+ TASK_HINT[task_mode] + tuple[1]

            self.given_input = lambda task_mode, batch: [input_with_hint(task_mode, row) for row in zip(batch['question'], batch['tables']) if all(row)]
            self.expected_output = lambda task_mode, batch: [row for row in batch[TASK_KEY[task_mode]] if row]
            self.skip_record = None

        if self.data_mode == 'record':
            extract_input = lambda task_mode, record: TASK_PREFIX[task_mode] + record['question']
            exract_hint = lambda task_mode, record: TASK_HINT[task_mode] + record['tables']

            self.given_input = lambda task_mode, record: extract_input(task_mode, record) + exract_hint(task_mode, record)
            self.expected_output = lambda task_mode, record: record[ TASK_KEY[task_mode] ]
            self.skip_record = lambda record: record['comment']
    
    def exract_input(self, data):
        return self.given_input(self.task_mode, data)

    def extract_expected_output(self, data):
        return self.expected_output(self.task_mode, data)
    
    def is_skip_record(self, data):
        return self.data_mode == 'record' and self.skip_record(data)
    
    def is_read_bacthed(self):
        return self.data_mode == 'batch'


In [None]:
import sys, time
#import pandas as pd
import numpy as np
import random
import torch
import concurrent

import concurrent.futures
import matplotlib.pyplot as plt

class MyT5Trainer:
    
    #================== Contained Objects ==================#
    # loader, tokenizer, model, optimizer
    
    #================== Model Info Attributes ==================#
    #out_model_name     = None
    #output_dir         = None

    #================== Trainer State Control attributes ==================#
    threads             = 1
    epochs              = 1
    loss_log            = []

    #================== T5 Model Hyperparameters ==================#
    #
    
    #================== Adam Optimizer Hyperparameters ==================#
    adam_lr             = 3e-4
    adam_eps            = 1e-8
    

    def __init__(self, model_name, loader):
        
        print("Initializing Trainer")
        self.set_loader(loader)
        
        # will be enabled in future when saving model checkpointing
        #self.out_model_name = f"my--{model_name}--finetuned-text-to-SQL"
        #self.output_dir = f"{MODEL_BASE_DIR}/models--{self.out_model_name}"

        try:
            model_class, tokenizer_class = MODEL_CLASSES[ model_name ]
        except KeyError:
            raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

        # model_max_length=512,
        self.tokenizer = tokenizer_class.from_pretrained( model_name )
        self.model = model_class.from_pretrained(model_name, pad_token_id=self.tokenizer.eos_token_id )

        no_decay = ["bias", "LayerNorm.weight"]
        grouped_parameters = [
            {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,},
            {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,},]
        self.optimizer = optim.Adam(grouped_parameters, lr=self.adam_lr, eps=self.adam_eps)
        #self.optimizer = optim.Adam(self.model.parameters*(), lr=self.ADAM_LR, eps=self.ADAM_EPS, weight_decay=0.0)

    
    #======================== Getters & Setters =========================#
    #====================================================================#
    
    def set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    def set_loader(self, loader):
        self.loader = loader

    def set_threads(self, max_threads):
        self.threads = max_threads
    
    def set_epochs(self, epochs):
        self.epochs = epochs

    def reset_accuracy_log(self):
        self.loss_log = []
        self.correctness = 0

    #==================== Functional Methods ===================#
    #===========================================================#
        
    __tokenize__ = lambda self, text: self.tokenizer.encode_plus(
        text, max_length=96, padding=True, truncation=True, return_tensors="pt")

    def __encode_data__(self, data):

        # preprocessing data - extracting from input data, prefixing & cleaning up
        input = self.loader.exract_input(data) #self.task_mode,
        expected_output = self.loader.extract_expected_output(data)

        # tokenizing input & exptected output data
        tokenized_input = self.__tokenize__(input)
        tokenized_output = self.__tokenize__(expected_output)

        return (tokenized_input["input_ids"], tokenized_input["attention_mask"], 
        tokenized_output["input_ids"], tokenized_output["attention_mask"])
    

    def __generate_prediction__(self, data):

        #if training in single record mode, check for empty or comment records
        if self.loader.is_skip_record(data):
            return None
        
        # parse, cleanse & tokenize input data record or bacth records (based on train_mode)
        input_ids, attention_mask, lm_labels, decoder_attention_mask = self.__encode_data__(data)

        # forward pass - predict
        return self.model(
            input_ids = input_ids, attention_mask = attention_mask, 
            labels = lm_labels, decoder_attention_mask = decoder_attention_mask)
    
    
    def __process_train_data__(self, data):

        output = self.__generate_prediction__(data)

        #in case of skip records
        if output is None:
            return
        
        # foward pass - compute loss
        loss = output[0]
        
        #record the loss for plotting
        self.loss_log.append(loss.item())

        #zero all gradients before tha backward pass
        self.optimizer.zero_grad()

        # backward pass
        loss.backward()

        self.optimizer.step()
        

    def __process_eval_data__(self, data):
        
        output = self.__generate_prediction__(data)

        #in case of skip records
        if output is None:
            return

        print("Next iteration - EVAL")
        # Get the index of the max log-probability.
        #pred = output.argmax(dim=1, keepdim=True)
        #self.correctness += pred.eq(lm_labels.view_as(pred)).sum().item()


    def train_model(self, train_ds, eval_ds):
        batched=self.loader.is_read_bacthed()
        trainer.reset_accuracy_log()

        for epoch in range(self.epochs):
            print("epoch ",epoch)

            # Model training
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                #executor.map(self.__process_train_data__, train_ds)
                #executor.map(self.__process_train_data__,range(train_ds))
                executor.submit( train_ds.map, self.__process_train_data__, batched=batched )
            
            #train_ds.map(self.__process_train_data__, batched=batched)

            # Model validation
            # with torch.no_grad():
            #    eval_ds.map(self.__process_eval_data__, batched=batched)
            # accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)

    def plot_loss(self):
        plt.plot(self.loss_log, label = "Stochastic Gradient Descent")
        #plt.plot(loss_Adam,label = "Adam Optimizer")
        plt.xlabel('epoch')
        plt.ylabel('Cost/ total loss')
        plt.legend()
        plt.show()


#================== Main Program ==================#
#==================================================#

# record = one record at a time training 
# batch = training 1K batched records at a time - not yeilding expected results - needs investigation
data_mode = "record" #record|batch
dataLoader = MyDataLoader(data_mode)

# load training & evaluation/validation datasets
train_ds, eval_ds = dataLoader.load_data()

# initialize tokenizer/encoder, model & optimizer
trainer = MyT5Trainer("t5-base", dataLoader) # t5-base | t5-small
trainer.set_threads(1)
trainer.set_epochs(1)
trainer.set_seed(42)

task_modes = ['translation'] #'metadata'
for task_mode in task_modes:
    
    dataLoader.set_task_mode(task_mode)
    trainer.train_model(train_ds, eval_ds)

    trainer.plot_loss()
    print(trainer.loss_log)

In [None]:
TASK_MODE = 'translation'
input_text = TASK_PREFIX[TASK_MODE] + "Which teams played in 2022?" #+  "</s>"
hint_text = TASK_HINT[TASK_MODE] + "game_stats"
expected_output = "SELECT team_name FROM game_stats WHERE DATE_PART('YEAR', pay_date)= 2022"
print(input_text+hint_text)

trainer.set_data_mode('record')
trainer.set_task_mode('translation')
trainer.set_mode('eval')

test_tokenized = trainer.tokenizer.encode_plus(input_text+hint_text, return_tensors="pt")
test_input_ids  = test_tokenized["input_ids"]
test_attention_mask = test_tokenized["attention_mask"]

output_tokenized = trainer.tokenizer.encode_plus(expected_output, return_tensors="pt")
labels = output_tokenized["input_ids"]

trainer.model.eval()
outputs = trainer.model.generate(
    input_ids=test_input_ids,
    attention_mask=test_attention_mask,
    temperature = .96,
    max_new_tokens=64,
    #max_length=64,
    
    #early_stopping=True,
    #num_beams=10,
    #num_return_sequences=1, #3
    #no_repeat_ngram_size=2 #2
    
    # ----- Beam Search w/ return sequences -----#
    early_stopping=True,
    num_beams=10,
    no_repeat_ngram_size=2,
    num_return_sequences=5, #num_return_sequences<=num_beams

    # ----- Top P & Top K sampling -----# ANALYSIS - much faster than beam search
    #do_sample=True,
    #top_k=5, 
    #top_p=3,
    #num_return_sequences=1

    # --- Greedy search ----#
    
)

for beam_output in outputs:
    
    output = trainer.tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    print("BEAM: ", beam_output)
    print(output)

    # Get the index of the max log-probability.
    #pred = beam_output.argmax(dim=0, keepdim=True)
    pred = torch.argmax(beam_output, dim=0, keepdim=True)
    print("PRED Index: ",pred)

    max = torch.max()
    print("Max: ",max)

    print("EXPected: ",expected_output)
    print("LABELS: ",labels)

    print("EQ: ", beam_output.eq(labels))

    #print("EQ: ", labels.eq(pred).sum() )
    #print("EQ: ", labels.eq(pred).sum().item() )

    #print("VIEW SUM: ", labels.sum())
    
    #print("PRED SHAPE: ", pred.shape)
    print("LABEL SHAPE: ", labels.shape)
    
    #print("VIEW AS: ",pred.view_as(labels))
    #print("VIEW AS: ",labels.view_as(pred))

    #correct = pred.eq(labels.view_as(pred)).sum()
    correct = pred.eq(labels).sum().item()
    print("CORRECT: ", correct)

In [None]:
import torch
t1 = torch.Tensor(1,2)
t2 = torch.Tensor([20])
t3 = torch.Tensor([[2,20,3,4,5]])
#t4 = torch.Tensor([],[])

print(t1.shape)
print(t2.shape)
print(t3.shape)
print(t3.size(0))
print(t3.size(1))
#print("VIEW:", t3.view(-1,-1))
eqt = t3.eq(t2)
print("EQ: ", eqt )
print("SUM: ", eqt.sum() )
print("ITEM: ", eqt.sum().item() )


# mode - train | eval
#def set_mode(self, mode):
#    self.mode = mode

#def __train__(self):
    #    self.set_mode("train")
    #    self.model.train(mode=True)

    #def __eval__(self):
    #    self.set_mode("eval")
    #    self.model.eval()