In [None]:
import langchain
from langchain_community.document_loaders import DataFrameLoader
import json
import pandas as pd
import getpass
import os
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore
from langchain_openai import OpenAIEmbeddings 
import openai 
from llama_index.core.retrievers import NLSQLRetriever

import sqlite3
import os
import llama_index


from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI


from sqlalchemy import text


In [None]:
"""
1. Number of where statements is number of filters
2. How we weight each filter in terms of importance?
    2.1 Dynamically?
3. If sorting columns with NA values maybe append 
"""

In [None]:
pinecone_api_key = os.environ.get("pinecone_API")

pc = Pinecone(api_key=pinecone_api_key)

api_key = os.getenv("OPENAI_API_KEY")
openai.api_key = api_key

In [None]:
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
text_sql_model = OpenAI(temperature=0.1, model="gpt-3.5-turbo")

In [None]:
index_abbrev = pc.Index("department-abbrev-db")

conn = sqlite3.connect('courses_temp.db')
cursor = conn.cursor()


engine = create_engine("sqlite:///courses_temp.db")
sql_database = SQLDatabase(engine, include_tables=["class_data"])
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["class_data"], return_raw=True
)


In [None]:
class Query():
    """
    This class represents a query and the its management lifecycle to find relevant courses

    """
    database = database  #this should be an engine object for small database
    database_final = database_final #this is database of all attributes
    embedder = OpenAIEmbeddings(model="text-embedding-3-small")
    sql_retriever = sql_retriever
    abbrev_vector_store = PineconeVectorStore(pc.Index("department-abbrev-db"), embedding = embedder)
    descr_vector_store = descr_vector_store
    credit_vector_score = credit_vector_score

    def __init__(self, query):
        
        """
            Some of these parameters should just be class variables not instance 
        """
        self.query = query
        self.embedding = self.embedder.embed_query(self.query)
        self.weights = {}
        self.abbrev_and_scores = self.abbrev_vector_store.similarity_search_by_vector_with_score(self.embedding, k = 3)
        self.descr_and_scores = self.descr_vector_store.similarity_search_by_vector_with_score(self.embedding, k = 3)
        self.credit_and_scores =  self.credit_vector_score.similarity_search_by_vector_with_score(self.embedding, k = 3) #k = 3 is just a temporary assumption thaty this is the optimal value

        self.sql = None
        
        


    def text_to_sql(self):
        """
            Returns the sql query from a text query

        """

        response = self.sql_retriever.retrieve(self.query, self.abbrev_and_scores, self.credit_and_scores)

        return response
    
    def run_sql(self):
        """
            Pay attention to what this is returning as the return format is different from llama_index
            NLSQLRetriever return format
        """
        with self.database.connect() as con:
            rows = con.execute(text(self.sql))
            
        return rows
    
    def grab_course_ids(self, returns):
        """
            Given a set of sql returns, returns the course ids

            NLSQLRetriever will return sql returns as a list of tuples
                DO we need to check which if course_id is index 0 in the tuple?
            Querying the database directly using sqlalchemy also returns the same list of tuples
                Same deal with course_id?
        """

        columns = self.sql.split(',')
        idx_course_id = None
        for idx, column in enumerate(columns):
            if 'course_id' in column:
                idx_course_id = idx
                break

        if idx_course_id is None:
            sql_statements = self.sql.splitlines()
            new_select_statement = sql_statements[0] + ', course_id'
            sql_statements[0] = new_select_statement
            self.sql = sql_statements

        course_ids = [row[idx_course_id] for row in returns]  

        return course_ids



    def sort_relevance(self, returns):
        """

            Given a set of sql return, sorts the by relvance based on a mix of dynamic and predefined weights.
            For now assumption is that all returns from sql are relevant

            Case 1: Sql returns are relevant w/ no return order and description is also relevant (>0.5?) 
                What are weights for each filter?
            Case 2: Sql returns are relevant w/ no return order but description is not relevant
                Reorder by different weights?
            Case 3: Sql returns are relevant w/ return order
                If return order then simply return the results

                
            Dynamic filter precdence:
            If we are returning by dept_abbrev or credit_type then that should be prioritized
                1. Credit type or department abbreviation
                2. Description

        """
        
        if "ORDER BY" in self.sql:
            return returns
        
        reordered_returns = []
        id_and_rank = {}
        if "dept_abbrev" in self.sql or "credit_type" in self.sql: #if neither are in sql then we sort by description
            order = tuple(sorted([self.credit_and_scores, self.abbrev_and_scores], key=lambda d: max(d.values()), reverse=True))
            order = order + (self.descr_and_scores,)
            for position, weight in enumerate(order):
                for id, score in weight.items():
                    for row in returns.splitlines(): #assumes lines are split by \n
                        if id in row:
                            id_and_rank[row] = (10 ** abs(position - 3)) * score

        elif max(self.descr_and_scores.values()) > 0.5:    #Score threshold which may need to be tuned. 
            id_and_rank = {row: score for row, id, score in zip(returns, self.descr_and_scores.items()) if id in row}
        else:
            reordered_returns = returns
                
        reordered_returns_initial = sorted(id_and_rank, key = lambda x: id_and_rank[x], reverse = True)    #this is a dictionary where each item is a key value pair of row and relevance scores
        reordered_returns_list = list(reordered_returns_initial.keys())
        relevance_ordered_couse_ids = self.grab_course_ids(reordered_returns_list)   #this is a list of course_ids in order of relevance

        final_returns_unsorted = self.sql_final(relevance_ordered_couse_ids)



        for course_id in relevance_ordered_couse_ids:
            for row in final_returns_unsorted:
                if course_id in row:
                    reordered_returns.append(row)


        

        return reordered_returns


    def sql_final(self, course_ids):
        """
            Given a list of course_ids creates the final sql statement to grab everything from final database
            
        """
        formatted_ids = ", ".join(f"'{course_id}'" for course_id in course_ids)
        sql_final = f'SELECT * FROM courses WHERE course_id IN ({formatted_ids})'
        with self.database_final.connect() as con:
            rows = con.execute(text(sql_final))

        return rows



    def widen_search(self):
        """ Reruns the query with a wider search, i.e. dropping where clauses until return is not None"""

        for line_num, line in enumerate(self.sql.splitlines()):
            if 'WHERE' in line:
                where_clause = line
                line_num = line_num
                break

        if where_clause and 'AND' in where_clause:  #maybe chance to drop up until last where clause max
            new_clause = where_clause.split('AND')[:-1]
            split_sql = self.sql.splitlines()
            split_sql[line_num] = new_clause
            self.sql = split_sql
        else:
            print('No more where clauses to drop')

        

        returns = self.run_sql()

        while len(returns) == 0:
            self.widen_search()
            returns = self.run_sql()

        return self.sort_relevance(returns)


    def find_relevant_course(self):
        """ Takes in weight dictionaries and uses a mix of dynamic query weighting and predifined weights
        to create a viable probability distribution for how to rank relevant queries

        """

        response = self.text_to_sql()

        sql_query = response[0].metadata['sql_query']
        returns = response[0].metadata['result']
        
        if returns:
            return self.sort_relevance(returns)

        
        return self.widen_search()