In [1]:
import pandas as pd
import json
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader
import torch

In [2]:
# torch.set_num_threads(38)

In [3]:
queryData = pd.read_csv('gs://data_tql/spider/processed/spiderQueryData.csv')
tableData = pd.read_csv('gs://data_tql/spider/processed/Schemas/tablesSchemaSpider.csv')

display(queryData.head(1))
display(tableData.head(2))

Unnamed: 0,db_id,TQL,SQL,dataset,fileName,filePath,result
0,department_management,How many heads of the departments are older th...,SELECT count(*) FROM head WHERE age > 56,train,department_management.sqlite,sqliteDB/department_management.sqlite,{'count(*)': {0: 5}}


Unnamed: 0,schema_id,table_name,table_name_original,primary_key,column_list,column_list_original,column_datatypes,foreign_keys
0,perpetrator,perpetrator,perpetrator,Perpetrator_ID,"['perpetrator id', 'people id', 'date', 'year'...","['Perpetrator_ID', 'People_ID', 'Date', 'Year'...","['number', 'number', 'text', 'number', 'text',...",[]
1,perpetrator,people,people,People_ID,"['people id', 'name', 'height', 'weight', 'hom...","['People_ID', 'Name', 'Height', 'Weight', 'Hom...","['number', 'text', 'number', 'number', 'text']","[['perpetrator', 'People_ID', 'people', 'Peopl..."


In [4]:
def create_schema_natural_language(row):

    schema_id = row['schema_id']
    table_name = row['table_name']
    primary_key = row['primary_key']
    column_list = eval(row['column_list_original'])
    datatype_list = eval(row['column_datatypes'])
    foreign_key = eval(row['foreign_keys'])

    column_list_with_datatype = []
    for column, datatype in zip(column_list, datatype_list):
        column_list_with_datatype.append(' has datatype '.join([column, datatype]))

    schema_natural_language = f"Given the Table {table_name} having columns as {', '.join(column_list_with_datatype)} which has {primary_key}"
    return schema_natural_language

In [5]:
tableData['schema_natural_language'] = tableData.apply(create_schema_natural_language, axis = 1)
tableData.head(3)

all_schemas = tableData['schema_id'].unique()
schema_table_query = {}
for schema in all_schemas:
    schema_details = ' and '.join(tableData[tableData['schema_id'] == schema]['schema_natural_language'].values)
    schema_table_query[schema] = schema_details

queryData['schema_natural_language'] = queryData['db_id'].map(schema_table_query)
queryData['final_TQL'] = queryData['TQL'] + ' ' + queryData['schema_natural_language']
queryData.head(2)

queryData['final_TQL'][0], queryData['SQL'][0]

('How many heads of the departments are older than 56 ? Given the Table department having columns as Department_ID has datatype number, Name has datatype text, Creation has datatype text, Ranking has datatype number, Budget_in_Billions has datatype number, Num_Employees has datatype number which has Department_ID and Given the Table head having columns as head_ID has datatype number, name has datatype text, born_state has datatype text, age has datatype number which has head_ID and Given the Table management having columns as department_ID has datatype number, head_ID has datatype number, temporary_acting has datatype text which has department_ID',
 'SELECT count(*) FROM head WHERE age  >  56')

In [20]:
# Load the pretrained T5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')

