In [None]:
# ===================== Traing base T5 ==================== # 
import sys
import random
import pandas as pd
import numpy as np
import torch
import tqdm
from datasets import load_dataset, DatasetDict
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
import matplotlib.pyplot as plt

MODEL_CLASSES = {
    "t5-small": (T5ForConditionalGeneration, T5Tokenizer),
    "t5-base": (T5ForConditionalGeneration, T5Tokenizer),
}
OPTIM_CLASSES = {
    "sgd": optim.SGD,
    "adam": optim.Adam,
}

#================== Adam Hyperparameters ==================#
ADAM_LR             = 3e-4
ADAM_EPS            = 1e-8

#================== Model, Data & Log Paths ==================#
MODEL_BASE_DIR      = "/Users/sree/.cache/huggingface/hub"
#IN_MODEL_NAME       = "t5-base"
IN_MODEL_NAME       = "t5-small"
OUT_MODEL_NAME      = "my--t5-base-finetuned-text-to-SQL" ##-finetuned-{source_lang}-to-{target_lang}"
OUTPUT_DIR          = MODEL_BASE_DIR + "/models--" + OUT_MODEL_NAME 
#LOG_DIR             = OUTPUT_DIR + "/logs"
DATA_PATH           = "./data/my_flat_sql_data_meta.json"

#================== Training Logic Properties ==================#
TASK_PREFIX         = {
                        "translation":"Translate English to SQL: ",
                        "metadata":"Find data tables in question: "
                        }
RESPONSE_KEY        = {
                        "translation":"sql",
                        "metadata":"tables"
                        }
TASK_HINT           = {
                        "translation":". Use following data tables - ",
                        "metadata":""
                        }
TASK_MODE           = "translation"

loss_log            = []

#================== Functions ==================#
def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
set_seed(42)

def initialize_model(model_name):
    # Initialize the model and tokenizer
    try:
        model_class, tokenizer_class = MODEL_CLASSES[ model_name ]
    except KeyError:
        raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

    tokenizer = tokenizer_class.from_pretrained( model_name )
    model = model_class.from_pretrained( model_name, pad_token_id=tokenizer.eos_token_id )
    return model, tokenizer

def initialize_optimizer_param(model):
    print("Named Params: ")
    for p in model.named_parameters():
        print(p)

    no_decay = ["bias", "LayerNorm.weight"]
    grouped_parameters = [
        {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,},
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,},]
    return grouped_parameters

def initialize_optimizer(model):
    
    # initialize optimizer params
    grouped_parameters = initialize_optimizer_param(model)

    return optim.Adam(grouped_parameters, lr=ADAM_LR, eps=ADAM_EPS)
    #return AdamW(grouped_parameters, lr=3e-4, eps=1e-8)


tokenize = lambda tokenizer, text: tokenizer.encode_plus(
    text, max_length=96, padding=True, truncation=True, return_tensors="pt")

def process_data(record):
    if record['comment']:
        return # to skip comments line in the data file

    hint = TASK_HINT[TASK_MODE] + record['tables']
    input = TASK_PREFIX[TASK_MODE] + record['question'] + hint #+ " </s>"
    expected_ouput = record[RESPONSE_KEY[TASK_MODE]] #+ " </s>"

    tokenized_input = tokenize(tokenizer, input)
    input_ids  = tokenized_input["input_ids"]
    attention_mask = tokenized_input["attention_mask"]
    
    # model_max_length=512,
    tokenized_output = tokenize(tokenizer, expected_ouput)
    lm_labels = tokenized_output["input_ids"]
    decoder_attention_mask=  tokenized_output["attention_mask"]

    output = model(
        input_ids = input_ids, attention_mask = attention_mask, 
        labels = lm_labels, decoder_attention_mask = decoder_attention_mask)
    
    loss = output[0]
    loss_log.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

def model_train(epochs):
    print("Training ", TASK_MODE)
    model.train(mode=True)
    for epoch in range(epochs):
        print("epoch ",epoch)
        my_ds.map(process_data)

def plot_loss(loss_log):
    plt.plot(loss_log,label = "Stochastic Gradient Descent")
    #plt.plot(loss_Adam,label = "Adam Optimizer")
    plt.xlabel('epoch')
    plt.ylabel('Cost/ total loss')
    plt.legend()
    plt.show()


#================== Main Program ==================#

# initialize model & tokenizer
model, tokenizer = initialize_model(IN_MODEL_NAME)

# initialize optimizer
optimizer = initialize_optimizer(model)

my_ds = load_dataset('json', data_files = DATA_PATH)
print(my_ds)

task_modes = ['translation'] #'metadata'

for mode in task_modes:
    loss_log = []
    TASK_MODE = mode
    model_train(8)
    #plot_loss(loss_log)
    print(loss_log)


In [None]:
TASK_MODE = 'translation'
test_sent = TASK_PREFIX[TASK_MODE] + "Which customers ordered in 2020?" #+  "</s>"
print(test_sent)
test_tokenized = tokenizer.encode_plus(test_sent, return_tensors="pt")

test_input_ids  = test_tokenized["input_ids"]
test_attention_mask = test_tokenized["attention_mask"]

model.eval()
beam_outputs = model.generate(
    input_ids=test_input_ids,
    attention_mask=test_attention_mask,
    temperature = .96,
    max_new_tokens=64,
    #max_length=64,
    
    #early_stopping=True,
    #num_beams=10,
    #num_return_sequences=1, #3
    #no_repeat_ngram_size=2 #2
    
    # ----- Beam Search w/ return sequences -----#
    #early_stopping=True,
    #num_beams=10,
    #no_repeat_ngram_size=2,
    #num_return_sequences=5, #num_return_sequences<=num_beams

    # ----- Top P & Top K sampling -----# ANALYSIS - much faster than beam search
    do_sample=True,
    top_k=5, 
    top_p=3,
    num_return_sequences=1
)

for beam_output in beam_outputs:
    output = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
    print(output)