In [None]:
import logging
import os
import yaml
import dotenv
import json
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.schema import BaseOutputParser
from IPython.display import display, Markdown
from plugins.prompt_based_sql_plugin.prompt_based_sql_plugin import PromptBasedSQLPlugin
logging.basicConfig(level=logging.INFO)

In [None]:
dotenv.load_dotenv()

In [None]:
# Initialize the OpenAI LLM
llm = OpenAI(api_key=os.getenv("OpenAI__ApiKey"))

# Initialize the SQL Plugin
sql_plugin = PromptBasedSQLPlugin(database=os.getenv("Text2Sql__DatabaseName"))


In [None]:
# Load prompt and execution settings from the file
with open("./prompt.yaml", "r") as file:
    data = yaml.safe_load(file.read())

# Create a prompt template
prompt_template = PromptTemplate(
    input_variables=["chat_history", "important_information", "user_input"],
    template=data["template"]
)

# Create an LLMChain
chain = LLMChain(llm=llm, prompt=prompt_template)


In [None]:
# Initialize chat history
history = []

In [None]:
async def ask_question(question: str, chat_history: list) -> str:
    """Asks a question to the chatbot and returns the answer.
    
    Args:
        question (str): The question to ask the chatbot.
        chat_history (list): The chat history list.
        
    Returns:
        str: The answer from the chatbot.
    """

    # Create important information prompt that contains the SQL database information.
    engine_specific_rules = "Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error."
    important_information_prompt = f"""
    [SQL DATABASE INFORMATION]
    {sql_plugin.system_prompt(engine_specific_rules=engine_specific_rules)}
    [END SQL DATABASE INFORMATION]
    """

    # Prepare the input for the LLMChain
    inputs = {
        "chat_history": chat_history,
        "important_information": important_information_prompt,
        "user_input": question
    }

    logging.info("Question: %s", question)

    # Invoke the LLMChain
    answer = await chain(inputs)

    logging.info("Answer: %s", answer)

    # Log the question and answer to the chat history.
    chat_history.append({"role": "user", "message": question})
    chat_history.append({"role": "assistant", "message": answer})

    json_answer = json.loads(str(answer))

    display(Markdown(json_answer["answer"]))

In [None]:
# Example usage
await ask_question("What are the different product categories we have?", history)