In [21]:
# Define a custom dataset for training
class SQLDataset(Dataset):
    def __init__(self, input_texts, target_queries, tokenizer, task_prefix):
        self.input_texts = input_texts
        self.target_queries = target_queries
        self.tokenizer = tokenizer
        self.task_prefix = task_prefix

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, index):
        
        input_text = self.task_prefix + self.input_texts[index]
        target_query = self.target_queries[index]

        input_encoding = self.tokenizer([input_text], return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        target_encoding = self.tokenizer([target_query], return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        
        return {
            'input_ids': input_encoding.input_ids.squeeze(0),
            'attention_mask': input_encoding.attention_mask.squeeze(0),
            'labels': target_encoding.input_ids.squeeze(0),
        }

In [22]:
# Load the labeled dataset
input_texts = queryData['final_TQL'].values # List of input texts
target_queries = queryData['SQL'].values  # List of corresponding target SQL queries

# Split the dataset into train and validation sets
train_input_texts, val_input_texts, train_target_queries, val_target_queries = train_test_split(input_texts, target_queries, test_size=0.2, random_state=42)

In [23]:
# Create instances of the custom dataset
task_prefix = 'Generate an SQL Query for '
train_dataset = SQLDataset(train_input_texts, train_target_queries, tokenizer, task_prefix)
val_dataset = SQLDataset(val_input_texts, val_target_queries, tokenizer, task_prefix)

In [24]:
# Define the training hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 10
LEARNING_RATE = 0.01

# Define the optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = torch.nn.CrossEntropyLoss()

In [25]:
# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Training loop
for epoch in tqdm(range(NUM_EPOCHS)):
    model.train()
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    # Evaluation on validation set
    model.eval()
    total_val_loss = 0
    for batch in val_dataloader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            val_loss = outputs.loss
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(val_dataloader)

    print(f'Epoch: {epoch+1}, Validation Loss: {avg_val_loss:.4f}')

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/122 [00:00<?, ?it/s][A
  1%|          | 1/122 [00:16<34:03, 16.89s/it][A
  2%|▏         | 2/122 [00:34<34:10, 17.09s/it][A
  2%|▏         | 3/122 [00:49<32:14, 16.26s/it][A
  3%|▎         | 4/122 [01:06<32:18, 16.43s/it][A
  4%|▍         | 5/122 [01:22<32:12, 16.52s/it][A
  5%|▍         | 6/122 [01:40<32:38, 16.89s/it][A
  6%|▌         | 7/122 [01:57<32:36, 17.02s/it][A
  7%|▋         | 8/122 [02:14<32:31, 17.12s/it][A
  7%|▋         | 9/122 [02:33<33:01, 17.53s/it][A
  8%|▊         | 10/122 [02:51<33:16, 17.83s/it][A
  9%|▉         | 11/122 [03:10<33:12, 17.95s/it][A
 10%|▉         | 12/122 [03:28<33:06, 18.06s/it][A
 11%|█         | 13/122 [03:47<33:35, 18.49s/it][A
 11%|█▏        | 14/122 [04:07<33:50, 18.80s/it][A
 12%|█▏        | 15/122 [04:24<32:50, 18.41s/it][A
 13%|█▎        | 16/122 [04:44<33:13, 18.81s/it][A
 14%|█▍        | 17/122 [05:05<33:49, 19.32s/it][A
 15%|█▍        | 18/122 [05:27<34:49, 20.10

Epoch: 2, Validation Loss: 0.0590



  0%|          | 0/122 [00:00<?, ?it/s][A
  1%|          | 1/122 [00:19<39:38, 19.65s/it][A
 43%|████▎     | 52/122 [16:23<22:11, 19.01s/it][A
 90%|█████████ | 110/122 [34:17<03:41, 18.43s/it][A
 91%|█████████ | 111/122 [34:35<03:20, 18.25s/it][A
 92%|█████████▏| 112/122 [34:54<03:04, 18.50s/it][A
 93%|█████████▎| 113/122 [35:12<02:45, 18.44s/it][A
 93%|█████████▎| 114/122 [35:31<02:28, 18.57s/it][A
 94%|█████████▍| 115/122 [35:50<02:10, 18.61s/it][A
 95%|█████████▌| 116/122 [36:06<01:47, 17.84s/it][A
 40%|████      | 4/10 [2:43:40<4:02:47, 2427.90s/it]

Epoch: 4, Validation Loss: 0.0292



  0%|          | 0/122 [00:00<?, ?it/s][A
  1%|          | 1/122 [00:19<38:49, 19.25s/it][A
 49%|████▉     | 60/122 [18:05<18:30, 17.91s/it][A
 50%|█████     | 61/122 [18:23<18:12, 17.91s/it][A
 51%|█████     | 62/122 [18:40<17:45, 17.76s/it][A
 52%|█████▏    | 63/122 [18:58<17:31, 17.83s/it][A
 52%|█████▏    | 64/122 [19:18<17:47, 18.40s/it][A
 53%|█████▎    | 65/122 [19:35<17:12, 18.12s/it][A
 54%|█████▍    | 66/122 [19:54<17:06, 18.32s/it][A
 55%|█████▍    | 67/122 [20:14<17:10, 18.74s/it][A
 56%|█████▌    | 68/122 [20:32<16:48, 18.68s/it][A
 57%|█████▋    | 69/122 [20:49<16:03, 18.18s/it][A
 57%|█████▋    | 70/122 [21:07<15:38, 18.05s/it][A
 58%|█████▊    | 71/122 [21:27<15:45, 18.54s/it][A
 59%|█████▉    | 72/122 [21:43<14:52, 17.85s/it][A
 60%|█████▉    | 73/122 [22:00<14:20, 17.55s/it][A
 61%|██████    | 74/122 [22:20<14:33, 18.20s/it][A
 61%|██████▏   | 75/122 [22:39<14:34, 18.61s/it][A
 62%|██████▏   | 76/122 [22:58<14:23, 18.78s/it][A
 63%|██████▎   | 77/1

Epoch: 9, Validation Loss: 0.0165



  0%|          | 0/122 [00:00<?, ?it/s][A
  1%|          | 1/122 [00:19<39:37, 19.65s/it][A
  2%|▏         | 2/122 [00:38<38:27, 19.23s/it][A
  2%|▏         | 3/122 [00:56<37:13, 18.77s/it][A
  3%|▎         | 4/122 [01:16<37:25, 19.03s/it][A
  4%|▍         | 5/122 [01:34<36:33, 18.74s/it][A
  5%|▍         | 6/122 [01:53<36:17, 18.77s/it][A
  6%|▌         | 7/122 [02:11<35:24, 18.48s/it][A
  7%|▋         | 8/122 [02:29<35:00, 18.42s/it][A
  7%|▋         | 9/122 [02:47<34:36, 18.37s/it][A
  8%|▊         | 10/122 [03:05<34:11, 18.31s/it][A
  9%|▉         | 11/122 [03:22<33:06, 17.90s/it][A
 10%|▉         | 12/122 [03:42<33:39, 18.36s/it][A
 11%|█         | 13/122 [04:00<33:12, 18.28s/it][A

In [85]:
torch.save(model, '../model.pt')

In [86]:
saved_model = torch.load('../model.pt')
#model_state_dict

In [91]:
# Preprocess input text
input_text = val_input_texts[0]
sql = val_target_queries[0]
# Tokenize and encode input text
tokens = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
# tokens_label = tokenizer(label_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")

In [92]:
# Forward pass
outputs = saved_model.generate(input_ids=tokens.input_ids, max_new_tokens = 512)
predicted_query = tokenizer.decode(outputs[0], skip_special_tokens=True)

# print("Input Text: ", input_text)
print('-'*100)
print("Predicted Query: ", predicted_query)
print('-'*100)
print("Actual Query: ", sql)

----------------------------------------------------------------------------------------------------
Predicted Query:  SELECT T1.first_name, T1.middle_name, T1.last_name, count(*) FROM Students AS T1 JOIN Student_Enrolment_Enrolment AS T2 ON T1.student_id = T2.student_id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
----------------------------------------------------------------------------------------------------
Actual Query:  SELECT T1.student_id ,  T1.first_name ,  T1.middle_name ,  T1.last_name ,  count(*) ,  T1.student_id FROM Students AS T1 JOIN Student_Enrolment AS T2 ON T1.student_id  =  T2.student_id GROUP BY T1.student_id ORDER BY count(*) DESC LIMIT 1
