In [1]:
import boto3
import pandas as pd
import ast
import json
import matplotlib.pyplot as plt
import time
import torch
import torch.nn.functional as F
import numpy as np
import copy
# from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModel
import os
# from huggingface_hub import login
# from sentence_transformers import SentenceTransformer


# Configure pandas to display all columns and their full content without truncation
pd.set_option('display.max_columns', None)  # Show all columns
pd.set_option('display.max_colwidth', None)  # Do not truncate column contents
pd.set_option('display.expand_frame_repr', False)  # Avoid wrapping to the next line

## load questions and sql queries in spider train

In [2]:
# Define your S3 bucket and file key
bucket_name = 'sagemaker-studio-423623869859-3no3d9ie4hx'

# Initialize the S3 client
s3_client = boto3.client('s3')

# Load the file from S3
obj = s3_client.get_object(Bucket=bucket_name, Key='df_question_entities_tables.csv')
df_question_entities_tables = pd.read_csv(obj['Body'])

obj = s3_client.get_object(Bucket=bucket_name, Key='df_schema_table.csv')
df_schema_table = pd.read_csv(obj['Body'])

obj = s3_client.get_object(Bucket=bucket_name, Key='embeddings.json')
embedded_dict = json.load(obj['Body'])
for key in embedded_dict:
    embedded_dict[key] = np.array(embedded_dict[key])

In [3]:
ct = 0
ct_max = 5
for key in embedded_dict:
    print(key)
    print(len(embedded_dict[key]))
    ct += 1
    if ct == ct_max:
        break

phone numbers
768
end station
768
musical actor
768
follower
768
egg
768


In [4]:
df_question_entities_tables['entities_for_tables'] = df_question_entities_tables['entities_for_tables'].apply(ast.literal_eval)
df_question_entities_tables['entities_for_columns'] = df_question_entities_tables['entities_for_columns'].apply(ast.literal_eval)
df_question_entities_tables['tables'] = df_question_entities_tables['tables'].apply(ast.literal_eval)
df_question_entities_tables['entities'] = df_question_entities_tables['entities'].apply(ast.literal_eval)

In [5]:
df_question_entities_tables.shape

(7000, 6)

In [6]:
df_question_entities_tables.sample(frac=1).head(10)

Unnamed: 0,question,entities_for_tables,entities_for_columns,query,tables,entities
674,Give the full name and phone of the customer who has the account name 162.,"[customer, account]","[full name, phone, account name]","SELECT T2.customer_first_name , T2.customer_last_name , T2.customer_phone FROM Accounts AS T1 JOIN Customers AS T2 ON T1.customer_id = T2.customer_id WHERE T1.account_name = ""162""","[Customers, Accounts]","[customer, account, full name, phone, account name]"
5770,"Find the emails and phone numbers of all the customers, ordered by email address and phone number.",[customers],"[emails, phone numbers, email address]","SELECT email_address , phone_number FROM customers ORDER BY email_address , phone_number",[customers],"[customers, emails, phone numbers, email address]"
6233,Give the classes that have more than two captains.,[classes],[captains],SELECT CLASS FROM captain GROUP BY CLASS HAVING count(*) > 2,[captain],"[classes, captains]"
5289,Who is the founders of companies whose first letter is S?,[companies],"[founders, first letter]",SELECT founder FROM manufacturers WHERE name LIKE 'S%',[manufacturers],"[companies, founders, first letter]"
2866,"Find all the papers published by ""Aaron Turon"".",[papers],[Aaron Turon],"SELECT t3.title FROM authors AS t1 JOIN authorship AS t2 ON t1.authid = t2.authid JOIN papers AS t3 ON t2.paperid = t3.paperid WHERE t1.fname = ""Aaron"" AND t1.lname = ""Turon""","[authorship, authors, papers]","[papers, Aaron Turon]"
3635,"What are the first name, last name and id of the player with the most all star game experiences? Also list the count.",[player],"[first name, last name, id, all star game experiences, count]","SELECT T1.name_first , T1.name_last , T1.player_id , count(*) FROM player AS T1 JOIN all_star AS T2 ON T1.player_id = T2.player_id GROUP BY T1.player_id ORDER BY count(*) DESC LIMIT 1;","[player, all_star]","[player, first name, last name, id, all star game experiences, count]"
2775,"For each county, find the name of the county and the number of delegates from that county.",[county],"[name, number of delegates]","SELECT T1.County_name , COUNT(*) FROM county AS T1 JOIN election AS T2 ON T1.County_id = T2.District GROUP BY T1.County_id","[election, county]","[county, name, number of delegates]"
1400,Find the minimum salary for the departments whose average salary is above the average payment of all instructors.,"[departments, instructors]","[salary, average salary, minimum salary, average payment]","SELECT min(salary) , dept_name FROM instructor GROUP BY dept_name HAVING avg(salary) > (SELECT avg(salary) FROM instructor)",[instructor],"[departments, instructors, salary, average salary, minimum salary, average payment]"
439,How many allergies are there?,[allergies],[count],SELECT count(DISTINCT allergy) FROM Allergy_type,[Allergy_type],"[allergies, count]"
5983,Show all video game types and the number of video games in each type.,[video games],"[types, number]","SELECT gtype , count(*) FROM Video_games GROUP BY gtype",[Video_games],"[video games, types, number]"


