# Text to IRIS SQL with langchain

An experiment on how to use langchain framework, IRIS Vector Search and LLMs to generate IRIS-compatible SQL from users' prompts.

## Setup

In [1]:
!pip install --upgrade --quiet langchain langchain-openai langchain-iris pandas


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [2]:
import os
import datetime
import hashlib
from copy import deepcopy

from sqlalchemy import create_engine

import getpass

import pandas as pd

from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.docstore.document import Document
from langchain_community.document_loaders import DataFrameLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain.globals import set_llm_cache
from langchain.cache import SQLiteCache

from langchain_iris import IRISVector

In [3]:
# Cache for LLM calls
set_llm_cache(SQLiteCache(database_path=".langchain.db"))

In [4]:
# IRIS database connection parameters
os.environ["ISC_LOCAL_SQL_HOSTNAME"] = "localhost"
os.environ["ISC_LOCAL_SQL_PORT"] = "1972"
os.environ["ISC_LOCAL_SQL_NAMESPACE"] = "IRISAPP"
os.environ["ISC_LOCAL_SQL_USER"] = "_system"
os.environ["ISC_LOCAL_SQL_PWD"] = "SYS"

In [5]:
if not "OPENAI_API_KEY" in os.environ:
    os.environ["OPENAI_API_KEY"] = getpass.getpass()

 ········


In [6]:
# IRIS database connection string
args = {
    'hostname': os.getenv("ISC_LOCAL_SQL_HOSTNAME"), 
    'port': os.getenv("ISC_LOCAL_SQL_PORT"), 
    'namespace': os.getenv("ISC_LOCAL_SQL_NAMESPACE"), 
    'username': os.getenv("ISC_LOCAL_SQL_USER"), 
    'password': os.getenv("ISC_LOCAL_SQL_PWD")
}
iris_conn_str = f"iris://{args['username']}:{args['password']}@{args['hostname']}:{args['port']}/{args['namespace']}"

In [7]:
# Connection to IRIS database
engine = create_engine(iris_conn_str)
cnx = engine.connect().connection

In [8]:
# Dict for context information for system prompt
context = {}
context["top_k"] = 3

## Prompt creation

First, a initial prompt with basic instructions for the LLM create SQL queries is provided. This template was derived from [langchain default prompts for MSSQL](https://github.com/langchain-ai/langchain/blob/b00c0fc558c278e3299a81ddcda9c61cdeff3043/libs/langchain/langchain/chains/sql_database/prompt.py#L106), with IRIS database specific instructions.

In [9]:
# Basic prompt template with IRIS database SQL instructions
iris_sql_template = """
You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".
Use double quotes to delimit columns identifiers.
Return just plain SQL; don't apply any kind of formatting.
"""

The basic prompt just set up the LLM to perform as an SQL expert with fine tunning to IRIS database hints. An auxiliary prompt providing information about the database tables schema is needed to be provided in order to avoid hallucionations.

In [10]:
# SQL template extension for including tables context information
tables_prompt_template = """
Only use the following tables:
{table_info}
"""

A good strategy to improve the LLM response accuracy is presenting some examples to the LLM. This techinique is know as few shots prompting. The following is a teplate to add a few shot prompt to the basic one.

In [11]:
# SQL template extension for including few shots
prompt_sql_few_shots_template = """
Below are a number of examples of questions and their corresponding SQL queries.

{examples_value}
"""

In [12]:
# Few shots prompt template
example_prompt_template = "User input: {input}\nSQL query: {query}"
example_prompt = PromptTemplate.from_template(example_prompt_template)

Build the user prompt using the few shots as template.

In [13]:
# User prompt template
user_prompt = "\n"+example_prompt.invoke({"input": "{input}", "query": ""}).to_string()

Finally, compose all prompts to create the final one.

In [14]:
# Complete prompt template
prompt = (
    ChatPromptTemplate.from_messages([("system", iris_sql_template)])
    + ChatPromptTemplate.from_messages([("system", tables_prompt_template)])
    + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)])
    + ChatPromptTemplate.from_messages([("human", user_prompt)])
)
prompt

