In [1]:
import os
import json
import torch
import transformers
import argparse
import torch.optim as optim

from tqdm import tqdm
from copy import deepcopy
from tokenizers import AddedToken
from utils.classifier_metric.evaluator import cls_metric, auc_metric
from torch.utils.data import DataLoader
from transformers import RobertaTokenizerFast
from utils.classifier_model import MyClassifier
from utils.classifier_loss import ClassifierLoss
from transformers.trainer_utils import set_seed
from torch.utils.tensorboard import SummaryWriter
from load_dataset import ColumnAndTableClassifierDataset

In [2]:
batch_size = 1
seed = 42
save_path = "models/schema_item_classifier_v2_ep4/"
dev_filepath = "../data/resdsql_pre/preprocessed_dataset_test.json"
use_contents=True
add_fk_info=True
mode = "eval" #'trian, eval or test.')

# tensorboard_save_path = None,
# train_filepath = "../data/pre-processing/preprocessed_train_spider.json",
# output_filepath = "../data/pre-processing/dataset_with_pred_probs.json",
# model_name_or_path = "roberta-large",
# gradient_descent_step = 4
# device = ''
# learning_rate = 3e-5,
# gamma = 1.0,
# alpha = 1.0,
# epochs = 128,
# patience = 32,

In [3]:
os.listdir(save_path)

['added_tokens.json',
 'tokenizer_config.json',
 'special_tokens_map.json',
 'config.json',
 'tokenizer.json',
 'tb',
 'dense_classifier.pt',
 'loss.json',
 'merges.txt',
 'vocab.json']

In [4]:
DEBUG_FLAG = False

