In [90]:
import pandas as pd
import numpy as np

from dotenv import load_dotenv
import os
# from openai import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings

from sklearn.metrics.pairwise import cosine_similarity

In [None]:
! pip install langchain
! pip install tiktoken

In [153]:
query_df = pd.read_csv('./git_pull/GS_BKMS2/NL2SQL/rawdata/SPIDER_SELECTED.csv')
ddl_df = pd.read_csv('./git_pull/GS_BKMS2/NL2SQL/rawdata/DDL_SELECTED.csv')

In [155]:
query_df = query_df.loc[:,['original_idx', 'db_id','question','query']]

In [156]:
query_df.head(3)

Unnamed: 0,original_idx,db_id,question,query
0,11,department_management,How many departments are led by heads who are ...,SELECT count(*) FROM department WHERE departme...
1,13,department_management,List the states where both the secretary of 'T...,SELECT T3.born_state FROM department AS T1 JOI...
2,14,department_management,Which department has more than 1 head at a tim...,"SELECT T1.department_id , T1.name , count(*)..."


In [157]:
ddl_df.shape

(434, 4)

In [158]:
qNL = query_df.loc[0, 'question']

In [159]:
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

embeddings = OpenAIEmbeddings(
    model = "text-embedding-ada-002",
    openai_api_key = OPENAI_API_KEY 
)

In [160]:
query_df.head(3)

Unnamed: 0,original_idx,db_id,question,query
0,11,department_management,How many departments are led by heads who are ...,SELECT count(*) FROM department WHERE departme...
1,13,department_management,List the states where both the secretary of 'T...,SELECT T3.born_state FROM department AS T1 JOI...
2,14,department_management,Which department has more than 1 head at a tim...,"SELECT T1.department_id , T1.name , count(*)..."


In [161]:
# qNL_vec = embeddings.embed_query(qNL)
query_df['qNL_vec'] = embeddings.embed_documents(texts = query_df['question'])

In [162]:
query_df.head(3)

Unnamed: 0,original_idx,db_id,question,query,qNL_vec
0,11,department_management,How many departments are led by heads who are ...,SELECT count(*) FROM department WHERE departme...,"[7.151246667418706e-05, 0.005482243697450374, ..."
1,13,department_management,List the states where both the secretary of 'T...,SELECT T3.born_state FROM department AS T1 JOI...,"[0.004536103547238757, -0.02295742974544512, 0..."
2,14,department_management,Which department has more than 1 head at a tim...,"SELECT T1.department_id , T1.name , count(*)...","[0.00690121912939286, 0.0009693379549008529, 0..."


In [163]:
ddl_df['CREATE_cleaned'] = ddl_df['CREATE'].apply(lambda x: x.replace('\n', ' ').replace('\t', ' '))

In [164]:
ddl_df['DDL_vec'] = embeddings.embed_documents(texts = ddl_df['CREATE_cleaned'])

In [165]:
len(ddl_df.loc[0, 'DDL_vec'])

1536

In [166]:
ddl_df.head(3)

Unnamed: 0,db_id,table_name,CREATE,INSERT,CREATE_cleaned,DDL_vec
0,department_management,department,CREATE TABLE IF NOT EXISTS department (\n\tDep...,"INSERT INTO department VALUES(1,'State','1789'...",CREATE TABLE IF NOT EXISTS department ( Depar...,"[-0.003998750327493227, 0.01744909199948951, -..."
1,department_management,head,CREATE TABLE IF NOT EXISTS head (\n\thead_ID i...,"INSERT INTO head VALUES(1,'Tiger Woods','Alaba...",CREATE TABLE IF NOT EXISTS head ( head_ID int...,"[-0.0015511123537666942, 0.01737107478940505, ..."
2,department_management,management,CREATE TABLE IF NOT EXISTS management (\n\tDep...,"INSERT INTO management VALUES(2,5,'Yes');\nINS...",CREATE TABLE IF NOT EXISTS management ( Depar...,"[-0.013841755627047652, 0.011188296730288663, ..."


In [173]:
qNL_vec = np.array(query_df['qNL_vec'].tolist())
DDL_vec = np.array(ddl_df['DDL_vec'].tolist())

similarity_matrix = cosine_similarity(qNL_vec, DDL_vec)

result_list = []

# Iterating through each query to find the top 5 similar ddl entry
for idx, qNL in enumerate(query_df['question']):
    original_idx, db_id_query = query_df.loc[idx, ['original_idx','db_id']]
    similarities = similarity_matrix[idx]

    # Pairing each similarity score with the corresponding ddl entry
    ddl_similarities = [(ddl_df.loc[i, 'db_id'], ddl_df.loc[i, 'table_name'], sim) 
                        for i, sim in enumerate(similarities)]

    # Sorting based on similarity score and selecting the top 5
    top_5_similar = sorted(ddl_similarities, key=lambda x: x[2], reverse=True)[:5]

    # Adding the results to the result list
    for db_id_ddl, table_name, sim_score in top_5_similar:
        result_list.append([original_idx, qNL, db_id_query, db_id_ddl, table_name])

# Creating the result DataFrame
result_df = pd.DataFrame(result_list, columns=['original_idx', 'qNL', 'db_id_query', 'db_id_ddl', 'table_name'])

In [174]:
result_df

Unnamed: 0,original_idx,qNL,db_id_query,db_id_ddl,table_name
0,11,How many departments are led by heads who are ...,department_management,department_store,Departments
1,11,How many departments are led by heads who are ...,department_management,department_management,management
2,11,How many departments are led by heads who are ...,department_management,department_management,department
3,11,How many departments are led by heads who are ...,department_management,hr_1,departments
4,11,How many departments are led by heads who are ...,department_management,college_2,department
...,...,...,...,...,...
1115,7955,Which dogs are of the rarest breed? Show their...,dog_kennels,dog_kennels,Breeds
1116,7955,Which dogs are of the rarest breed? Show their...,dog_kennels,dog_kennels,Treatments
1117,7955,Which dogs are of the rarest breed? Show their...,dog_kennels,dog_kennels,Dogs
1118,7955,Which dogs are of the rarest breed? Show their...,dog_kennels,pets_1,Pets


In [177]:
# Comparing db_id_query and db_id_ddl and counting the number of mismatches for each db_id_query

# Creating a column that indicates a mismatch (1) or match (0)
result_df['mismatch'] = (result_df['db_id_query'] != result_df['db_id_ddl']).astype(int)

# Counting the number of mismatches for each db_id_query
mismatch_count = result_df.groupby('original_idx')['mismatch'].sum()

mismatch_count

original_idx
11      3
13      5
14      3
36      3
37      3
       ..
7937    1
7940    2
7941    2
7954    1
7955    1
Name: mismatch, Length: 224, dtype: int64