ChatPromptTemplate(input_variables=['examples_value', 'input', 'table_info', 'top_k'], messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['top_k'], template='\nYou are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes (\'\') to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay a

This prompt expects the variables `examples_value`, `input`, `table_info` and `top_k`.

Let's provide some foo values for them in order to see how the prompt is sent to the LLM:

In [15]:
prompt_value = prompt.invoke({
    "top_k": "<top_k>",
    "table_info": "<table_info>",
    "examples_value": "<examples_value>",
    "input": "<input>"
})
print(prompt_value.to_string())

System: 
You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most <top_k> results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".
Use double quotes to delimit columns identifiers.
Return just plain SQL;

So, in order to send this prompt to the LLM we must provide those variables. Let's do that!

## Providing table information

May sounds kind of obvious that you need to provide information about tables in order to SQL queries creation. However, if we don't provide to the LLM which tables it should base its response, probably it'll send you some query that possibly will seems plausible but probably won't work, an example of halluciantion.

So, first let's create a function that queries the INFORMATION_SCHEMA and returns the tabels for a schema.

In [16]:
def get_table_definitions_array(cnx, schema, table=None):
    """
    Generate SQL `CREATE TABLE` statements for tables in a given schema.

    This function queries the database to retrieve column definitions and 
    constructs SQL `CREATE TABLE` statements for each table within the specified schema. 
    If a specific table name is provided, it generates the statement only for that table.

    Args:
        cnx: An IRIS connection object.
        schema (str): The name of the schema to retrieve table definitions from.
        table (str, optional): The name of a specific table to retrieve the definition for. 
                               If not provided, definitions for all tables in the schema are retrieved.

    Returns:
        list of str: A list of SQL `CREATE TABLE` statements as strings.
    """
    
    cursor = cnx.cursor()

    # Base query to get columns information
    query = """
    SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_SCHEMA = %s
    """
    
    # Parameters for the query
    params = [schema]

    # Adding optional filters
    if table:
        query += " AND TABLE_NAME = %s"
        params.append(table)
    
    # Execute the query
    cursor.execute(query, params)

    # Fetch the results
    rows = cursor.fetchall()
    
    # Process the results to generate the table definition(s)
    table_definitions = {}
    for row in rows:
        table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row
        if table_name not in table_definitions:
            table_definitions[table_name] = []
        table_definitions[table_name].append({
            "column_name": column_name,
            "column_type": column_type,
            "is_nullable": is_nullable,
            "column_default": column_default,
            "column_key": column_key,
            "extra": extra
        })

    primary_keys = {}
    
    # Build the output string
    result = []
    for table_name, columns in table_definitions.items():
        table_def = f"CREATE TABLE {schema}.{table_name} (\n"
        column_definitions = []
        for column in columns:
            column_def = f"  {column['column_name']} {column['column_type']}"
            if column['is_nullable'] == "NO":
                column_def += " NOT NULL"
            if column['column_default'] is not None:
                column_def += f" DEFAULT {column['column_default']}"
            if column['extra']:
                column_def += f" {column['extra']}"
            column_definitions.append(column_def)
        if table_name in primary_keys:
            pk_def = f"  PRIMARY KEY ({', '.join(primary_keys[table_name])})"
            column_definitions.append(pk_def)
        table_def += ",\n".join(column_definitions)
        table_def += "\n);"
        result.append(table_def)

    return result

Now, we are able to retrieve all tables for a schema.

For this example, we're going to use the Aviation schema, available [here](https://openexchange.intersystems.com/package/Samples-Aviation).

In [17]:
get_table_definitions_array(cnx, "Aviation")

['CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  AccidentExplosion varchar,\n  AccidentFire varchar,\n  AirFrameHours varchar,\n  AirFrameHoursSince varchar,\n  AirFrameHoursSinceLastInspection varchar,\n  AircraftCategory varchar,\n  AircraftCertMaxGrossWeight integer,\n  AircraftHomeBuilt varchar,\n  AircraftKey integer NOT NULL,\n  AircraftManufacturer varchar,\n  AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  AircraftSerialNo varchar,\n  AircraftSeries varchar,\n  Damage varchar,\n  DepartureAirportId varchar,\n  DepartureCity varchar,\n  DepartureCountry varchar,\n  DepartureSameAsEvent varchar,\n  DepartureState varchar,\n  DepartureTime integer,\n  DepartureTimeZone varchar,\n  DestinationAirportId varchar,\n  DestinationCity varchar,\n  DestinationCountry varchar,\n  DestinationSameAsLocal varchar,\n  DestinationState varchar,\n  EngineCount integer,\n  EvacuationOccurred varchar,\n  EventId varchar NOT NULL,\n  FlightMedi

## Selecting the most relevant tables

For small databases like the used in this example is OK to send DDL for all tables in the prompt. However, real databases easily could have hundreds or even thousands of tables, turning unfeasable sending all of them to the LLM.

Furthermore, is unlikely that the LLM needs to know all tables in order to create SQL queries.

So, we have a way to select just the most relevant tables given an user prompt. 

We can achieve such functionality using sematic search capabilities with IRIS Vector Search.

Note: this approach only makes sense if your SQL elements identifiers (tables, fields, keys etc) has some meaning. If your identifiers are just codes, you must use some kind of data dictionary instead.

First, let's retrieve the table information into a pandas dataframe.

In [18]:
# Retrieve the tables information into a pandas dataframe
table_def = get_table_definitions_array(cnx=cnx, schema='Aviation')
table_df = pd.DataFrame(data=table_def, columns=["col_def"])
table_df["id"] = table_df.index + 1
table_df

Unnamed: 0,col_def,id
0,CREATE TABLE Aviation.Aircraft (\n Event bigi...,1
1,CREATE TABLE Aviation.Crew (\n Aircraft varch...,2
2,CREATE TABLE Aviation.Event (\n ID bigint NOT...,3


Next, let's split the definitions into langchain [Documents](https://api.python.langchain.com/en/latest/documents/langchain_core.documents.base.Document.html). That step allows big chuncks of text to be sent to text embeddings extraction which generally has length limitations.

In [19]:
loader = DataFrameLoader(table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
tables_docs = text_splitter.split_documents(documents)
tables_docs

[Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,\n  ID varchar NOT NULL,\n  AccidentExplosion varchar,\n  AccidentFire varchar,\n  AirFrameHours varchar,\n  AirFrameHoursSince varchar,\n  AirFrameHoursSinceLastInspection varchar,\n  AircraftCategory varchar,\n  AircraftCertMaxGrossWeight integer,\n  AircraftHomeBuilt varchar,\n  AircraftKey integer NOT NULL,\n  AircraftManufacturer varchar,'),
 Document(metadata={'id': 1}, page_content='AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  AircraftSerialNo varchar,\n  AircraftSeries varchar,\n  Damage varchar,\n  DepartureAirportId varchar,\n  DepartureCity varchar,\n  DepartureCountry varchar,\n  DepartureSameAsEvent varchar,\n  DepartureState varchar,\n  DepartureTime integer,\n  DepartureTimeZone varchar,\n  DestinationAirportId varchar,\n  DestinationCity varchar,'),
 Document(metadata={'id': 1}, page_content='DestinationCountry varchar,\n  DestinationSameAsLocal varch

Now we're ready to extract the embeddings vectors and store them into IRIS, using the `IRISVector` class from [langchain-iris](https://pypi.org/project/langchain-iris/).

In [20]:
tables_vector_store = IRISVector.from_documents(
    embedding = OpenAIEmbeddings(), 
    documents = tables_docs,
    connection_string=iris_conn_str,
    collection_name="sql_tables",
    pre_delete_collection=True
)

Note: the `pre_delete_collection` flag is used here for demo purpose only to ensure a fresh collection on every test run during development. In production environment this flag probably should not be used - its default value is `False`.

Now, we are able to find the most relevants documents regarding a user prompt.

For instance, let's query for aircrafts manufactures:

In [21]:
input = "List the first 2 manufacturers"
relevant_tables_docs = tables_vector_store.similarity_search(input, k=3)
relevant_tables_docs

[Document(metadata={'id': 1}, page_content='GearType varchar,\n  LastInspectionDate timestamp,\n  LastInspectionType varchar,\n  Missing varchar,\n  OperationDomestic varchar,\n  OperationScheduled varchar,\n  OperationType varchar,\n  OperatorCertificate varchar,\n  OperatorCertificateNum varchar,\n  OperatorCode varchar,\n  OperatorCountry varchar,\n  OperatorIndividual varchar,\n  OperatorName varchar,\n  OperatorState varchar,\n  Owner varchar,'),
 Document(metadata={'id': 1}, page_content='AircraftModel varchar,\n  AircraftRegistrationClass varchar,\n  AircraftSerialNo varchar,\n  AircraftSeries varchar,\n  Damage varchar,\n  DepartureAirportId varchar,\n  DepartureCity varchar,\n  DepartureCountry varchar,\n  DepartureSameAsEvent varchar,\n  DepartureState varchar,\n  DepartureTime integer,\n  DepartureTimeZone varchar,\n  DestinationAirportId varchar,\n  DestinationCity varchar,'),
 Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,

As you can see in the `id` metadata field, only the table with ID 1 was retrieved for the 3 most similar documents. By looking up that ID in the `table_df` dataframe by its `id` column, we can find that only the `Aviation.Aircraft` table seems to be relevant to this query, which makes sense.

But this isn't always perfect. Let's try another prompt:

In [22]:
input = "List the top 10 most crash sites"
relevant_tables_docs = tables_vector_store.similarity_search(input, k=3)
relevant_tables_docs

[Document(metadata={'id': 3}, page_content='LocationSiteZipCode varchar,\n  LocationState varchar,\n  MidAir varchar,\n  NTSBId varchar,\n  NarrativeCause varchar,\n  NarrativeFull varchar,\n  NarrativeSummary varchar,\n  OnGroundCollision varchar,\n  SkyConditionCeiling varchar,\n  SkyConditionCeilingHeight integer,\n  SkyConditionNonCeiling varchar,\n  SkyConditionNonCeilingHeight integer,\n  TimeZone varchar,\n  Type varchar,\n  Visibility varchar,'),
 Document(metadata={'id': 3}, page_content='InjuriesGroundSerious integer,\n  InjuriesHighest varchar,\n  InjuriesTotal integer,\n  InjuriesTotalFatal integer,\n  InjuriesTotalMinor integer,\n  InjuriesTotalNone integer,\n  InjuriesTotalSerious integer,\n  InvestigatingAgency varchar,\n  LightConditions varchar,\n  LocationCity varchar,\n  LocationCoordsLatitude double,\n  LocationCoordsLongitude double,\n  LocationCountry varchar,'),
 Document(metadata={'id': 1}, page_content='CREATE TABLE Aviation.Aircraft (\n  Event bigint NOT NULL,

Despite of the (correct) table `Aviation.Event` (ID 3) was selected twice, the (incorrect) table `Aviation.Aircraft` (ID 1) was also selected.

This probably could be improved by adding some threshold for result aceptance, but this is subject for future work.

So, in general, by using this apporach, we have a way to filter only the top k relevante tables which will be sent to LLM, helping in decrease the prompt length. 

Let's define a function for it and populate the variable wich will be used in the `table_info`:

In [23]:
def get_relevant_tables(user_input, tables_vector_store, table_df):
    relevant_tables_docs = tables_vector_store.similarity_search(user_input)
    relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs]
    indices = table_df["id"].isin(relevant_tables_docs_indices)
    relevant_tables_array = [x for x in table_df[indices]["col_def"]]
    return relevant_tables_array

## Selecting the most relevants examples (few shots)

The next variable we need to provide information is `examples_value`.

First let's define a list with examples, such examples tries to cover most aspects of IRIS SQL sintaxe and tables available int the database to help the LLM to not hallucinate.

In [24]:
examples = [
    {
        "input": "List all aircrafts.", 
        "query": "SELECT * FROM Aviation.Aircraft"
    },{
        "input": "Find all incidents for the aircraft with ID 'N12345'.",
        "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"
    },{
        "input": "List all incidents in the 'Commercial' operation type.",
        "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')"
    },{
        "input": "Find the total number of incidents.",
        "query": "SELECT COUNT(*) FROM Aviation.Event"
    },{
        "input": "List all incidents that occurred in 'Canada'.",
        "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'"
    },{
        "input": "How many incidents are associated with the aircraft with AircraftKey 5?",
        "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5"
    },{
        "input": "Find the total number of distinct aircrafts involved in incidents.",
        "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft"
    },{
        "input": "List all incidents that occurred after 5 PM.",
        "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700"
    },{
        "input": "Who are the top 5 operators by the number of incidents?",
        "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC"
    },{
        "input": "Which incidents occurred in the year 2020?",
        "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'"
    },{
        "input": "What was the month with most events in the year 2020?",
        "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC"
    },{
        "input": "How many crew members were involved in incidents?",
        "query": "SELECT COUNT(*) FROM Aviation.Crew"
    },{
        "input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.",
        "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012"
    },{
        "input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.",
        "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5"
    },{
        "input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.",
        "query": "SELECT c.CrewNumber AS \"Crew Number\", c.Age, c.Sex Gender, e.EventDate AS \"Event Date\", e.LocationCity AS \"Location City\", e.LocationState AS \"Location State\" FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'"
    },
]

And as for the case of tables, we don't need to send all the examples to the LLM. In a production environment, this list will be constantly increased with new examples inputed by users and soom would be unfeaseable to send all of them.

In this case we will also user IRIS Vector Search, but this time with the help of the [SemanticSimilarityExampleSelector class](https://api.python.langchain.com/en/latest/example_selectors/langchain_core.example_selectors.semantic_similarity.SemanticSimilarityExampleSelector.html):

In [25]:
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    IRISVector,
    k=5,
    input_keys=["input"],
    connection_string=iris_conn_str,
    collection_name="sql_samples",
    pre_delete_collection=True
)

Remembering that the pre_delete_collection flag is used here for demo purpose only to ensure a fresh collection on every test run during development and should be avoided in production.

This class allows us to retrieve the k first examples that best fit an user prompt. Let's chek it out:

In [26]:
input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
example_selector.select_examples({"input": input})

[{'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.',
  'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012'},
 {'input': "Find all incidents for the aircraft with ID 'N12345'.",
  'query': "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"},
 {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.',
  'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'List all aircrafts.', 'query': 'SELECT * FROM Aviation.Aircraft'},
 {'input': 'Find the total number of distinct aircrafts involved in incidents.',
  'query': 'SELECT C

As you can see, the results have no `COUNT(*)` statements, despite of having a lot of examples using it in the examples list. 

Now let's ask for quantities:

In [27]:
input = "What is the number of incidents involving Boeing aircraft."
example_selector.select_examples({"input": input})

[{'input': 'How many incidents are associated with the aircraft with AircraftKey 5?',
  'query': 'SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5'},
 {'input': 'Find the total number of distinct aircrafts involved in incidents.',
  'query': 'SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft'},
 {'input': 'How many crew members were involved in incidents?',
  'query': 'SELECT COUNT(*) FROM Aviation.Crew'},
 {'input': 'Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.',
  'query': 'SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5'},
 {'input': 'List all incidents with detailed aircraft information for incidents that occurred in the year 2012.',
  'query': 'SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft 

Now note that the first examples - the most similars, have the `COUNT(*)` statement.

And as for the case of tables, we also have some results that aren't good and a complementary approach should be implemented in the future to exclude such bad results.

## Accuracy test

Now we have all information needed to set up the prompt and send it to the LLM.

Let's create a function to do that:

In [28]:
def get_sql_from_text(context, prompt, user_input, use_few_shots, tables_vector_store, table_df, example_selector=None, example_prompt=None):
    relevant_tables = get_relevant_tables(user_input, tables_vector_store, table_df)
    context["table_info"] = "\n\n".join(relevant_tables)

    examples = example_selector.select_examples({"input": user_input}) if not example_selector is None else []
    context["examples_value"] = "\n\n".join([
        example_prompt.invoke(x).to_string() for x in examples
    ])
    
    model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    output_parser = StrOutputParser()
    chain_model = prompt | model | output_parser
    
    response = chain_model.invoke({
        "top_k": context["top_k"],
        "table_info": context["table_info"],
        "examples_value": context["examples_value"],
        "input": user_input
    })
    return response

Let's execute the prompt with and without examples:

In [29]:
# Prompt execution **with** few shots
input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
response_with_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=False, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
    example_selector=example_selector, 
    example_prompt=example_prompt,
)
print(response_with_few_shots)

SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.EventId = a.EventId
WHERE Year(e.EventDate) = 2010


In [30]:
# Prompt execution **without** few shots
input = "Find all events in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model."
response_with_no_few_shots = get_sql_from_text(
    context, 
    prompt, 
    user_input=input, 
    use_few_shots=False, 
    tables_vector_store=tables_vector_store, 
    table_df=table_df,
)
print(response_with_no_few_shots)

SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.ID = a.Event
WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'


Let's check if those SQL queries work, by creating some utility functions:

In [31]:
def execute_sql_query(cnx, query):
    try:
        cursor = cnx.cursor()
        cursor.execute(query)
        rows = cursor.fetchall()
        return rows
    except:
        print('error on running query: ')
        print(query)
        print('-'*80)
    return None

In [32]:
# SQL test for prompt **with** few shots
print("SQL is OK" if not execute_sql_query(cnx, response_with_few_shots) is None else "SQL is not OK")

SQL is OK


In [33]:
# SQL test for prompt **without** few shots
print("SQL is OK" if not execute_sql_query(cnx, response_with_no_few_shots) is None else "SQL is not OK")

error on running query: 
SELECT TOP 3 "EventId", "EventDate", "LocationCity", "LocationState", "AircraftManufacturer", "AircraftModel"
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.ID = a.Event
WHERE e.EventDate >= '2010-01-01' AND e.EventDate < '2011-01-01'
--------------------------------------------------------------------------------
SQL is not OK


In order to perform a comparasion, let's define a set of SQL queries and them expected results:

In [34]:
tests = [{
    # SELECT TOP 3 YEAR(EventDate) AS EventYear, COUNT(*) AS EventCount
    # FROM Aviation.Event
    # GROUP BY YEAR(EventDate)
    # ORDER BY EventCount DESC
    "input": "What were the top 3 years with the most recorded events?",
    "expected": [{128, 2003}, {122, 2007}, {117, 2005}]
},{
    # SELECT COUNT(*) FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE AircraftManufacturer = 'Boeing')
    "input": "How may incidents involving Boeing aircraft.",
    "expected": [{5}]
},{
    # SELECT count(*) FROM Aviation.Event WHERE InjuriesTotalFatal > 0
    "input": "How may incidents that resulted in fatalities.",
    "expected": [{237}]
},{
    # SELECT e.EventId, e.EventDate, c.CrewNumber, c.Age, c.Sex FROM Aviation.Event e JOIN Aviation.Crew c ON e.EventId = c.EventId WHERE YEAR(e.EventDate) = 2013
    "input": "List event Id and date and, crew number, age and gender for incidents that occurred in 2013.",
    "expected": [{1, datetime.datetime(2013, 3, 4, 11, 6), '20130305X71252', 59, 'M'},
                 {1, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 32, 'M'},
                 {2, datetime.datetime(2013, 1, 1, 15, 0), '20130101X94035', 35, 'M'},
                 {1, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 25, 'M'},
                 {2, datetime.datetime(2013, 1, 12, 15, 0), '20130113X42535', 34, 'M'},
                 {1, datetime.datetime(2013, 2, 1, 15, 0), '20130203X53401', 29, 'M'},
                 {1, datetime.datetime(2013, 2, 15, 15, 0), '20130218X70747', 27, 'M'},
                 {1, datetime.datetime(2013, 3, 2, 15, 0), '20130303X21011', 49, 'M'},
                 {1, datetime.datetime(2013, 3, 23, 13, 52), '20130326X85150', 'M', None}]
},{
    # SELECT COUNT(*) FROM Aviation.Event WHERE LocationCountry = 'United States'
    "input": "Find the total number of incidents that occurred in the United States.",
    "expected": [{1178}]
},{
    # SELECT LocationCoordsLatitude, LocationCoordsLongitude, InjuriesTotal FROM Aviation.Event WHERE YEAR(EventDate) = 2010 AND InjuriesTotal > 5
    "input": "List all incidents lattitude and longitude coordinates with more than 5 injuries that occurred in 2010.",
    "expected": [{-78.76833333333333, 43.25277777777778}]
},{
    # SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel 
    # FROM Aviation.Event e 
    # JOIN Aviation.Aircraft a ON e.EventId = a.EventId 
    # WHERE YEAR(e.EventDate) = 2010 AND LocationState = 'New York'
    "input": "Find all incidents in 2010 informing the Event Id and date, location city and state, aircraft manufacturer and model.",
    "expected": [
        {datetime.datetime(2010, 5, 20, 13, 43), '20100520X60222', 'CIRRUS DESIGN CORP', 'Farmingdale', 'New York', 'SR22'},
        {datetime.datetime(2010, 4, 11, 15, 0), '20100411X73253', 'CZECH AIRCRAFT WORKS SPOL SRO', 'Millbrook', 'New York', 'SPORTCRUISER'},
        {'108', datetime.datetime(2010, 1, 9, 12, 55), '20100111X41106', 'Bayport', 'New York', 'STINSON'},
        {datetime.datetime(2010, 8, 1, 14, 20), '20100801X85218', 'A185F', 'CESSNA', 'New York', 'Newfane'}
    ]
}]

Now, let's define a function that checks if a SQL query returns the expected results:

In [35]:
def sql_result_equals(cnx, query, expected):
    rows = execute_sql_query(cnx, query)
    result = [set(row._asdict().values()) for row in rows or []]
    if result != expected and not rows is None:
        print('result not expected for query: ')
        print(query)
        print('-'*80)
    return result == expected

Then let's put all together and define a function to test each test case and calculate the test accuracy:

In [36]:
def execute_tests(cnx, context, prompt, use_few_shots, tables_vector_store, table_df, example_selector, example_prompt):
    tests_generated_sql = [(x, get_sql_from_text(
            context, 
            prompt, 
            user_input=x['input'], 
            use_few_shots=use_few_shots, 
            tables_vector_store=tables_vector_store, 
            table_df=table_df,
            example_selector=example_selector if use_few_shots else None, 
            example_prompt=example_prompt if use_few_shots else None,
        )) for x in deepcopy(tests)]
    
    tests_sql_executions = [(x[0], sql_result_equals(cnx, x[1], x[0]['expected'])) 
                            for x in tests_generated_sql]
    
    accuracy = sum(1 for i in tests_sql_executions if i[1] == True) / len(tests_sql_executions)
    print(f'accuracy: {accuracy}')
    print('-'*80)

In [37]:
# Accuracy tests for prompts executed **without** few shots
use_few_shots=False
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)

error on running query: 
SELECT "EventDate", COUNT("EventId") as "TotalEvents"
FROM Aviation.Event
GROUP BY "EventDate"
ORDER BY "TotalEvents" DESC
TOP 3;
--------------------------------------------------------------------------------
error on running query: 
SELECT "EventId", "EventDate", "C"."CrewNumber", "C"."Age", "C"."Sex"
FROM "Aviation.Event" AS "E"
JOIN "Aviation.Crew" AS "C" ON "E"."ID" = "C"."EventId"
WHERE "E"."EventDate" >= '2013-01-01' AND "E"."EventDate" < '2014-01-01'
--------------------------------------------------------------------------------
result not expected for query: 
SELECT TOP 3 "e"."EventId", "e"."EventDate", "e"."LocationCity", "e"."LocationState", "a"."AircraftManufacturer", "a"."AircraftModel"
FROM "Aviation"."Event" AS "e"
JOIN "Aviation"."Aircraft" AS "a" ON "e"."ID" = "a"."Event"
WHERE "e"."EventDate" >= '2010-01-01' AND "e"."EventDate" < '2011-01-01'
--------------------------------------------------------------------------------
accuracy: 0.5714285

------

In [38]:
# Accuracy tests for prompts executed **with** few shots
use_few_shots=True
execute_tests(
    cnx,
    context, 
    prompt, 
    use_few_shots, 
    tables_vector_store, 
    table_df, 
    example_selector, 
    example_prompt
)

error on running query: 
SELECT e.EventId, e.EventDate, e.LocationCity, e.LocationState, a.AircraftManufacturer, a.AircraftModel
FROM Aviation.Event e
JOIN Aviation.Aircraft a ON e.EventId = a.EventId
WHERE Year(e.EventDate) = 2010 TOP 3
--------------------------------------------------------------------------------
accuracy: 0.8571428571428571
--------------------------------------------------------------------------------


As you can see, SQL queries generated with few shots was about 49% more accurate than the ones generated without few shots - 85% vs 57%.

## References

- https://python.langchain.com/v0.1/docs/expression_language/get_started/
- https://python.langchain.com/v0.1/docs/use_cases/sql/prompting/
- https://python.langchain.com/v0.1/docs/modules/model_io/prompts/composition/