In [5]:
def prepare_batch_inputs_and_labels(batch, tokenizer):
    batch_size = len(batch)
    
    batch_questions = [data[0] for data in batch]
    
    batch_table_names = [data[1] for data in batch]
    batch_table_labels = [data[2] for data in batch]

    batch_column_infos = [data[3] for data in batch]
    batch_column_labels = [data[4] for data in batch]
    
    if DEBUG_FLAG: print(f"batch_questions - {batch_questions}")
    if DEBUG_FLAG: print(f"batch_table_names - {batch_table_names}")
    if DEBUG_FLAG: print(f"batch_table_labels - {batch_table_labels}")
    if DEBUG_FLAG: print(f"batch_column_infos - {batch_column_infos}")
    if DEBUG_FLAG: print(f"batch_column_labels - {batch_column_labels}")
    
    batch_input_tokens, batch_column_info_ids, batch_table_name_ids, batch_column_number_in_each_table = [], [], [], []
    for batch_id in range(batch_size):
        input_tokens = [batch_questions[batch_id]]
        table_names_in_one_db = batch_table_names[batch_id]
        column_infos_in_one_db = batch_column_infos[batch_id]

        batch_column_number_in_each_table.append([len(column_infos_in_one_table) for column_infos_in_one_table in column_infos_in_one_db])

        column_info_ids, table_name_ids = [], []
        
        for table_id, table_name in enumerate(table_names_in_one_db):
            input_tokens.append("|")
            input_tokens.append(table_name)
            table_name_ids.append(len(input_tokens) - 1)
            input_tokens.append(":")
            
            for column_info in column_infos_in_one_db[table_id]:
                input_tokens.append(column_info)
                column_info_ids.append(len(input_tokens) - 1)
                input_tokens.append(",")
            
            input_tokens = input_tokens[:-1]
            
        if DEBUG_FLAG and batch_id == 0: print(f"input_tokens - {input_tokens}")
        
        batch_input_tokens.append(input_tokens)
        batch_column_info_ids.append(column_info_ids)
        batch_table_name_ids.append(table_name_ids)

    # notice: the trunction operation will discard some tables and columns that exceed the max length
    tokenized_inputs = tokenizer(
        batch_input_tokens, 
        return_tensors="pt", 
        is_split_into_words = True, 
        padding = "max_length",
        max_length = 512,
        truncation = True
    )

    batch_aligned_question_ids, batch_aligned_column_info_ids, batch_aligned_table_name_ids = [], [], []
    batch_aligned_table_labels, batch_aligned_column_labels = [], []
    
    # align batch_question_ids, batch_column_info_ids, and batch_table_name_ids after tokenizing
    for batch_id in range(batch_size):
        word_ids = tokenized_inputs.word_ids(batch_index = batch_id)

        aligned_question_ids, aligned_table_name_ids, aligned_column_info_ids = [], [], []
        aligned_table_labels, aligned_column_labels = [], []

        # align question tokens
        for token_id, word_id in enumerate(word_ids):
            if word_id == 0:
                aligned_question_ids.append(token_id)

        # align table names
        for t_id, table_name_id in enumerate(batch_table_name_ids[batch_id]):
            temp_list = []
            for token_id, word_id in enumerate(word_ids):
                if table_name_id == word_id:
                    temp_list.append(token_id)
            # if the tokenizer doesn't discard current table name
            if len(temp_list) != 0:
                aligned_table_name_ids.append(temp_list)
                aligned_table_labels.append(batch_table_labels[batch_id][t_id])

        # align column names
        for c_id, column_id in enumerate(batch_column_info_ids[batch_id]):
            temp_list = []
            for token_id, word_id in enumerate(word_ids):
                if column_id == word_id:
                    temp_list.append(token_id)
            # if the tokenizer doesn't discard current column name
            if len(temp_list) != 0:
                aligned_column_info_ids.append(temp_list)
                aligned_column_labels.append(batch_column_labels[batch_id][c_id])

        batch_aligned_question_ids.append(aligned_question_ids)
        batch_aligned_table_name_ids.append(aligned_table_name_ids)
        batch_aligned_column_info_ids.append(aligned_column_info_ids)
        batch_aligned_table_labels.append(aligned_table_labels)
        batch_aligned_column_labels.append(aligned_column_labels)

    # update column number in each table (because some tables and columns are discarded)
    for batch_id in range(batch_size):
        if len(batch_column_number_in_each_table[batch_id]) > len(batch_aligned_table_labels[batch_id]):
            batch_column_number_in_each_table[batch_id] = batch_column_number_in_each_table[batch_id][ : len(batch_aligned_table_labels[batch_id])]
        
        if sum(batch_column_number_in_each_table[batch_id]) > len(batch_aligned_column_labels[batch_id]):
            truncated_column_number = sum(batch_column_number_in_each_table[batch_id]) - len(batch_aligned_column_labels[batch_id])
            batch_column_number_in_each_table[batch_id][-1] -= truncated_column_number

    encoder_input_ids = tokenized_inputs["input_ids"]
    encoder_input_attention_mask = tokenized_inputs["attention_mask"]
    batch_aligned_column_labels = [torch.LongTensor(column_labels) for column_labels in batch_aligned_column_labels]
    batch_aligned_table_labels = [torch.LongTensor(table_labels) for table_labels in batch_aligned_table_labels]

    # print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True)))

    if torch.cuda.is_available():
        encoder_input_ids = encoder_input_ids.cuda()
        encoder_input_attention_mask = encoder_input_attention_mask.cuda()
        batch_aligned_column_labels = [column_labels.cuda() for column_labels in batch_aligned_column_labels]
        batch_aligned_table_labels = [table_labels.cuda() for table_labels in batch_aligned_table_labels]

    return encoder_input_ids, encoder_input_attention_mask, \
        batch_aligned_column_labels, batch_aligned_table_labels, \
        batch_aligned_question_ids, batch_aligned_column_info_ids, \
        batch_aligned_table_name_ids, batch_column_number_in_each_table

In [6]:
set_seed(seed)

# load tokenizer
tokenizer = RobertaTokenizerFast.from_pretrained(
    save_path,
    add_prefix_space = True
)

dataset = ColumnAndTableClassifierDataset(
    dir_ = dev_filepath,
    use_contents = use_contents,
    add_fk_info = add_fk_info
)

dataloder = DataLoader(
    dataset,
    batch_size = batch_size,
    shuffle = False,
    collate_fn = lambda x: x
)

