In [1]:
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
import streamlit as st
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough,RunnableMap
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from prompts import column_name_retriver_prompt, encoded_values_retriver_prompt, column_desc_retriver_prompt
import re
import ast
import openai
from pymongo import MongoClient
import os
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
LANGCHAIN_TRACING_V2 = os.getenv("LANGCHAIN_TRACING_V2")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
LANGCHAIN_PROJECT=os.getenv("LANGCHAIN_PROJECT")
db_url = os.getenv("DB_URL")

In [3]:
client = MongoClient(os.getenv('MONGODB_URI'), tls=True,
    tlsAllowInvalidCertificates=True)
mongo_db = client.get_database(os.getenv('DB_NAME'))
c_name = os.getenv('COLLECTION_NAME')
collection= mongo_db.Data

In [4]:
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

column_desc_retriver_prompt = """
Retrieve the description of a specific column as a string, based on the column name mentioned in the question, using the provided context.

Context:
{context}

- `Column_Description`: A dictionary where keys are column descriptions and values are column names.

Tasks:
Identify one specific description corresponding to the one column name mentioned in the question. Provide a single column description to assist the downstream Text-to-SQL Agent in formulating SQL queries involving JOINs, filtering, and subqueries.
Question:
{question}

Output format:

'Column_Description key'
    """
    
column_name_retriver_prompt = """
Retrieve the specific column names that are relevant to the question variables within the provided context. Use the following details:

Context:
{context}

Details:

- `Column_Description`: This dictionary contains column descriptions as keys and their respective column names as values.

Tasks to accomplish:
1. Identify column names.

Scan the context for column names pertinent to the question. Include the 'unitid' column of the relevant table to aid the downstream Text-to-SQL Agent in constructing SQL queries involving JOINs, filters, and subqueries.

Question: {question}

Output format:

Column names related to the question: [column_name1, column_name2, ...]
    """

encoded_values_retriver_prompt = """Retrieve the specific one encoded value based on the provided context. Answer the question using only the following context:

{context}

Details:
- `Encoded_Values`: A dictionary containing code descriptions as keys and corresponding encoded values as integers, float or string representing the discrete column values.

Tasks to perform:
1. Identify the encoded value.

Search the context for one encoded value that answers the question. Include a specific one encoded value to help the downstream Text-to-SQL Agent in forming SQL queries involving JOINs, filtering, and subqueries. Note that `unitid` is the primary key for JOIN operations.

Question: {question}
Output format:
Encoded_Values:
"""

In [5]:
def get_embedding(text):
    
    EMBEDDING_MODEL = "text-embedding-ada-002"
    """Generate an embedding for the given text using OpenAI's API."""

    # Check for valid input
    if not text or not isinstance(text, str):
        return None

    try:
        # Call OpenAI API to get the embedding
        embedding = openai.embeddings.create(input=text, model=EMBEDDING_MODEL).data[0].embedding
        return embedding
    except Exception as e:
        print(f"Error in get_embedding: {e}")
        return None

In [6]:
def vector_search(user_query, collection):
  
    # Generate embedding for the user query
    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."

    # Define the vector search pipeline
    pipeline = [
                {
                    "$vectorSearch":{
                                        "index": "vector_index",
                                        "path": "embedding",
                                        "queryVector": query_embedding,
                                        "numCandidates": 20,
                                        "limit": 4
                                    }
                },
                {
                    "$project": {
                                    "_id": 0,  # Exclude the _id field
                                    "text": 1,
                                    "Table_Description": 1, # Include the Table_Description field
                                    "Table_Name": 1,
                                    "Encoded_Values": 1,  # Include the Encoded_Values field
                                    "Column_Description": 1, # Include the Column_Description field
                                    "score": {
                                                "$meta": "vectorSearchScore"  # Include the search score
                                            }
                                }
                }
        
            ]
    # Execute the search
    results = collection.aggregate(pipeline)
    return list(results)

In [7]:
def get_table_info(question: str, template: str, context: dict):
    prompt = ChatPromptTemplate.from_template(template)

    model = ChatOpenAI()

    table_chain = (
        RunnableMap({"context": RunnablePassthrough(), "question": RunnablePassthrough()})
        | prompt
        | model
        | StrOutputParser()
    )
    return table_chain.invoke({"context": context, "question": question})

In [8]:
patterns = [
    "(\[.*?\])",  # Pattern 1
    "'(\w+)'",    # Pattern 2
]
def retrieve_list_objects(pattern, text):
    return re.findall(pattern, text)

In [9]:
def get_substring_before_colon(input_string):
    result = input_string.split(':', 1)[0]
    return result.strip()

In [10]:
def fetch_value(input_string):
  
    last_colon_index = input_string.rfind(':')
    
    if last_colon_index != -1:
        value_string = input_string[last_colon_index+1:].lstrip()
        
        try:
            return int(value_string)
        except ValueError:
            try:
                return float(value_string)
            except ValueError:
                return value_string
    else:
        return None

