In [1]:
from operator import itemgetter
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough, RunnableSequence, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, FewShotPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_community.llms import Ollama
from sqlalchemy import create_engine
from tqdm import tqdm
import pandas as pd
import re
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import SequentialChain

## Setup SQL Database
Previously, we created a SQLite database and tables from our original .csv file using the script in 'src/create_sql_database.py'.

In the cell below, we are setting up a connection and initializing tables within an SQLite database for managing data about ultrasound probes and systems. We also define custom table information, `custom_table_info = {...}`, which will provide our SQL query generator with more context to create accurate and informed queries.

A custom table information dictionary allows for the inclusion of notes and comments about each table and its columns. These notes can later be passed into our SQL prompt, enabling it to construct more precise and relevant queries. In contrast, using the default `db.get_table_info()` method would not provide this extra contextual information.

In [4]:
# Define the database URI
database_uri = "sqlite:///../data/SQL/Probes.db"

# Create an engine
engine = create_engine(database_uri)

# Define custom table information
custom_table_info = {
    "probes": """
    CREATE TABLE probes (
        Probe_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Manufacturer TEXT,
        Probe_Model TEXT,
        Connection_Type TEXT,
        Array_Type TEXT,
        Frequency_Range TEXT,
        Applications TEXT,
        Stock INTEGER,
        Description TEXT
    );
    /* 
    Probe_ID is the unique identifier for each probe.
    Manufacturer is the company that manufactures the probe.
    Probe_Model is the model number or identifier of the probe.
    Connection_Type is the type of connection interface the probe uses (e.g., Cartridge).
    Array_Type is the configuration of the probe's sensor array (e.g., Linear, Convex).
    Frequency_Range is the operating frequency range of the probe.
    Applications is the typical applications or medical procedures the probe is used for.
    Stock is the number of units currently available for sale.
    Description is a detailed description of the probe's features and capabilities.
    */
    """,
    "systems": """
    CREATE TABLE systems (
        System_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        System_Name TEXT UNIQUE,
        Manufacturer TEXT
    );
    /* 
    System_ID is the unique identifier for each ultrasound system.
    System_Name is the name or model of the ultrasound system.
    Manufacturer is the company that manufactures the system.
    */
    """,
    "compatibility": """
    CREATE TABLE compatibility (
        Compatibility_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Probe_ID INTEGER,
        System_ID INTEGER,
        FOREIGN KEY(Probe_ID) REFERENCES probes(Probe_ID),
        FOREIGN KEY(System_ID) REFERENCES systems(System_ID)
    );
    /* 
    Compatibility_ID is the unique identifier for each compatibility entry.
    Probe_ID is a foreign key linking to the Probes table.
    System_ID is a foreign key linking to the Systems table.
    */
    """
}