In [7]:
df_schema_table.shape

(749, 5)

In [8]:
df_schema_table.head(3)

Unnamed: 0,database,table,processed_database,processed_table,database_and_table
0,academic,author,academic,author,academic author
1,academic,conference,academic,conference,academic conference
2,academic,domain,academic,domain,academic domain


In [9]:
df_schema_table['embedding'] = df_schema_table['database_and_table'].apply(lambda x : x.lower()).map(embedded_dict)

In [10]:
df_schema_table.isna().sum()

database              0
table                 0
processed_database    0
processed_table       0
database_and_table    0
embedding             0
dtype: int64

In [11]:
df_schema_table.head(1)

Unnamed: 0,database,table,processed_database,processed_table,database_and_table,embedding
0,academic,author,academic,author,academic author,"[0.025645531713962555, 0.0742696076631546, -0.004432312212884426, 0.012715279124677181, -0.014841729775071144, -0.014733493328094482, 0.09128323197364807, 0.03157993406057358, -0.015109768137335777, 0.019560782238841057, 0.05157171189785004, 0.05434355512261391, -0.03596740588545799, 0.036288950592279434, 0.02790152095258236, 0.00871801096946001, -0.01697620376944542, 0.03884623944759369, 0.1441171020269394, -0.035446859896183014, -0.05303538963198662, 0.01372064184397459, -0.003383792471140623, 0.033779315650463104, 0.018980078399181366, -0.03765243664383888, 0.04184789955615997, -0.001720184925943613, 0.012179960496723652, 0.0810307189822197, 0.045781951397657394, 0.016897069290280342, -0.03842933103442192, -0.030939439311623573, 1.168659764516633e-06, 0.029574420303106308, -0.004737870302051306, -0.0009187128162011504, 0.04264272376894951, -0.02760082669556141, 0.044564224779605865, 0.09870371967554092, -0.05626872181892395, 0.023318400606513023, -0.04494043067097664, 0.005029047839343548, 0.015962352976202965, 0.04478465020656586, -0.06643695384263992, 0.014912978745996952, 0.0077165234833955765, -0.03825429826974869, 0.05210268124938011, -0.02390531450510025, -0.02450334094464779, 0.014174302108585835, 0.00895096268504858, 0.008809899911284447, -0.01938359998166561, 0.09609752893447876, 0.03211810439825058, 0.043194565922021866, -0.03939264267683029, -0.006576254032552242, 0.03727315366268158, 0.04458096623420715, 0.07531660795211792, 0.0060681491158902645, 0.06635133177042007, -0.006314987316727638, 0.1002747192978859, 8.970334602054209e-05, 0.04085642844438553, 0.08544538170099258, -0.00057821354130283, 0.08088438212871552, 0.0416874997317791, 0.032923661172389984, -0.03228451684117317, -0.031806766986846924, 0.049014829099178314, 0.006755518261343241, 0.0030282107181847095, 0.09359920769929886, -0.012934922240674496, -0.001040241215378046, -0.029447052627801895, -0.005058777052909136, 0.010800394229590893, -0.02510487660765648, 0.0500672310590744, -0.009793831966817379, 0.03679483383893967, 0.02442980371415615, -0.004229494836181402, -0.030768124386668205, 0.02584119886159897, -0.07453466951847076, -0.0545203723013401, 0.028709612786769867, ...]"


In [13]:
search_corpus = df_schema_table[['embedding', 'database', 'table']].values.tolist()

