In [1]:
import numpy as np
import pandas as pd

## Building knowledge base

In [2]:
import os
from dotenv import load_dotenv
from pinecone import Pinecone

load_dotenv()

os.environ['OPENAI_API_KEY'] = os.getenv('openai_key')  
os.environ['PINECONE_API_KEY'] = os.getenv('pinecone_key')  

# configure client
pc = Pinecone(api_key=os.getenv('pinecone_key') )

  from tqdm.autonotebook import tqdm


In [3]:
# serverlessSpec typically refers to specifications or configurations related to serverless computing resources. 
from pinecone import ServerlessSpec

spec = ServerlessSpec(
    cloud="aws", region="us-west-2"
)

In [59]:
pc.list_indexes()
pc.delete_index('sql-sample-rag-test')

In [60]:
import time

# In Pinecone, an index refers to the data structure used to store and organize the vectors for efficient retrieval. 
# When you create an index in Pinecone, you're essentially creating a database tailored for vector search.

index_name = 'sql-sample-rag-test'
existing_indexes = [
    index_info["name"] for index_info in pc.list_indexes()
]

# check if index already exists (it shouldn't if this is first time)
if index_name not in existing_indexes:
    # if does not exist, create index
    pc.create_index(
        index_name,
        dimension=1536,  # dimensionality of ada 002
        metric='dotproduct',
        spec=spec
    )
    # wait for index to be initialized
    while not pc.describe_index(index_name).status['ready']:
        time.sleep(1)

# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()


{'dimension': 1536,
 'index_fullness': 0.0,
 'namespaces': {},
 'total_vector_count': 0}

Our index is now ready but it's empty. It is a vector index, so it needs vectors. As mentioned, to create these vector embeddings we will OpenAI's text-embedding-ada-002 model — we can access it via LangChain like so:

In [61]:
from langchain.embeddings.openai import OpenAIEmbeddings

embed_model = OpenAIEmbeddings(model="text-embedding-ada-002")

In [62]:
df = pd.read_csv('spider_data_with_masked_questions_v2.csv')
df.head()

num = 100
print(df.iloc[num].question)
print(df.iloc[num].masked_question)
print(df.iloc[num].schema_toks)

What are the ids of the students who either registered or attended a course?
What are the [MASK] of the [MASK] who either [MASK] or [MASK] a [MASK]
['student_course_registrations', 'student_course_attendance', 'student_id']


In [64]:
df.head()

