Copyright (c) Microsoft Corporation.

Licensed under the MIT License.

# Text2SQL with Semantic Kernel & Azure OpenAI

This notebook demonstrates how the SQL plugin can be integrated with Semantic Kernel and Azure OpenAI to answer questions from the database based on the schemas provided. 

A multi-shot approach is used for SQL generation for more reliable results and reduced token usage. More details can be found in the README.md.

In [None]:
import logging
import os
import yaml
import dotenv
import json
from semantic_kernel.connectors.ai.open_ai import (
    AzureChatCompletion,
)
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.kernel import Kernel
from plugins.vector_based_sql_plugin.vector_based_sql_plugin import VectorBasedSQLPlugin
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
from IPython.display import display, Markdown
logging.basicConfig(level=logging.INFO)

## Kernel Setup

In [None]:
dotenv.load_dotenv()
kernel = Kernel()

## Set up GPT connections

In [None]:
service_id = "chat"

In [None]:
chat_service = AzureChatCompletion(
    service_id=service_id,
    deployment_name=os.environ["OpenAI__CompletionDeployment"],
    endpoint=os.environ["OpenAI__Endpoint"],
    api_key=os.environ["OpenAI__ApiKey"],
)
kernel.add_service(chat_service)

In [None]:
# Register the SQL Plugin with the Database name to use.
sql_plugin = VectorBasedSQLPlugin()
kernel.add_plugin(sql_plugin, "SQL")

## Prompt Setup

In [None]:
# Load prompt and execution settings from the file
with open("./prompt.yaml", "r") as file:
    data = yaml.safe_load(file.read())
    prompt_template_config = PromptTemplateConfig(**data)

In [None]:
chat_function = kernel.add_function(
    prompt_template_config=prompt_template_config,
    plugin_name="ChatBot",
    function_name="Chat",
)

## ChatBot setup

In [None]:
history = ChatHistory()

In [None]:
async def ask_question(question: str, chat_history: ChatHistory) -> str:
    """Asks a question to the chatbot and returns the answer.
    
    Args:
        question (str): The question to ask the chatbot.
        chat_history (ChatHistory): The chat history object.
        
    Returns:
        str: The answer from the chatbot.
    """

    # Create important information prompt that contains the SQL database information.
    engine_specific_rules = "Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error."
    sql_database_information_prompt = f"""
    [SQL DATABASE INFORMATION]
    {await sql_plugin.sql_prompt_injection(engine_specific_rules=engine_specific_rules, question=question)}
    [END SQL DATABASE INFORMATION]
    """

    arguments = KernelArguments()
    arguments["sql_database_information"] = sql_database_information_prompt
    arguments["user_input"] = question

    logging.info("Question: %s", question)

    answer = await kernel.invoke(
        function_name="Chat",
        plugin_name="ChatBot",
        arguments=arguments,
        chat_history=chat_history,
    )

    logging.info("Answer: %s", answer)

    # Log the question and answer to the chat history.
    chat_history.add_user_message(question)
    chat_history.add_message({"role": "assistant", "message": answer})

    json_answer = json.loads(str(answer))

    display(Markdown(json_answer["answer"]))

In [None]:
await ask_question("What are the different product categories we have?", history)

In [None]:
await ask_question("What is the top performing product by quantity of units sold?", history)

In [None]:
await ask_question("Which country did we sell the most to in June 2008?", history)