# Create a LangChain NL2SQL Agent using Azure OpenAI and Azure SQL Database
This notebook goes through the process of creating a LangChain SQL Agent using Azure OpenAI as the LLM against an Azure SQL Database.

## Install the required python libraries
Start by installing the required libraries. Run the following at the terminal in the project folder so it references the project's requirements.txt:

```bash
pip install -r requirements.txt
```


## ODBC Driver for MS SQL Install

Use the **odbcDriverInstallUbuntu.txt** script to install the Microsoft ODBC Driver for MS SQL (version 18).

If you are not using codespace or Ubuntu, you can find the correct script to install the driver for linux [here](https://learn.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-for-sql-server), for windows [here](https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server), and for MacOS [here](https://learn.microsoft.com/en-us/sql/connect/odbc/linux-mac/install-microsoft-odbc-driver-sql-server-macos).

## Create the table in the database
(all SQL commands are in the database.sql script)

In the database that will be used for this notebook, run the following:

(create table permission and access to the dbo schema is needed. It's best to keep the roles and permissions at a minimum with working with NL2SQL)

```SQL
create table [dbo].[langtable] (id int Identity, username nvarchar(100))
GO

insert into [dbo].[langtable] (username) values('sammy')
insert into [dbo].[langtable] (username) values('mary')
insert into [dbo].[langtable] (username) values('jane')
insert into [dbo].[langtable] (username) values('fred')
insert into [dbo].[langtable] (username) values('billy')
insert into [dbo].[langtable] (username) values('jonny')
insert into [dbo].[langtable] (username) values('kenny')
insert into [dbo].[langtable] (username) values('dan')
insert into [dbo].[langtable] (username) values('frank')
insert into [dbo].[langtable] (username) values('jenny')
GO

select * from [dbo].[langtable]
GO
```


## .env file
Fill out the .env file with your server, database, and key values. For this notebook, you only need to add edit the following values.

For SQL Authentication or Entra ID Service Principal Authentication:

```BASH
AZURE_OPENAI_API_KEY="" 
AZURE_OPENAI_ENDPOINT=""
AZURE_SQL_SERVER=""
AZURE_SQL_USER=""
AZURE_SQL_PASSWORD=""
AZURE_SQL_DATABASE_NAME=""
```

For Entra ID Integrated Authentication:

```BASH
AZURE_OPENAI_API_KEY="" 
AZURE_OPENAI_ENDPOINT=""
AZURE_SQL_SERVER=""
AZURE_SQL_DATABASE_NAME=""
```

#### Locate your .env file (should be base of repo)

In [1]:
from dotenv import dotenv_values
config = dotenv_values('../../.env')

In [2]:
azure_openai_endpoint = config['AZURE_OPENAI_API_GPT_BASE']
azure_openai_key = config['AZURE_OPENAI_API_GPT_KEY']
azure_openai_version = config['AZURE_OPENAI_API_GPT_VERSION']
azure_openai_gpt_deployment = config['AZURE_OPENAI_API_GPT_DEPLOYMENT']
azure_sql_server = config['AZURE_SQL_SERVER']
azure_sql_database = config['AZURE_SQL_DATABASE_NAME'] 
#azure_sql_database = "mySampleDatabase" 

## Notebook Kernel
Be sure to select a kernel for the python notebook by using the **Select Kernel** button in the upper right of the notebook.

## Starting the Example
The first section sets up the python environment and gets any environment variables that were set.

In [3]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_types import AgentType
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import AzureChatOpenAI

from sqlalchemy import create_engine

#### Create the database connection

Using SQL Login or Service Principal

In [4]:
azure_sql_user = config['AZURE_SQL_USER']
azure_sql_pass = config['AZURE_SQL_PASSWORD']

# using Azure SQL sql login
connectionString=f'mssql+pyodbc://{azure_sql_user}:{azure_sql_pass}@{azure_sql_server}.database.windows.net/{azure_sql_database}?driver=ODBC+Driver+18+for+SQL+Server'
# or Service Principal
#connectionString=f"mssql+pyodbc:///?odbc_connect=DRIVER={{ODBC Driver 18 for SQL Server}};SERVER={azure_sql_server}.database.windows.net;PORT=1433;UID={azure_sql_user};DATABASE={azure_sql_database};PWD={AZURE_SQL_PASSWORD};Authentication=ActiveDirectoryServicePrincipal"

# define the connection
db_engine = create_engine(connectionString)

Alternative EntraID login

In [14]:
# #### extra steps if using Entra ID on Linux to get token (Interactive auth only on Windows)
# from azure.identity import DefaultAzureCredential
# import struct
# credential = DefaultAzureCredential(exclude_interactive_browser_credential=False)

# token_bytes = credential.get_token("https://database.windows.net/.default").token.encode("UTF-16-LE")
# token_struct = struct.pack(f'<I{len(token_bytes)}s', len(token_bytes), token_bytes)
# SQL_COPT_SS_ACCESS_TOKEN = 1256  # This connection option is defined by microsoft in msodbcsql.h

# connecting_string=f'Driver={{ODBC Driver 18 for SQL Server}};Server=tcp:{azure_sql_server}.database.windows.net,1433;Database={azure_sql_database};Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30'
# connectionString = 'mssql+pyodbc:///?odbc_connect={}'.format(connecting_string)

# # define the connection
# db_engine = create_engine(connectionString, connect_args={'attrs_before': {SQL_COPT_SS_ACCESS_TOKEN: token_struct}})

#### Next, test database connection.

In [8]:
# connect to the Azure SQL database
db = SQLDatabase(db_engine, view_support=True, schema="dbo")

# test the connection
print(db.dialect)
print(db.get_usable_table_names())
db.run("select convert(varchar(25), getdate(), 120)")


mssql
['BuildVersion', 'ErrorLog', 'MSchange_tracking_history', 'foodreview']


"[('2024-05-10 01:59:38',)]"

Create a reference to Azure OpenAI as the LLM to be used with the SQL agent. Replace DEPLOYMENT_NAME with the name of your Azure OpenAI gpt-3.5-turbo-instruct deployment

In [9]:
azurellm = AzureChatOpenAI(
    azure_deployment=azure_openai_gpt_deployment,
    openai_api_version = azure_openai_version,
    azure_endpoint=azure_openai_endpoint,
    api_key = azure_openai_key,
)

Run the following to create the SQL Agent

In [10]:
toolkit = SQLDatabaseToolkit(db=db, llm=azurellm)

agent_executor = create_sql_agent(
    llm=azurellm,
    toolkit=toolkit,
    verbose=True,
    #agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    agent_type="openai-tools",
)


Now, test the agent by creating a prompt using natural language asking about a database object.

In [11]:
agent_executor.invoke("count the rows in the foodreview table.")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mBuildVersion, ErrorLog, MSchange_tracking_history, foodreview[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) AS TotalRows FROM foodreview'}`


[0m[36;1m[1;3m[(99,)][0m[32;1m[1;3mThe `foodreview` table contains a total of 99 rows.[0m

[1m> Finished chain.[0m


{'input': 'count the rows in the foodreview table.',
 'output': 'The `foodreview` table contains a total of 99 rows.'}

In [12]:
agent_executor.invoke("count the rows in the langtable table.")



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mA, B, C, Employees, MSchange_tracking_history, Products, ProductsNameNotNull, ProductsWithIdentity, langtable[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) AS RowCount FROM langtable'}`


[0m[36;1m[1;3mError: (pyodbc.OperationalError) ('08S01', '[08S01] [Microsoft][ODBC Driver 18 for SQL Server]TCP Provider: Error code 0x68 (104) (SQLExecDirectW)')
[SQL: SELECT COUNT(*) AS RowCount FROM langtable]
(Background on this error at: https://sqlalche.me/e/14/e3q8)[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) AS TotalRows FROM langtable;'}`


[0m[36;1m[1;3m[(10,)][0m[32;1m[1;3mThe `langtable` table contains a total of 10 rows.[0m

[1m> Finished chain.[0m


{'input': 'count the rows in the langtable table.',
 'output': 'The `langtable` table contains a total of 10 rows.'}

### Using a dynamic few-shot prompt

* Example from https://python.langchain.com/v0.1/docs/use_cases/sql/agents

examples = [
    {"input": "List all artists.", "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of invoices.",
        "query": "SELECT COUNT(*) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "Which albums are from the year 2000?",
        "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]

In [None]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

In [None]:
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)

system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

In [None]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

In [None]:
# Example formatted prompt
prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there",
        "top_k": 5,
        "dialect": "SQLite",
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())