# initialize model
model = MyClassifier(
    model_name_or_path = save_path,
    vocab_size = len(tokenizer),
    mode = mode
)

In [14]:
# load fine-tuned params
model.load_state_dict(torch.load(os.path.join(save_path,"dense_classifier.pt"),
                                 map_location=torch.device('cpu')))
if torch.cuda.is_available():
    model = model.cuda()
print(model.eval())

table_labels_for_auc, column_labels_for_auc = [], []
table_pred_probs_for_auc, column_pred_probs_for_auc = [], []

returned_table_pred_probs, returned_column_pred_probs = [], []

MyClassifier(
  (plm_encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50266, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (La

In [15]:
def sort_names_based_on_probs(probs, names):
#     print(f"probs = {probs}")
#     print(f"names = {names}")
    sorted_lists = sorted(zip(probs, names))

    sorted_table_probs = [x[0] for x in sorted_lists]
    sorted_tables = [x[1] for x in sorted_lists]

    sorted_table_probs.reverse()
    sorted_tables.reverse()

    return sorted_tables

In [16]:
counter = 0
output_final = {}
max_len = 0

In [17]:
for batch in tqdm(dataloder):
    
    if batch[0][0] in list(output_final.keys()):
        counter += 1
        continue
    
    encoder_input_ids, encoder_input_attention_mask, \
        batch_column_labels, batch_table_labels, batch_aligned_question_ids, \
        batch_aligned_column_info_ids, batch_aligned_table_name_ids, \
        batch_column_number_in_each_table = prepare_batch_inputs_and_labels(batch, tokenizer)
    
    if DEBUG_FLAG: print(f"encoder_input_ids - {encoder_input_ids}")
    if DEBUG_FLAG: print(f"encoder_input_attention_mask - {encoder_input_attention_mask}")
    
    with torch.no_grad():
        model_outputs = model(
            encoder_input_ids,
            encoder_input_attention_mask,
            batch_aligned_question_ids,
            batch_aligned_column_info_ids,
            batch_aligned_table_name_ids,
            batch_column_number_in_each_table
        )
        
    for batch_id, table_logits in enumerate(model_outputs["batch_table_name_cls_logits"]):
        table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)
        returned_table_pred_probs.append(table_pred_probs[:, 1].cpu().tolist())

        table_pred_probs_for_auc.extend(table_pred_probs[:, 1].cpu().tolist())
        table_labels_for_auc.extend(batch_table_labels[batch_id].cpu().tolist())
        
        get_table_pred_probs = table_pred_probs[:, 1].cpu().tolist()

    for batch_id, column_logits in enumerate(model_outputs["batch_column_info_cls_logits"]):
        column_number_in_each_table = batch_column_number_in_each_table[batch_id]
        column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)
        returned_column_pred_probs.append([column_pred_probs[:, 1].cpu().tolist()[sum(column_number_in_each_table[:table_id]):sum(column_number_in_each_table[:table_id+1])] \
            for table_id in range(len(column_number_in_each_table))])

        column_pred_probs_for_auc.extend(column_pred_probs[:, 1].cpu().tolist())
        column_labels_for_auc.extend(batch_column_labels[batch_id].cpu().tolist())
        
        get_col_pred_probs = column_pred_probs[:, 1].cpu().tolist()
    
    # table names and column names
    table_names = batch[0][1]
    col_names = batch[0][3]

    # sort table names
    sorted_tables = sort_names_based_on_probs(get_table_pred_probs, table_names)

    # sort column names
    table_col_dict = {}
    init_id = 0
    for idx, tab in enumerate(table_names):
        columns = col_names[idx]
        probs = get_col_pred_probs[init_id : (init_id + len(columns))]

        sorted_columns = sort_names_based_on_probs(probs, columns)
        table_col_dict[tab] = sorted_columns

        init_id = len(columns)
    
    # create ranked db schema
    db_scheme = ''
    for table in sorted_tables:
        db_scheme += (table + ' : ')

        for idx, col1 in enumerate(table_col_dict[table]):
    #         print(col1)
            db_scheme += col1

            if idx < len(table_col_dict[table])-1:
                db_scheme += ' , '
            else:
                db_scheme += ' | '

        db_scheme = db_scheme.replace("  ", ' ')
    db_scheme = db_scheme[:-3]
    
    output_final[batch[0][0]] = db_scheme
    
    check_len = batch[0][0] + " | " + db_scheme
    max_len = max(max_len, len(check_len.split()))
        
    counter += 1
#     if counter>50: break

    if counter % 50 == 0:
        output_filepath = "../data/resdsql_pre/preprocessed_dataset_test_db_schema.json"
        with open(output_filepath, "w") as fp:
            json.dump(output_final, fp, indent=4)
        
if mode == "eval":
    # calculate AUC score for table classification
    table_auc = auc_metric(table_labels_for_auc, table_pred_probs_for_auc)
    # calculate AUC score for column classification
    column_auc = auc_metric(column_labels_for_auc, column_pred_probs_for_auc)
    print("table auc:", table_auc)
    print("column auc:", column_auc)
    print("total auc:", table_auc+column_auc)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 696/696 [15:49<00:00,  1.36s/it]

table auc: 0.6382781766775827
column auc: 0.42589333719621886
total auc: 1.0641715138738015





In [21]:
batch[0][0]

'How many heads of the departments are older than 56 ?'

In [15]:
get_table_pred_probs, batch_table_labels[batch_id]

([0.4917433559894562, 0.4917435646057129], tensor([1, 0]))

In [42]:
max_len, counter

(261, 696)

In [43]:
check_len

'What are all company names that have a corresponding movie directed in the year 1999? | movie : year , title , director , gross worldwide , budget million , movie id ( [FK] ) | book club : year , result , publisher , category , book title , author or editor , book club id ( [FK] ) | culture company : incorporated in , group equity shareholding , company name , book club id ( [FK] ) , movie id ( [FK] ) , type'

In [11]:
len(output_final.keys()), type(output_final)

(263, dict)

In [16]:
get_col_pred_probs

[0.24222856760025024,
 0.2792842388153076,
 0.2792842388153076,
 0.2792842388153076,
 0.2792842388153076,
 0.2792842388153076,
 0.23681582510471344,
 0.23681406676769257,
 0.2792842388153076,
 0.23681406676769257,
 0.2792842388153076,
 0.2792842388153076,
 0.2792842388153076]

In [50]:
# import json

# output_filepath = "../data/resdsql_pre/preprocessed_dataset_train_db_schema.json"
# with open(output_filepath, "w") as fp:
#     json.dump(output_final, fp, indent=4)

In [None]:
table_names = batch[0][1]
col_names = batch[0][3]

sorted_tables = sort_names_based_on_probs(table_pred_probs, table_names)

table_col_dict = {}
init_id = 0
for idx, tab in enumerate(table_names):
    columns = col_names[idx]
    probs = column_pred_probs[init_id : (init_id + len(columns))]
    
    sorted_columns = sort_names_based_on_probs(probs, columns)
    table_col_dict[tab] = sorted_columns
    
    init_id = len(columns)
    
db_scheme = ''
for table in sorted_tables:
    db_scheme += (table + ' : ')
    
    for idx, col1 in enumerate(table_col_dict[table]):
#         print(col1)
        db_scheme += col1
        
        if idx < len(table_col_dict[table])-1:
            db_scheme += ' , '
        else:
            db_scheme += ' | '
    
    db_scheme = db_scheme.replace("  ", ' ')
db_scheme = db_scheme[:-3]

In [None]:
sorted_tables

In [None]:

    
db_scheme

In [None]:
db_scheme[:-3]

In [23]:
len(table_pred_probs)

3

In [None]:
# model_outputs
# encoder_input_ids[0]
# encoder_input_attention_mask
# batch_aligned_question_ids
# batch_aligned_column_info_ids
# batch_aligned_table_name_ids
# batch_column_number_in_each_table

In [None]:


return returned_table_pred_probs, returned_column_pred_probs