In [11]:
def handle_user_query(question, collection):

  get_knowledge = vector_search(question, collection)
  output = "["
  count_col = 0
  for result in get_knowledge:
    context = {}
    column_details = ''
    if result.get('Encoded_Values', '-1')!= '-1':
        
      context["Table_Name"] = result.get('Table_Name')
      context["Column_Description"]= result.get('Column_Description')
      column_details+=get_table_info(question, column_name_retriver_prompt, context)
      print(' Line 14', column_details)
      cdesc = ''
      if "Column names related to the question" in column_details:
        ls = retrieve_list_objects(patterns[1], column_details)
        encoded_values=result.get('Encoded_Values')
        encoded_values=ast.literal_eval(encoded_values)
        code_value =''
        for i in ls:
          if encoded_values.get(i, 'N/A')!= 'N/A':
            if isinstance(encoded_values.get(i), str):
              code_value= ast.literal_eval(encoded_values.get(i))
            elif isinstance(encoded_values.get(i), dict):
              code_value = encoded_values.get(i)
            code_value = get_table_info(question, encoded_values_retriver_prompt, code_value)
            print("=============25 Line========", code_value)
            cdesc=get_table_info(i, column_desc_retriver_prompt ,context["Column_Description"])
            print("============27 Line==========", cdesc)
            if count_col == 0:
              count_col+=1
              output+= "{"+ f"'Table_Name': '{result.get('Table_Name')}' ,'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)} ,'Encoded_Values': {fetch_value(code_value)}" + "}"
            else:
              output+= ", {"+ f"'Table_Name': '{result.get('Table_Name')}' , 'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)}, 'Encoded_Values': {fetch_value(code_value)}" + "}"
          else:
            cdesc=get_table_info(i, column_desc_retriver_prompt ,context["Column_Description"])
            print("=============35 Line========", cdesc)
            if count_col == 0:
              count_col+=1
              output+= "{" + f"'Table_Name': '{result.get('Table_Name')}' ,'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)}" + "}"
            else:
              output+= ", {" + f"'Table_Name': '{result.get('Table_Name')}' ,'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)}" + "}"
        
      else:
        ls = ast.literal_eval(column_details)
        encoded_values=result.get('Encoded_Values')
        encoded_values=ast.literal_eval(encoded_values)
        code_value =''
        for i in ls:
          if encoded_values.get(i, 'N/A')!= 'N/A':
            if isinstance(encoded_values.get(i), str):
              code_value= ast.literal_eval(encoded_values.get(i))
            elif isinstance(encoded_values.get(i), dict):
              code_value = encoded_values.get(i)
            code_value = get_table_info(question, encoded_values_retriver_prompt, code_value)
            print("=============51 Line========", code_value)
            cdesc=get_table_info(i, column_desc_retriver_prompt ,context["Column_Description"])
            print("=============53 Line========", cdesc)
            if count_col == 0:
              count_col+=1
              output+= "{"+ f"'Table_Name': '{result.get('Table_Name')}' , 'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)} , 'Encoded_Values': {fetch_value(code_value)}" + "}"
            else:
              output+= ", {"+ f"'Table_Name': '{result.get('Table_Name')}' , 'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)}, 'Encoded_Values': {fetch_value(code_value)}" + "}"
          else:
            cdesc=get_table_info(i, column_desc_retriver_prompt ,context["Column_Description"])
            print("=============61 Line========", cdesc)
            if count_col == 0:
              count_col+=1
              output+= "{" + f"'Table_Name': '{result.get('Table_Name')}' ,'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)}" + "}"
            else:
              output+= ", {" + f"'Table_Name': '{result.get('Table_Name')}' , 'Column_Name': '{i}', 'Column_Description': {get_substring_before_colon(cdesc)}" + "}"
  output += "]"
  return output

In [12]:
question = "Total Number of institutes in Boston"
ls = handle_user_query(question, collection)

 Line 14 Column names related to the question: ['unitid', 'pcinstnm', 'pccity', 'pcstabbr', 'pczip']
- Boston: Not Found
 Line 14 Column names related to the question: ['unitid', 'instnm', 'city', 'stabbr', 'zip']
    
 Line 14 Column names related to the question: ['unitid']
 Line 14 Column names related to the question: ['unitid']


In [14]:
ls

"[{'Table_Name': 'ic2022campuses' ,'Column_Name': 'unitid', 'Column_Description': 'Unique identification number of the institution'}, {'Table_Name': 'ic2022campuses' ,'Column_Name': 'pcinstnm', 'Column_Description': 'Branch Campus Name'}, {'Table_Name': 'ic2022campuses' ,'Column_Name': 'pccity', 'Column_Description': 'City location of institution'}, {'Table_Name': 'ic2022campuses' , 'Column_Name': 'pcstabbr', 'Column_Description': 'Column_Description', 'Encoded_Values': Not Found}, {'Table_Name': 'ic2022campuses' ,'Column_Name': 'pczip', 'Column_Description': 'ZIP code'}, {'Table_Name': 'hd2022' ,'Column_Name': 'unitid', 'Column_Description': 'Unique identification number of the institution'}, {'Table_Name': 'hd2022' ,'Column_Name': 'instnm', 'Column_Description': 'Institution (entity) name'}, {'Table_Name': 'hd2022' ,'Column_Name': 'city', 'Column_Description': 'City location of institution'}, {'Table_Name': 'hd2022' , 'Column_Name': 'stabbr', 'Column_Description': 'State abbreviation