In [42]:
def cosine_similarity(vec1, vec2):
    # Ensure inputs are 1D vectors
    if vec1.ndim != 1 or vec2.ndim != 1:
        print("Input vectors are not 1D:")
        print("vec1:", vec1)
        print("vec2:", vec2)
        raise ValueError("Both input vectors must be 1D.")
    
    dot_product = np.dot(vec1, vec2)
    norm_a = np.linalg.norm(vec1)
    norm_b = np.linalg.norm(vec2)
    
    # Calculate cosine similarity
    result = dot_product / (norm_a * norm_b)
    
    # Ensure result is a scalar
    if not np.isscalar(result):
        print("Result is not a scalar:")
        print("vec1:", vec1)
        print("vec2:", vec2)
        print("Result:", result)
        raise ValueError("Cosine similarity result must be a scalar.")
    
    return result



def search_for_a_single_entity(entity, search_corpus, embedded_dict):
    entity_embedding = embedded_dict[entity]
    result = []
    search_corpus_copy = copy.deepcopy(search_corpus)
    entity_embedding = embedded_dict[entity]
    for ind, ele in enumerate(search_corpus_copy):
        document_embedding = ele[0]
        search_corpus_copy[ind].append(cosine_similarity(entity_embedding, document_embedding))
    return search_corpus_copy

def merge_and_sort_multiple_search_corpus_copies(list_of_search_corpus_copies):
    if list_of_search_corpus_copies == []:
        return None
    
    result = []
    for ind, ele in enumerate(list_of_search_corpus_copies[0]):
        result.append([])
        for search_corpus_copy in list_of_search_corpus_copies:
            result[-1].append(search_corpus_copy[ind][-1])
            
    search_corpus_copy_merged_and_sorted = []
    #print(result)
    for ind, ele in enumerate(list_of_search_corpus_copies[0]):
        search_corpus_copy_merged_and_sorted.append(ele[:3]+[max(result[ind])])
        
    return search_corpus_copy_merged_and_sorted
        

def find_index_of_a_table_from_ranked_list(target_table, search_corpus_copy_merged_and_sorted):
    if not search_corpus_copy_merged_and_sorted:
        return None
    target_table = target_table.lower()
    list_of_sorted_tables = [ele[2] for ele in search_corpus_copy_merged_and_sorted]
    list_of_sorted_tables = [ele.lower() for ele in list_of_sorted_tables]
    #print(target_table)
    #print(list_of_sorted_tables)
    index = list_of_sorted_tables.index(target_table)
    return index



ct1 = 0
ct2 = 0
for ind, row in df_question_entities_tables.iterrows():
    if ind % 10 == 9:
        print(ind, ct1, ct2)
    ground_truth_tables = row['tables']
    entities_for_tables = row['entities_for_tables']
    list_of_search_corpus_copies = []
    if entities_for_tables == []:
        print('no entities for tables')
    for entity in entities_for_tables:
        list_of_search_corpus_copies.append(search_for_a_single_entity(entity.lower(), search_corpus, embedded_dict))
    search_corpus_copy_merged_and_sorted = merge_and_sort_multiple_search_corpus_copies(list_of_search_corpus_copies)
    #print(type(ground_truth_tables))
    
    ground_truth_table_indices = []
    #print(ground_truth_tables)
    for ground_truth_table in ground_truth_tables:
        try:
            ground_truth_table_indices.append(find_index_of_a_table_from_ranked_list(ground_truth_table, search_corpus_copy_merged_and_sorted))
            ct1 += 1
        except:
            ct2 += 1
            
print(ct1)
print(ct2)

9 5 6
19 12 15
29 22 15
no entities for tables
39 37 15
49 51 15
59 64 15
69 81 15
79 99 15
89 119 15
99 135 15
109 155 15
119 165 15
129 175 15
139 187 15
149 203 15
159 219 15
169 235 15
179 249 15
189 259 15
199 273 15
209 289 15
219 303 15
229 316 15
239 327 15
249 337 15
259 356 15
269 368 15
279 374 20
289 374 34
299 374 49
no entities for tables
no entities for tables
309 382 52
319 396 52
329 408 52
339 422 52
349 432 52
359 442 52
369 452 52
379 462 52
389 472 52
399 482 52
409 492 52
419 510 52
429 534 52
439 558 52
449 568 52
459 578 52
469 588 52
479 598 52
489 608 52
499 618 52
509 630 52
519 646 52
529 672 52
539 694 52
549 710 52
559 724 52
569 734 52
579 748 52
589 762 52
599 774 52
609 786 52
619 798 52
629 816 52
639 842 52
649 868 52
659 870 62
669 877 76
679 891 76
689 909 76
699 919 76
709 933 76
719 949 76
729 961 76
739 971 76
749 983 76
759 993 76
769 1003 76
779 1017 76
789 1035 76
799 1045 76
809 1055 78
819 1063 80
829 1073 88
839 1085 92
849 1099 94
859 1113