In [6]:
import langchain
from langchain_community.document_loaders import DataFrameLoader
import json
import pandas as pd
import getpass
import os
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore
from langchain_openai import OpenAIEmbeddings 
import openai 


In [2]:
from llama_index.core.retrievers import NLSQLRetriever

import sqlite3
import os
import llama_index


from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI


from sqlalchemy import text


In [6]:
from llama_index.core.schema import QueryBundle, TextNode, NodeWithScore
import logging


In [7]:
#Setting up keys
pinecone_api_key = os.environ.get("pinecone_API")

pc = Pinecone(api_key=pinecone_api_key)

api_key = os.getenv("OPENAI_API_KEY")
openai.api_key = api_key


In [26]:
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
text_sql_model = OpenAI(temperature=0.1, model="gpt-3.5-turbo")

In [None]:
index_abbrev = pc.Index("department-abbrev-db")

conn = sqlite3.connect('courses_temp.db')
cursor = conn.cursor()


engine = create_engine("sqlite:///courses_temp.db")
sql_database = SQLDatabase(engine, include_tables=["class_data"])
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["class_data"], return_raw=True
)


In [None]:
"""
1. Query Occurs
2. Query is embedded
3. Query is sent to two pinecone indexes
4. Returns include the course_id for description matches and department abbrv for department matches
       -need to decide how to grab department abbrv
       -also need to decide wx`hether to even use the course_id description match if it is poor. We can judge this with cosine similarity score
5. Store course_id from description along with score
6. Insert department abbrv into prompt
7. text to sql
8. return sql results and sql query
9. if sql query is empty then reduce where clauses and try again
10. aggregate and use relevance score to determine which results to show
11 return results


"""

In [12]:
query = 'computer science'

embedding = embeddings.embed_query(query)

vector_store_department = PineconeVectorStore(index_abbrev, embedding = embeddings)

embedding_department = vector_store_department.similarity_search_by_vector_with_score(embedding, k = 3)

In [15]:
print(embedding_department[2])

(Document(id='9b95de79-921d-41ba-88f7-3b307023eec3', metadata={}, page_content='COMP SCI AND ENGINEERING - DATA SCIENCE: CSE D'), 0.547315836)


In [22]:
doc = embedding_department[2][0]
print((doc.page_content))

COMP SCI AND ENGINEERING - DATA SCIENCE: CSE D


In [None]:
abbrev_and_scores = {}
for entry in embedding_department:
    deparment = entry[0].page_content
    abbrev = deparment.split(':', 1)[1].strip()
    abbrev_and_scores[abbrev] = entry[1]


print(abbrev_and_scores)



{'CSE': 0.631870806, 'CSS': 0.57072264, 'CSE D': 0.547315836}


In [10]:


class CustomNLSQLRetriever(NLSQLRetriever):
    def retrieve_with_metadata(
        self, str_or_query_bundle, abbrev_and_scores
    ):
        """Retrieve with metadata."""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        table_desc_str = self._get_table_context(query_bundle)
        logger.info(f"> Table desc str: {table_desc_str}")
        if self._verbose:
            print(f"> Table desc str: {table_desc_str}")

        response_str = self._llm.predict(
            self._text_to_sql_prompt,
            query_str=query_bundle.query_str,
            schema=table_desc_str,
            dialect=self._sql_database.dialect,
        )

        sql_query_str = self._sql_parser.parse_response_to_sql(
            response_str, query_bundle
        )
        # assume that it's a valid SQL query
        logger.debug(f"> Predicted SQL query: {sql_query_str}")
        if self._verbose:
            print(f"> Predicted SQL query: {sql_query_str}")

        if self._sql_only:
            sql_only_node = TextNode(text=f"{sql_query_str}")
            retrieved_nodes = [NodeWithScore(node=sql_only_node)]
            metadata = {"result": sql_query_str}
        else:
            try:
                retrieved_nodes, metadata = self._sql_retriever.retrieve_with_metadata(
                    sql_query_str
                )
            except BaseException as e:
                # if handle_sql_errors is True, then return error message
                if self._handle_sql_errors:
                    err_node = TextNode(text=f"Error: {e!s}")
                    retrieved_nodes = [NodeWithScore(node=err_node)]
                    metadata = {}
                else:
                    raise

        return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

[]
