In [None]:
import langchain
from langchain_community.document_loaders import DataFrameLoader
import json
import pandas as pd
import getpass
import os
from pinecone import Pinecone, ServerlessSpec
from langchain_pinecone import PineconeVectorStore
from langchain_openai import OpenAIEmbeddings 
import openai 
from llama_index.core.retrievers import NLSQLRetriever

import sqlite3
import os
import llama_index


from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI


from sqlalchemy import text


In [None]:
"""
1. Number of where statements is number of filters
2. How we weight each filter in terms of importance?
    2.1 Dynamically?
3. If sorting columns with NA values maybe append 
"""

In [None]:
pinecone_api_key = os.environ.get("pinecone_API")

pc = Pinecone(api_key=pinecone_api_key)

api_key = os.getenv("OPENAI_API_KEY")
openai.api_key = api_key

In [None]:
class Query():
    """
    This class represents a query and the its management lifecycle to find relevant courses

    """
    default_weights = {}

    def __init__(self, query, abbrev_vector_store, descr_vector_store, credit_vector_score, database,
                 sql_retriever, embedder):
        
        """
            Some of these parameters should just be class variables not instance 
        """
        self.query = query
        self.embedding = self.get_embedding(query)
        self.weights = {}
        self.abbrev_and_scores = {}
        self.abbrev_vector_store = abbrev_vector_store
        self.descr_vector_store = descr_vector_store
        self.credit_vector_score = credit_vector_score
        self.sql = None
        self.database = database
        self.embedder = embedder
        
        

    def db_searches(self):
        dept_abbrev_w_scores = self.abbrev_vector_store.similarity_search_by_vector_with_score(self.embedding, k = 3)
        descriptions_w_scores = self.descr_vector_store.similarity_search_by_vector_with_score(self.embedding, k = 3)
        credit_w_scores = self.credit_vector_score.similarity_search_by_vector_with_score(self.embedding, k = 3)

        return descriptions_w_scores, dept_abbrev_w_scores, credit_w_scores

    def text_to_sql(self):
        descriptions_w_scores, department_abbrev_w_scores, credit_w_scores = self.db_searches()

        response = self.sql_retriever.retrieve(self.query, department_abbrev_w_scores, credit_w_scores)

        return response
    
    def run_sql(self):
        """
            Pay attention to what this is returning as the return format is different from llama_index
            NLSQLRetriever return format
        """
        with self.database.connect() as con:
            rows = con.execute(text(self.sql))
            
        return rows
        
    
    def sort_relevance(self, returns):
        
        pass

    def widen_search(self):
        """ Reruns the query with a wider search, i.e. dropping where clauses until return is not None"""

        for line, line_num in enumerate(self.sql.splitlines()):
            if 'WHERE' in line:
                where_clause = line
                line_num = line_num
                break

        if where_clause and 'AND' in where_clause:
            new_clause = where_clause.split('AND')[:-1]
            split_sql = self.sql.splitlines()
            split_sql[line_num] = new_clause
            self.sql = split_sql
        else:
            print('No more where clauses to drop')

        

        returns = self.run_sql()
        
        return self.sort_relevance(returns)


    def find_relevant_course(self):
        """ Takes in weight dictionaries and uses a mix of dynamic query weighting and predifined weights
        to create a viable probability distribution for how to rank relevant queries

        """

        response = self.text_to_sql()

        sql_query = response[0].metadata['sql_query']
        returns = response[0].metadata['result']
        
        if returns is not None:
            return self.sort_relevance(returns)

        
        return self.widen_search()