# Initialize SQLDatabase with the engine and custom table info
db = SQLDatabase(engine=engine, custom_table_info=custom_table_info)
context = db.get_context()["table_info"]
print(context)


    CREATE TABLE compatibility (
        Compatibility_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Probe_ID INTEGER,
        System_ID INTEGER,
        FOREIGN KEY(Probe_ID) REFERENCES probes(Probe_ID),
        FOREIGN KEY(System_ID) REFERENCES systems(System_ID)
    );
    /* 
    Unique identifier for each compatibility entry.
    Foreign key linking to the Probes table.
    Foreign key linking to the Systems table.
    */
    


    CREATE TABLE probes (
        Probe_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Manufacturer TEXT,
        Probe_Model TEXT,
        Connection_Type TEXT,
        Array_Type TEXT,
        Frequency_Range TEXT,
        Applications TEXT,
        Stock INTEGER,
        Description TEXT
    );
    /* 
    Unique identifier for each probe.
    The company that manufactures the probe.
    The model number or identifier of the probe.
    Type of connection interface the probe uses (e.g., Cartridge).
    The configuration of the probe's sensor array (e.g.,

## Create SQL query examples
Next we create a list of examples, where each example is a dictionary containing an "input" key with a question as the value, and an "sql_query" key with the corresponding SQL query as the value. These examples will be used to train a model to generate SQL queries based on input questions.

In [23]:
# Define examples
examples = [
    {
        "input": "Who is the maker of 9EVF4?",
        "sql_query": "SELECT Manufacturer FROM probes WHERE Probe_Model = '9EVF4';"
    },
    {
        "input": "Is 9EVF4 made by Siemens Acuson or Philips?",
        "sql_query": "SELECT Manufacturer, CASE WHEN Manufacturer = 'Siemens Acuson' THEN 'Yes' WHEN Manufacturer = 'Philips' THEN 'Yes' ELSE 'No' END AS MadeBy FROM probes WHERE Probe_Model = '9EVF4' AND Manufacturer IN ('Siemens Acuson', 'Philips');"
    },
    {
        "input": "What systems is the ATL C3 probe compatible with?",
        "sql_query": "SELECT s.System_Name FROM compatibility c JOIN probes p ON c.Probe_ID = p.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'ATL' AND p.Probe_Model = 'C3';"
    },
    {
        "input": "What systems is the Siemens Acuson 9EVF4 compatible with?",
        "sql_query": "SELECT s.System_Name FROM compatibility c JOIN probes p ON c.Probe_ID = p.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' LIMIT 5;"
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the S3000?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID JOIN probes p ON c.Probe_ID = p.Probe_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' AND s.System_Name = 'S3000';"
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the Voluson 730?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID JOIN probes p ON c.Probe_ID = p.Probe_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' AND s.System_Name = 'Voluson 730';"
    },
    {
        "input": "What array type is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT Array_Type FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    },
    {
        "input": "What can the Siemens Acuson 9EVF4 be used for?",
        "sql_query": "SELECT Applications FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    },
    {
        "input": "Do you have any 9EVF4 for sale?",
        "sql_query": "SELECT CASE WHEN Stock > 0 THEN 'Yes' ELSE 'No' END AS AvailableForSale FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    },
    {
        "input": "Do you have any ATL C3 probes in stock?",
        "sql_query": "SELECT CASE WHEN Stock > 0 THEN 'Yes' ELSE 'No' END AS InStock FROM probes WHERE Manufacturer = 'ATL' AND Probe_Model = 'C3';"
    },
    {
        "input": "Do you have any C3 probes for the HDI 5000 system in stock?",
        "sql_query": "SELECT CASE WHEN p.Stock > 0 AND EXISTS (SELECT 1 FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID WHERE s.System_Name = 'HDI 5000' AND c.Probe_ID = p.Probe_ID) THEN 'Yes' ELSE 'No' END AS InStock FROM probes p WHERE p.Probe_Model = 'C3';"
    },
    {
        "input": "Do you have any convex array probes compatible with the HDI 1500 system?",
        "sql_query": "SELECT p.Probe_Model, p.Manufacturer, CASE WHEN p.Stock > 0 THEN 'In Stock' ELSE 'Out of Stock' END AS StockStatus FROM probes p JOIN compatibility c ON p.Probe_ID = c.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE s.System_Name = 'HDI 1500' AND p.Array_Type = 'Convex' AND p.Stock > 0;"
    },
    {
        "input": "What is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT Description FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    }
]


## Create the Example Selector and Example Prompt

The code creates an instance of the `SemanticSimilarityExampleSelector` class from the `langchain` library. This class is used to select relevant examples from the list of examples based on the semantic similarity between the input question and the examples.

The `SemanticSimilarityExampleSelector` is initialized with the following parameters:

- `examples`: The list of examples defined earlier.
- `embeddings`: An instance of the `OllamaEmbeddings` class, which is used to generate embeddings (numerical representations) of the input questions and examples. The `mxbai-embed-large` model is used for generating the embeddings.
- `vectorstore_cls`: The `FAISS` class from `langchain.vectorstores` is used for efficient similarity search over the embeddings.
- `k`: The number of examples to retrieve for a given input question. In this case, it is set to 5.
- `input_keys`: The keys in the example dictionaries that contain the input text. In this case, it is set to `["input"]`.

In [27]:
# Create the SQL example selector
sql_example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples=examples,
    embeddings=OllamaEmbeddings(model="mxbai-embed-large"),
    vectorstore_cls=FAISS,
    k=5,
    input_keys=["input"],
)

Next, we create an instance of the `PromptTemplate` class from `langchain` to format the examples for use in a prompt.

In [25]:
# Define the SQL example prompt format
sql_example_prompt = PromptTemplate(
    input_variables=["input", "sql_query"],
    template="{sql_query}"
)

## Define the FewShotPromptTemplate

Next, we want to create a prompt that can be used with a language model to generate SQL queries based on input questions. The prompt includes relevant examples, as well as a prefix and suffix that provide context for the task. 

The `FewShotPromptTemplate` is initialized with the following parameters:

- `example_selector`: The `SemanticSimilarityExampleSelector` instance created earlier (i.e. `sql_example_selector`).
- `example_prompt`: The `PromptTemplate` instance created earlier (i.e. `sql_example_prompt`).
- `prefix`: A string that will be added to the beginning of the prompt.
- `suffix`: A string that will be added to the end of the prompt, after the input question.
- `input_variables`: The variables that should be included in the prompt. In this case, it is set to `["input", "table_info", "top_k"]`.

In [29]:
# Define the FewShotPromptTemplate
fewshot_sql_prompt = FewShotPromptTemplate(
    example_selector=sql_example_selector,
    example_prompt=sql_example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Use the LIMIT clause as appropriate, querying for at most {top_k} results. Return only the SQL query. Do not include any explanations or additional text.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "table_info", "top_k"],
)


Next, we ask a sample question that will be used as input to generate an SQL query.

We then initialize an instance of the `Ollama` language model from the `langchain` library with the `llama3` model. This language model will be used to generate SQL queries based on the input question. 

Then we create an instance of the `create_sql_query_chain` that combines the language model, the database (`db`), and the few-shot prompt (`prompt`) created earlier to generate SQL queries.

We then call `.invoke()` on the chain with the question and the input question.

We then print the SQL query generated by the chain.

We then execute the SQL query on the database and print the result.




In [34]:
sample_question = "Do you have any linear array probes for the iU22 system?"

# Initialize language model and SQL query chain
llm = Ollama(model="llama3")
chain = create_sql_query_chain(llm=llm, db=db, prompt=fewshot_sql_prompt)

# Use StrOutputParser to extract only the SQL query
sql_query_chain = chain | StrOutputParser()

# Invoke the chain with a specific question
sql_query = sql_query_chain.invoke({
    "input": sample_question,
    "question": sample_question,
    "table_info": db.get_context()["table_info"],
    "top_k": 5
})

print("SQL Result:\n", sql_query)

# Execute the SQL query and get results
sql_result = db.run(sql_query)
print("\nQuery Result:\n", sql_result)


SQL Result:
 SELECT p.Probe_Model, p.Manufacturer 
FROM probes p JOIN compatibility c ON p.Probe_ID = c.Probe_ID 
JOIN systems s ON c.System_ID = s.System_ID 
WHERE s.System_Name = 'iU22' AND p.Array_Type = 'Linear';

Query Result:
 [('L12-3', 'Philips'), ('L15-7io', 'Philips'), ('L17-5', 'Philips')]


In [15]:
final_answer_examples = [
    {
        "input": "Who is the maker of 9EVF4?",
        "sql_query": "SELECT Manufacturer FROM Probes WHERE Probe_Model = '9EVF4';",
        "sql_result": "[('Siemens Acuson',)]",
        "answer": "The 9EVF4 probe is made by Siemens Acuson."
    },
    {
        "input": "Is 9EVF4 made by Siemens Acuson or Philips?",
        "sql_query": "SELECT Manufacturer, CASE WHEN Manufacturer = 'Siemens Acuson' THEN 'Yes' WHEN Manufacturer = 'Philips' THEN 'Yes' ELSE 'No' END AS MadeBy FROM Probes WHERE Probe_Model = '9EVF4' AND Manufacturer IN ('Siemens Acuson', 'Philips') LIMIT 5;",
        "sql_result": "[('Siemens Acuson', 'Yes')]",
        "answer": "The 9EVF4 probe is made by Siemens Acuson."
    },
    {
        "input": "What systems is Siemens Acuson 9EVF4 compatible with?",
        "sql_query": "SELECT s.System_Name FROM compatibility c JOIN probes p ON c.Probe_ID = p.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' LIMIT 5;",
        "sql_result": "[('S1000',), ('S2000',), ('S3000',)]",
        "answer": "The Siemens Acuson 9EVF4 is compatible with S1000, S2000, S3000 ultrasound systems."
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the S3000?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM Compatibility JOIN Systems ON Compatibility.System_ID = Systems.System_ID JOIN Probes ON Compatibility.Probe_ID = Probes.Probe_ID WHERE Probes.Manufacturer = 'Siemens Acuson' AND Probes.Probe_Model = '9EVF4' AND Systems.System_Name = 'S3000';",
        "sql_result": "[('Yes',)]",
        "answer": "Yes, the Siemens Acuson 9EVF4 is compatible with the S3000 ultrasound system, as well as S1000, S2000."
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the Voluson 730?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM Compatibility JOIN Systems ON Compatibility.System_ID = Systems.System_ID JOIN Probes ON Compatibility.Probe_ID = Probes.Probe_ID WHERE Probes.Manufacturer = 'Siemens Acuson' AND Probes.Probe_Model = '9EVF4' AND Systems.System_Name = 'Voluson 730';",
        "sql_result": "[('No',)]",
        "answer": "No, the Siemens Acuson 9EVF4 is not compatible with the Voluson 730 ultrasound system."
    },
    {
        "input": "What type of probe is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT Array_Type FROM Probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[('Convex Array', '4-9 MHz')]",
        "answer": "The Siemens Acuson 9EVF4 is a Convex Array probe with a frequency range of 4-9 MHz."
    },
    {
        "input": "What can the Siemens Acuson 9EVF4 be used for?",
        "sql_query": "SELECT Applications FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[('urological, endocavitary, obstetric, gynecological',)]",
        "answer": "The Siemens Acuson 9EVF4 can be used for urological, endocavitary, obstetric, and gynecological applications."
    },
    {
        "input": "Do you have any Siemens Acuson 9EVF4 for sale?",
        "sql_query": "SELECT CASE WHEN Stock > 0 THEN 'Yes' ELSE 'No' END AS AvailableForSale FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[('No',)]",
        "answer": "We currently do not have any Siemens Acuson 9EVF4 in stock."
    },
    {
        "input": "Do you have any 9EVF4 for S3000 for sale?",
        "sql_query": "SELECT p.Probe_Model, p.Manufacturer, CASE WHEN p.Stock > 0 THEN 'Yes' ELSE 'No' END AS AvailableForSale FROM probes p JOIN compatibility c ON p.Probe_ID = c.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' AND s.System_Name = 'S3000' LIMIT 5;",
        "sql_result": "[('9EVF4', 'Siemens Acuson', 'No')]",
        "answer": "Unfortunately, we do not have any Siemens Acuson 9EVF4 for the S3000 available for sale."
    },
    {
        "input": "Do you have any linear array probes for iU22?",
        "sql_query": "SELECT p.Probe_Model, p.Manufacturer FROM probes p JOIN compatibility c ON p.Probe_ID = c.Probe_ID WHERE c.System_ID IN (SELECT s.System_ID FROM systems s WHERE s.System_Name = 'iU22') AND p.Array_Type = 'Linear' LIMIT 5;",
        "sql_result": "[('L12-3', 'Philips'), ('L15-7io', 'Philips'), ('L17-5', 'Philips')]",
        "answer": "The Philips L12-3, L15-7io, and L17-5 are linear array probes that are compatible with the iU22 ultrasound system."
    },
    {
        "input": "What is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT * FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[(43, 'Siemens Acuson', '9EVF4', None, 'Convex', '4-9 MHz', 'urological, endocavitary, obstetric, gynecological', 0, 'The Siemens Acuson 9EVF4 is a convex array probe that can be used for urological, abdominal, obstetric, and gynecologic applications. The 9EVF4 has a frequency range of 4-9 MHz while its curved array design gives it a wide field of view beneficial for OB/GYN exams.')]",
        "answer": "The Siemens Acuson 9EVF4 is a convex array probe with a frequency range of 4-9 MHz. It can be used for urological, endocavitary, obstetric, and gynecological applications. The probe's curved array design provides a wide field of view, which is beneficial for OB/GYN exams."
    }
]

answer_example_prompt = PromptTemplate(
    input_variables=["input", "sql_query", "sql_result", "answer"],
    template="{answer}"
)

answer_prompt = FewShotPromptTemplate(
    examples=final_answer_examples,
    example_prompt=answer_example_prompt,
    prefix="Given the following user input question, corresponding SQL query, and SQL result, generate a concise and informative answer to the user input question. Your response should only answer the question asked while only re-iterating details that are already in the user input question. Do not include any explanations or additional text.",
    suffix="User input: {input}\nSQL query: {sql_query}\nSQL result: {sql_result}\n\nAnswer: ",
    input_variables=["input", "sql_query", "sql_result"],
)

# Create a chain to generate the final answer
final_answer_chain = (
    answer_prompt | llm | StrOutputParser()
)

# Invoke the final answer chain with the question, SQL query, and SQL result
final_answer = final_answer_chain.invoke({
    "input": "What type of probe is the Philips L17-5?",
    "sql_query": sql_query,
    "sql_result": sql_result
})

print(final_answer)

The Philips L17-5 is a linear array probe.


In [28]:
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples=examples,
    embeddings=OllamaEmbeddings(model="mxbai-embed-large"),
    vectorstore_cls=FAISS,
    k=5,
    input_keys=["input"],
)

ollama_emb = OllamaEmbeddings(model="mxbai-embed-large")
example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

# Initialize language model and SQL query chain
llm = Ollama(model="llama3")
chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt)

# Define the prompt for cleaning up the SQL code
clean_sql_prompt = PromptTemplate(
    input_variables=["sql_query"],
    template="Extract and return only the SQL query from the input:\n\n{sql_query}"
)

# Define the function to clean up the SQL code
def clean_sql(sql_query):
    cleaned_sql = llm(clean_sql_prompt.format(sql_query=sql_query))
    return cleaned_sql

# Create a chain with the existing SQL query chain and the clean_sql function
final_chain = SequentialChain(
    chains=[chain, clean_sql],
    input_variables=["question"],
    output_variables=["cleaned_sql_query"],
    verbose=True
)

final_chain({"question": question})

# Execute the SQL query and get results
sql_result = db.run(sql_query)

# Define the final answer composition using RunnableSequence
answer_prompt_template = """Use the original question, SQL query, and SQL result provided to generate a concise and informative answer. Follow these guidelines:

    - If the question asks about a specific probe model, include the manufacturer name and model in the answer.
    - If the question asks about compatibility, list the compatible systems and provide a clear yes/no answer.
    - If the question asks about probe characteristics, include relevant details such as array type and frequency range.
    - If the question asks about applications, list the specific applications the probe is used for.
    - If the question asks about stock availability, provide a clear statement about current stock status.
    - If the requested item is not in stock, politely inform that it is currently unavailable.
    - If there are ambiguous column names in the SQL query, prefix the column names with the appropriate table names to resolve the ambiguity.

    Example Outputs:
    - 'The 9EVF4 probe is made by Siemens Acuson.'
    - 'Yes, the Siemens Acuson 9EVF4 is compatible with the S3000 ultrasound system, as well as S1000, S2000.'
    - 'The Siemens Acuson 9EVF4 is a Convex Array probe with a frequency range of 4-9 MHz.'
    - 'Unfortunately, we do not have any Siemens Acuson 9EVF4 for the S3000 available for sale.'

    Original Question: {question}
    SQL Query: {query}
    SQL Result: {result}

    Final Answer:
    """

answer_prompt = PromptTemplate(
    template=answer_prompt_template,
    input_variables=["question", "result", "query"]
)

# Create the final answer chain
answer_chain = RunnableSequence(
    RunnableLambda(lambda inputs: {
        "question": inputs["question"],
        "result": sql_result,
        "query": sql_query
    }),
    answer_prompt,
    llm
)

# Get the final answer
final_answer = answer_chain.invoke({"question": question})

# Parse the output to extract the SQL query and LLM response
output_lines = final_answer.split("\n")
for line in output_lines:
    if line.startswith("SQL Query:"):
        parsed_sql_query = line.split("SQL Query:")[1].strip()
    elif line.startswith("Final Answer:"):
        parsed_llm_response = line.split("Final Answer:")[1].strip()

print("Parsed SQL Query:", parsed_sql_query)
print("Parsed LLM Response:", parsed_llm_response)

"Based on your request, here's a SQL query that corresponds to your user input:\n\nWhat type of probe is the Philips L17-5?\n\nSQL query:\nSELECT Array_Type FROM Probes WHERE Manufacturer = 'Philips' AND Probe_Model = 'L17-5';\n\nNote: The manufacturer and model number are case-sensitive."

In [19]:
# Read the CSV file into a DataFrame
df = pd.read_csv("../data/qa_dataset.csv")

# Initialize language model and SQL query chain
llm = Ollama(model="llama3")
chain = create_sql_query_chain(llm, db)


system = """Extract and return only the SQL query from the user's {dialect} query. Check for common SQL mistakes as listed, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
- Ensuring subqueries return a single row when used in expressions that require a single value
- Avoiding ambiguous column names by specifying the table name with the column when multiple tables are involved
- Checking for potential SQL injection vulnerabilities by sanitizing inputs

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. Format the output as follows:

\`\`\`sql
[insert sql query code here]
\`\`\`
"""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)

validation_chain = prompt | llm | StrOutputParser()

full_chain = {"query": chain} | validation_chain

# Create empty lists to store the results
sql_queries = []
sql_results = []
final_answers = []

# Iterate over each row in the DataFrame to extract and answer question
for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing questions"):
    question = row["Question"]

    # Invoke the chain to get the SQL query
    query_result = full_chain.invoke({"question": question})

    # Use regular expression to extract the SQL query
    match = re.search(r"```sql\n(.+?)\n```", query_result, re.DOTALL)
    if match:
        sql_query = match.group(1).strip()  # Extract and strip whitespace
    else:
        sql_query = "No valid SQL query found"

    # Execute the SQL query and get results
    sql_result = db.run(sql_query)


    # Define the final answer composition using RunnableSequence
    answer_prompt_template = """Use the original question, SQL query, and SQL result provided to generate a concise and informative answer. Follow these guidelines:

    - If the question asks about a specific probe model, include the manufacturer name and model in the answer.
    - If the question asks about compatibility, list the compatible systems and provide a clear yes/no answer.
    - If the question asks about probe characteristics, include relevant details such as array type and frequency range.
    - If the question asks about applications, list the specific applications the probe is used for.
    - If the question asks about stock availability, provide a clear statement about current stock status.
    - If the requested item is not in stock, politely inform that it is currently unavailable.
    - If there are ambiguous column names in the SQL query, prefix the column names with the appropriate table names to resolve the ambiguity.

    Example Outputs:
    - 'The 9EVF4 probe is made by Siemens Acuson.'
    - 'Yes, the Siemens Acuson 9EVF4 is compatible with the S3000 ultrasound system, as well as S1000, S2000.'
    - 'The Siemens Acuson 9EVF4 is a Convex Array probe with a frequency range of 4-9 MHz.'
    - 'Unfortunately, we do not have any Siemens Acuson 9EVF4 for the S3000 available for sale.'

    Original Question: {question}
    SQL Query: {query}
    SQL Result: {result}

    Final Answer:
    """

    answer_prompt = PromptTemplate(
        template=answer_prompt_template,
        input_variables=["question", "result", "query"]
    )

    # Create the final answer chain
    answer_chain = RunnableSequence(
        RunnableLambda(lambda inputs: {
            "question": inputs["question"],
            "result": sql_result,
            "query": sql_query
        }),
        answer_prompt,
        llm
    )

    # Get the final answer
    final_answer = answer_chain.invoke({"question": question})

    # Append the results to the respective lists
    df.loc[index, 'sql_query'] = sql_query
    df.loc[index, 'sql_result'] = str(sql_result)
    df.loc[index, 'final_answer'] = final_answer.strip()
    df.to_csv("../data/qa_dataset_with_results.csv", index=False)

Processing questions:   0%|          | 0/465 [00:10<?, ?it/s]


OperationalError: (sqlite3.OperationalError) near "No": syntax error
[SQL: No valid SQL query found]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [21]:
query = chain.invoke({"question": question})
query



'Here are the answers:\n\nQuestion: Who is the maker of C3?\nSQLQuery: SELECT "Manufacturer" FROM probes WHERE "Probe_Model" = \'C3'

In [16]:
sql_query

'Here is the SQL query:  sql SELECT "System_Name"  FROM compatibility c JOIN probes p ON c."Probe_ID" = p."Probe_ID" WHERE p."Manufacturer" = \'ATL\' AND p."Probe_Model" = \'C3\''

In [5]:
# Initialize language model and SQL query chain
llm = Ollama(model="llama3")
chain = create_sql_query_chain(llm, db)

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only."""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)

validation_chain = prompt | llm | StrOutputParser()

# Define the chain to generate the SQL query, execute it, and format the answer
chain = (
    RunnablePassthrough.assign(query=create_sql_query_chain(llm, db))
    | itemgetter("query")  # Extract the 'query' part from the output
    | (lambda query: query.strip().strip('`'))  # Clean the query string
    | (lambda sql_query: db.run(sql_query))  # Execute the SQL query
    | (lambda sql_result: ', '.join([result[0] for result in sql_result]))  # Format results
    | (lambda system_names: f"The ATL C3 is compatible with the {system_names} ultrasound systems.")  # Create the final answer
)

# Invoke the chain with a specific question
final_answer = chain.invoke({
    "question": "What systems is ATL C3 compatible with?",
    "input": "What systems is ATL C3 compatible with?"
})

print(final_answer)

OperationalError: (sqlite3.OperationalError) near "Here": syntax error
[SQL: Here is the response:

Question: What systems is ATL C3 compatible with?
SQLQuery: SELECT "System_Name" FROM compatibility WHERE "Manufacturer" = 'ATL' AND "Probe_ID" = (SELECT "Probe_ID" FROM probes WHERE "Manufacturer" = 'ATL' AND "Probe_Model" = 'C3]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [30]:
# Initialize language model and database
llm = Ollama(model="llama3")

# Define the custom prompt template focused only on generating the SQL query
template = '''Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. Query only the necessary columns you can see in the tables below: {table_info}
Use the LIMIT clause as appropriate, querying for at most {top_k} results. Wrap column names in double quotes. Use the date('now') function for current date if needed.

Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Question: {input}'''
sql_prompt = PromptTemplate.from_template(template)

# Create a chain to convert questions to SQL queries using the custom prompt
write_query = create_sql_query_chain(
    llm=llm,
    db=db,
    prompt=sql_prompt
)

# Tool to execute SQL queries
execute_query = QuerySQLDataBaseTool(
    db=db,
    return_direct=True,
    verbose=True
)

# Define the answer prompt template
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, provide a full answer to the question.

Question: {input}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

# Combine the components into a chain
chain = (
    RunnablePassthrough.assign(query=write_query)
    | itemgetter("SQLQuery")  # Extract the 'SQLQuery' part from the output
    | execute_query
    | answer_prompt | llm | StrOutputParser()
)

# Invoke the chain with a specific question
response = chain.invoke({
    "input": "What systems is ATL C3 compatible with?",
    "question": "What systems is ATL C3 compatible with?",
    "top_k": 5,
    "table_info": db.get_table_info()
})
print(response)


KeyError: 'SQLQuery'

In [7]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result