In [None]:
# ===================== Traing base T5 ==================== # 
import random
import pandas as pd
import numpy as np
import torch
import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
set_seed(42)

MODEL_BASE_DIR      = "/Users/sree/.cache/huggingface/hub"
IN_MODEL_NAME       = "t5-base"
OUT_MODEL_NAME      = "my--t5-base-finetuned-text-to-SQL" ##-finetuned-{source_lang}-to-{target_lang}"
OUTPUT_DIR          = MODEL_BASE_DIR + "/models--" + OUT_MODEL_NAME 
#LOG_DIR             = OUTPUT_DIR + "/logs"
TASK_PREFIX         = "Translate English to SQL: "

# initialize model & tokenizer
tokenizer           = T5Tokenizer.from_pretrained(IN_MODEL_NAME)
model               = T5ForConditionalGeneration.from_pretrained(IN_MODEL_NAME)


# initialize optimizer
no_decay            = ["bias", "LayerNorm.weight"]
grouped_parameters  = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer           = AdamW(grouped_parameters, lr=3e-4, eps=1e-8)

# sample dataset - single table, single where column 
text_to_sql_tuples = [
   ("What school did Patrick O'Bryant play for?", "SELECT school FROM table WHERE player = 'Patrick O'Bryant'"),
   ("What club did Patrick O'Bryant play for?", "SELECT club FROM table WHERE player = 'Patrick O'Bryant'"),
   ("How many cars sold in the year 1997?", "SELECT count(*) FROM table WHERE year = '1997'"),
   ("What's Dell Curry nationality?","SELECT nationality FROM table WHERE player = 'Dell Curry'"),
   ("which player is from georgia", "SELECT player FROM table WHERE state = 'Georgia'"),
   ("What nationality is the player Muggsy Bogues?", "SELECT nationality FROM table WHERE player = 'Muggsy Bogues'")
]

model.train(mode=True)
epochs = 5

for epoch in range(epochs):
    #print ("epoch ",epoch)
    for input, output in text_to_sql_tuples:
        
        input_sent = TASK_PREFIX + input+ " </s>"
        ouput_sent = output+" </s>"

        tokenized_inp = tokenizer.encode_plus(input_sent, max_length=96, 
                            padding=True, truncation=True, return_tensors="pt")
    
        input_ids  = tokenized_inp["input_ids"]
        attention_mask = tokenized_inp["attention_mask"]

        tokenized_output = tokenizer.encode_plus(ouput_sent, max_length=96, 
                            padding=True, truncation=True, return_tensors="pt")
        
        lm_labels= tokenized_output["input_ids"]
        decoder_attention_mask=  tokenized_output["attention_mask"]

        output = model(input_ids=input_ids, attention_mask=attention_mask, 
                        labels=lm_labels, decoder_attention_mask=decoder_attention_mask)

        loss = output[0]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
test_sent = TASK_PREFIX + "how many wins did Lebron James have in 2020?"+  "</s>"
test_tokenized = tokenizer.encode_plus(test_sent, return_tensors="pt")

test_input_ids  = test_tokenized["input_ids"]
test_attention_mask = test_tokenized["attention_mask"]

model.eval()
beam_outputs = model.generate(
    input_ids=test_input_ids,attention_mask=test_attention_mask,
    max_length=64,
    early_stopping=True,
    num_beams=10,
    num_return_sequences=3,
    no_repeat_ngram_size=2
)

for beam_output in beam_outputs:
    output = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
    print(output)