Unnamed: 0.1,Unnamed: 0,question,masked_question,masked_question_toks,query,schema_toks,tables,columns,sql_keywords,join_involved
0,0,How many heads of the departments are older th...,How many [MASK] of the [MASK] are older than 56 ?,"['How', 'many', '[MASK]', 'of', 'the', '[MASK]...",SELECT count(*) FROM head WHERE age > 56,"['head', 'age']","[('head', '')]",['age'],"select,count,from,where",False
1,1,"List the name, born state and age of the heads...",List the [MASK] [MASK] [MASK] and [MASK] of th...,"['List', 'the', '[MASK]', '[MASK]', '[MASK]', ...","SELECT name , born_state , age FROM head ORD...","['head', 'born_state', 'age', 'name']","[('head', '')]","['born_state', 'age', 'name']","select,order,from",False
2,2,"List the creation year, name and budget of eac...","List the [MASK] year, [MASK] and [MASK] of eac...","['List', 'the', '[MASK]', 'year,', '[MASK]', '...","SELECT creation , name , budget_in_billions ...","['department', 'creation', 'budget_in_billions...","[('department', '')]","['creation', 'budget_in_billions', 'name']","select,from",False
3,3,What are the maximum and minimum budget of the...,What are the maximum and minimum [MASK] of the...,"['What', 'are', 'the', 'maximum', 'and', 'mini...","SELECT max(budget_in_billions) , min(budget_i...","['department', 'budget_in_billions']","[('department', '')]",['budget_in_billions'],"select,min,from,max",False
4,4,What is the average number of employees of the...,What is the average number of [MASK] of the [M...,"['What', 'is', 'the', 'average', 'number', 'of...",SELECT avg(num_employees) FROM department WHER...,"['department', 'num_employees', 'ranking']","[('department', '')]","['num_employees', 'ranking']","from,and,between,avg,select,where",False


In [41]:
np.random.seed(99)
split_index = np.random.permutation(len(df))
train_size = int(len(df) * 0.8) 
trainset = df.iloc[split_index[:train_size], :]
testset = df.iloc[split_index[train_size:], :]

In [42]:
trainset.head()

Unnamed: 0.1,Unnamed: 0,question,masked_question,masked_question_toks,query,schema_toks,tables,columns,sql_keywords
2282,2282,Give the total money requested by entrepreneur...,Give the total [MASK] [MASK] by [MASK] who are...,"['Give', 'the', 'total', '[MASK]', '[MASK]', '...",SELECT sum(T1.Money_Requested) FROM entreprene...,"['entrepreneur', 'people', 'Money_Requested', ...","[('entrepreneur', 'T1'), ('people', 'T2')]","['Money_Requested', 'People_ID', 'Height']","as,select,where,from,sum,join"
7783,7783,Return the codes of countries that do not spea...,Return the [MASK] of [MASK] that do not speak ...,"['Return', 'the', '[MASK]', 'of', '[MASK]', 't...",SELECT Code FROM country WHERE GovernmentForm ...,"['country', 'countrylanguage', 'LANGUAGE', 'En...","[('country', ''), ('countrylanguage', '')]","['LANGUAGE', 'English', 'Republic', 'Governmen...","select,from,where,except"
3535,3535,What are the id of songs whose format is mp3.,What are the [MASK] of songs whose [MASK] is mp3.,"['What', 'are', 'the', '[MASK]', 'of', 'songs'...","SELECT f_id FROM files WHERE formats = ""mp3""","['files', 'formats', 'mp3', 'f_id']","[('files', '')]","['formats', 'mp3', 'f_id']","select,from,where"
2013,2013,"Show gas station id, location, and manager_nam...","Show [MASK] [MASK] id, [MASK] and [MASK] for a...","['Show', '[MASK]', '[MASK]', 'id,', '[MASK]', ...","SELECT station_id , LOCATION , manager_name ...","['gas_station', 'manager_name', 'station_id', ...","[('gas_station', '')]","['manager_name', 'station_id', 'open_year', 'L...","select,from,order"
6176,6176,Return the famous titles for artists that have...,Return the [MASK] [MASK] for [MASK] that have ...,"['Return', 'the', '[MASK]', '[MASK]', 'for', '...",SELECT T1.Famous_Title FROM artist AS T1 JOIN ...,"['artist', 'volume', 'Famous_Title', 'Weeks_on...","[('artist', 'T1'), ('volume', 'T2')]","['Famous_Title', 'Weeks_on_Top', 'Artist_ID']","as,select,where,from,join"


In [65]:
from tqdm.auto import tqdm  # for progress bar

batch_size = 100
df = df.reset_index()

for i in tqdm(range(0, len(df), batch_size)):
    i_end = min(len(df), i+batch_size)
    # get batch of data
    batch = df.iloc[i:i_end]
    # generate unique ids for each chunk
    ids = [f"question_{x['index']}" for _, x in batch.iterrows()]
    # get text to embed
    texts = [x['masked_question'] for _, x in batch.iterrows()]
    # embed text
    embeds = embed_model.embed_documents(texts)
    # get metadata to store in Pinecone
    metadata = [
        {'question': x['question'],
         'masked_question': x['masked_question'],
         'query': x['query'],
         'tables': x['tables'],
         'columns': x['columns'],
         'sql_keywords': x['sql_keywords'],
         'join_involved': x['join_involved']} for i, x in batch.iterrows()
    ]
    # add to Pinecone
    index.upsert(vectors=zip(ids, embeds, metadata))

100%|██████████| 81/81 [06:35<00:00,  4.88s/it]


In [66]:
index.describe_index_stats()

{'dimension': 1536,
 'index_fullness': 0.0,
 'namespaces': {'': {'vector_count': 8034}},
 'total_vector_count': 8034}

Retrieval Augmented Generation

We've built a fully-fledged knowledge base. Now it's time to connect that knowledge base to our chatbot. To do that we'll be diving back into LangChain and reusing our template prompt from earlier.

To use LangChain here we need to load the LangChain abstraction for a vector index, called a vectorstore. We pass in our vector index to initialize the object.

In [67]:
from langchain.vectorstores import Pinecone

text_field = "query"  # the metadata field that contains our text

# initialize the vector store object
vectorstore = Pinecone(
    index, embed_model.embed_query
    ,text_field
)



In [68]:
query = "What is the [MASK] that has third largest [MASK]?"

result = vectorstore.similarity_search(query, k=3)

In [69]:

result[0].metadata['sql_keywords']

'limit,order,select,desc,from'

In [None]:
inspect_num = 21

print(f"Question: {testset.iloc[inspect_num]['question']}")
print(f"Masked Question: {testset.iloc[inspect_num]['masked_question']}")
print(f"Query: {testset.iloc[inspect_num]['query']}")
print(f"SQL Keywords: {testset.iloc[inspect_num]['sql_keywords']}")

In [71]:
query = 'Give me the [MASK] [MASK] and [MASK] [MASK] for the [MASK] with the three oldest id.'
result = vectorstore.similarity_search(query, k=3)
print(result[0].metadata['sql_keywords'])
print(result[1].metadata['sql_keywords'])
print(result[2].metadata['sql_keywords'])

select,limit,order,from
as,join,from,select,where
select,from,where


In [104]:
for item in result:
    print(item.metadata)

{'columns': "['product_name', 'product_type_code', 'product_id', 'product_price', 'supplier_id']", 'masked_question': 'Give me the [MASK] [MASK] [MASK] and [MASK] for all the [MASK] [MASK] by [MASK] [MASK] 3.', 'question': 'Give me the product type, name and price for all the products supplied by supplier id 3.', 'sql_keywords': 'from,join,as,where,select', 'tables': "[('product_suppliers', 'T1'), ('products', 'T2')]"}
{'columns': "['Name', 'Brazil', 'LifeExpectancy', 'Population']", 'masked_question': 'Give me [MASK] [MASK] and [MASK] [MASK]', 'question': 'Give me Brazil’s population and life expectancies.', 'sql_keywords': 'select,from,where', 'tables': "[('country', '')]"}
{'columns': "['id', 'dock_count', 'start_station_id']", 'masked_question': 'Which [MASK] [MASK] from the [MASK] with the largest [MASK] [MASK] Give me the [MASK] id.', 'question': 'Which trip started from the station with the largest dock count? Give me the trip id.', 'sql_keywords': 'from,order,join,as,select,des

In [116]:
import numpy as np
import pandas as pd
from collections import Counter

def summarise_keywords_from_result(index, embed_model, masked_question, join_involved = False):
    
    result = get_query_result(index = index, embed_model = embed_model, masked_question = masked_question, k = 3, join_involved = join_involved, verbose = False)
    keywords = [item['sql_keywords'] for item in result]
    keywords = [item.split(',') for item in keywords]
    keywords = [words for lst in keywords for words in lst]
    word_counts = dict(Counter(keywords))

    # If the keywords appears twice, it is considered relevant
    keywords_selected = [word for word, count in word_counts.items() if count >= 2]
    
    ignore_keywords = ['select', 'from']
    
    return [keyword for keyword in keywords_selected if keyword not in ignore_keywords]
    

def get_query_result(index, embed_model, masked_question:str, k:int, join_involved:bool = None, verbose:bool = True):
    '''
    index: pinecone index to search from
    embed_model: text embedding model
    masked_question: masked natural language question
    k: top k searches to return
    join_involved: if NONE, return results without filtering; if True, return results with JOIN keywords; if False, return results without JOIN keyword
    verbose: whether or not to print results
    '''
    
    if join_involved == None:
        query_results = index.query(
            vector=embed_model.embed_documents([masked_question]), 
            top_k=k, 
            include_metadata=True)
    else:
        query_results = index.query(
            vector=embed_model.embed_documents([masked_question]), 
            top_k=k, 
            filter={"join_involved": join_involved},
            include_metadata=True)
    
    if verbose:
        for match in query_results.matches:
            print(f"Question: {match.metadata['question']}")
            print(f"Masked question: {match.metadata['masked_question']}")
            print(f"Query: {match.metadata['query']}")
            print(f"SQL keywords: {match.metadata['sql_keywords']}")
            print(f"Join involved: {match.metadata['join_involved']}")
            print(' ')
    
    return [match.metadata for match in query_results.matches]
    

def search_in_document(df, column_name, keyword):

    if column_name in df.columns:
        mask = df[column_name].apply(lambda x: keyword in str(x).lower())
        
        result_dict = {}

        for index, row in df[mask].iterrows():
            result_dict[row['title']] = {'summary': row['summary'],
                                        'content': row['content']}
            
        return result_dict

    else:
        raise ValueError(f"Column '{column_name}' is not found in DataFrame.")
    



In [50]:
# Create a subset for small scale testing
testset_sub = testset

In [51]:
testset_sub['keywords_pred'] = testset_sub['masked_question'].apply(summarise_keywords_from_result)

testset_sub

Unnamed: 0.1,Unnamed: 0,question,masked_question,masked_question_toks,query,schema_toks,tables,columns,sql_keywords,keywords_pred
5044,5044,What is the total number of enrollment of scho...,What is the total [MASK] of [MASK] of schools ...,"['What', 'is', 'the', 'total', '[MASK]', 'of',...",SELECT sum(enr) FROM college WHERE cName NOT I...,"['college', 'tryout', 'pPos', 'enr', 'cName', ...","[('college', ''), ('tryout', '')]","['pPos', 'enr', 'cName', 'goalie']","select,not,in,where,from,sum","select,not,in,where,count,from"
6497,6497,Find the number of scientists involved for the...,Find the [MASK] of scientists involved for the...,"['Find', 'the', '[MASK]', 'of', 'scientists', ...","SELECT count(*) , T1.name FROM projects AS T1...","['assignedto', 'projects', 'code', 'project', ...","[('assignedto', 'T2'), ('projects', 'T1')]","['code', 'project', 'name', 'hours']","as,select,where,count,group,from,join","as,select,count,group,from,join"
2218,2218,What are the names of all the Japanese constru...,What are the [MASK] of all the [MASK] [MASK] t...,"['What', 'are', 'the', '[MASK]', 'of', 'all', ...",SELECT T1.name FROM constructors AS T1 JOIN co...,"['constructors', 'constructorstandings', 'poin...","[('constructors', 'T1'), ('constructorstanding...","['points', 'nationality', 'Japanese', 'name', ...","as,select,where,and,from,join","having,select,count,group,from"
355,355,Show the id and name of the aircraft with the ...,Show the [MASK] and [MASK] of the [MASK] with ...,"['Show', 'the', '[MASK]', 'and', '[MASK]', 'of...","SELECT aid , name FROM Aircraft ORDER BY dist...","['Aircraft', 'aid', 'distance', 'name']","[('Aircraft', '')]","['aid', 'distance', 'name']","select,desc,from,order,limit","select,desc,from,order,limit"
5786,5786,"What is the description of the product named ""...",What is the [MASK] of the [MASK] [MASK] [MASK],"['What', 'is', 'the', '[MASK]', 'of', 'the', '...",SELECT product_description FROM products WHERE...,"['products', 'Chocolate', 'product_name', 'pro...","[('products', '')]","['Chocolate', 'product_name', 'product_descrip...","select,from,where","select,from,where"
...,...,...,...,...,...,...,...,...,...,...
1737,1737,Count the number of gymnasts.,Count the [MASK] of [MASK],"['Count', 'the', '[MASK]', 'of', '[MASK]']",SELECT count(*) FROM gymnast,['gymnast'],"[('gymnast', '')]",[],"select,from,count","select,from,count"
3240,3240,Which department offers the most credits all t...,Which [MASK] offers the most [MASK] all together?,"['Which', '[MASK]', 'offers', 'the', 'most', '...",SELECT T3.dept_name FROM course AS T1 JOIN CLA...,"['CLASS', 'course', 'department', 'crs_code', ...","[('CLASS', 'T2'), ('course', 'T1'), ('departme...","['crs_code', 'dept_name', 'crs_credit', 'dept_...","as,select,group,desc,from,order,sum,limit,join","select,count,group,desc,from,order,limit"
5305,5305,What is the total revenue of companies with re...,What is the total [MASK] of companies with [MA...,"['What', 'is', 'the', 'total', '[MASK]', 'of',...",SELECT sum(revenue) FROM manufacturers WHERE r...,"['manufacturers', 'headquarter', 'revenue']","[('manufacturers', '')]","['headquarter', 'revenue']","select,where,min,from,sum","select,min,from,where,sum"
7203,7203,How many flights depart from 'APG'?,How many [MASK] [MASK] from 'APG'?,"['How', 'many', '[MASK]', '[MASK]', 'from', ""'...",SELECT count(*) FROM FLIGHTS WHERE SourceAirpo...,"['FLIGHTS', 'SourceAirport', 'APG']","[('FLIGHTS', '')]","['SourceAirport', 'APG']","select,from,where,count","select,from,where"


In [52]:
def prediction_accuracy(row):
    target = row['sql_keywords'].split(',')
    prediction = row['keywords_pred'].split(',')
    correct = 0
            
    for item in prediction:
        if item in target:
            correct += 1
            
    # How many keywords in the query are successfully predicted
    # How many of the predictions made are correct        
    return correct/len(target), correct/len(prediction) 

In [53]:
testset_sub = testset_sub.copy()

result = testset_sub.apply(prediction_accuracy, axis = 1)
testset_sub['accuracy_on_target'], testset_sub['accuracy_on_prediction'] = zip(*result)

In [54]:
testset_sub[['accuracy_on_target', 'accuracy_on_prediction']].describe()

Unnamed: 0,accuracy_on_target,accuracy_on_prediction
count,1607.0,1607.0
mean,0.841292,0.914419
std,0.195612,0.162986
min,0.222222,0.285714
25%,0.666667,0.875
50%,1.0,1.0
75%,1.0,1.0
max,1.0,1.0


In [55]:

# testset_sub.to_csv('keywords_prediction_result.csv')

In [48]:
from question_masking import *

# Function to search for the keyword in a specific column
def search_in_document(df, column_name, keyword):

    if column_name in df.columns:
        mask = df[column_name].apply(lambda x: keyword in str(x).lower())
        
        result_dict = {}

        for index, row in df[mask].iterrows():
            result_dict[row['title']] = {'summary': row['summary'],
                                        'content': row['content']}
            
        return result_dict

    else:
        raise ValueError(f"Column '{column_name}' is not found in DataFrame.")



In [121]:

common_schema_related_toks = ['student', 'course', 'department', 'age', 'course', 'ids', 'car', 'player', 'class', 'cities', 'member', 'employee']

db_schema = {
        'department': ['id', 'name', 'num_employees', 'creation', 'budget_billions', 'head'],
        'course': ['id', 'math', 'english', 'computer_science']
        }

# Example usage
# question = "Which head's name has the substring 'Ha'? List the id and name."What are the name of math students are there in department computer science
question = 'For each start station id, what is its name, longitude and average duration of trips started there'
masked_question = mask_question(question, db_schema, common_schema_related_toks)
print(masked_question)
get_query_result(index, embed_model, masked_question, 3, False)
summarise_keywords_from_result(index, embed_model, masked_question)

For each start [MASK] id, what is its [MASK] [MASK] and average [MASK] of trips started there
Question: For each type, what is the average tonnage?
Masked question: For each [MASK] what is the average [MASK]
Query: SELECT TYPE ,  avg(Tonnage) FROM ship GROUP BY TYPE
SQL keywords: group,select,from,avg
Join involved: False
 
Question: For each country, what is the average elevation of that country's airports?
Masked question: For each [MASK] what is the average [MASK] of that [MASK] [MASK]
Query: SELECT avg(elevation) ,  country FROM airports GROUP BY country
SQL keywords: group,select,from,avg
Join involved: False
 
Question: For each position, what is the average number of points for players in that position?
Masked question: For each [MASK] what is the average number of [MASK] for [MASK] in that [MASK]
Query: SELECT POSITION ,  avg(Points) FROM player GROUP BY POSITION
SQL keywords: group,select,from,avg
Join involved: False
 


['group', 'avg']

In [67]:
doc = pd.read_csv('md_data.csv')
doc.head()

Unnamed: 0,title,summary,content
0,SELECT,SELECT: used to select data from a database,SELECT: used to select data from a database.\n...
1,DISTINCT,DISTINCT: filters away duplicate values and re...,DISTINCT: filters away duplicate values and re...
2,WHERE,WHERE: used to filter records/rows,WHERE: used to filter records/rows.\nSELECT co...
3,ORDER BY,ORDER BY: used to sort the result-set in ascen...,ORDER BY: used to sort the result-set in ascen...
4,SELECT TOP,SELECT TOP: used to specify the number of reco...,SELECT TOP: used to specify the number of reco...


In [69]:
for keyword in keywords_predicted:
    print(search_in_document(doc, 'title', keyword))

{'AS': {'summary': 'AS: aliases are used to assign a temporary name to a table or column', 'content': 'AS: aliases are used to assign a temporary name to a table or column.\nSELECT column_name AS alias_name FROM table_name;\nSELECT column_name FROM table_name AS alias_name;\nSELECT column_name AS alias_name1, column_name2 AS alias_name2;\nSELECT column_name1, column_name2 + ‘, ‘ + column_name3 AS alias_name;\n'}}
{'SELECT': {'summary': 'SELECT: retrieve a view', 'content': 'SELECT: retrieve a view.\nSELECT * FROM view_name;\n'}, 'SELECT TOP': {'summary': 'SELECT TOP: used to specify the number of records to return from top of table', 'content': 'SELECT TOP: used to specify the number of records to return from top of table.\nSELECT TOP number columns_names FROM table_name WHERE condition;\nSELECT TOP percent columns_names FROM table_name WHERE condition;\nNot all database systems support SELECT TOP. The MySQL equivalent is the LIMIT clause\nSELECT column_names FROM table_name LIMIT offs