In [1]:
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, FewShotPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_community.llms import Ollama
from sqlalchemy import create_engine
from tqdm.notebook import tqdm
import pandas as pd
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
import time
from sqlalchemy.exc import OperationalError, ProgrammingError
import re

In [2]:
# Define the database URI
database_uri = "sqlite:///../data/SQL/Probes.db"

# Create an engine
engine = create_engine(database_uri)

# Define custom table information
custom_table_info = {
    "probes": """
    CREATE TABLE probes (
        Probe_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Manufacturer TEXT,
        Probe_Model TEXT,
        Connection_Type TEXT,
        Array_Type TEXT,
        Frequency_Range TEXT,
        Applications TEXT,
        Stock INTEGER,
        Description TEXT
    );
    /* 
    Unique identifier for each probe.
    The company that manufactures the probe.
    The model number or identifier of the probe.
    Type of connection interface the probe uses (e.g., Cartridge).
    The configuration of the probe's sensor array (e.g., Linear, Convex).
    The operating frequency range of the probe.
    Typical applications or medical procedures the probe is used for.
    The number of units currently available for sale.
    A detailed description of the probe's features and capabilities.
    */
    """,
    "systems": """
    CREATE TABLE systems (
        System_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        System_Name TEXT UNIQUE,
        Manufacturer TEXT
    );
    /* 
    Unique identifier for each ultrasound system.
    The name or model of the ultrasound system.
    The company that manufactures the system.
    */
    """,
    "compatibility": """
    CREATE TABLE compatibility (
        Compatibility_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Probe_ID INTEGER,
        System_ID INTEGER,
        FOREIGN KEY(Probe_ID) REFERENCES probes(Probe_ID),
        FOREIGN KEY(System_ID) REFERENCES systems(System_ID)
    );
    /* 
    Unique identifier for each compatibility entry.
    Foreign key linking to the Probes table.
    Foreign key linking to the Systems table.
    */
    """
}

# Initialize SQLDatabase with the engine and custom table info
db = SQLDatabase(engine=engine, custom_table_info=custom_table_info)
context = db.get_context()["table_info"]
print(context)


    CREATE TABLE compatibility (
        Compatibility_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Probe_ID INTEGER,
        System_ID INTEGER,
        FOREIGN KEY(Probe_ID) REFERENCES probes(Probe_ID),
        FOREIGN KEY(System_ID) REFERENCES systems(System_ID)
    );
    /* 
    Unique identifier for each compatibility entry.
    Foreign key linking to the Probes table.
    Foreign key linking to the Systems table.
    */
    


    CREATE TABLE probes (
        Probe_ID INTEGER PRIMARY KEY AUTOINCREMENT,
        Manufacturer TEXT,
        Probe_Model TEXT,
        Connection_Type TEXT,
        Array_Type TEXT,
        Frequency_Range TEXT,
        Applications TEXT,
        Stock INTEGER,
        Description TEXT
    );
    /* 
    Unique identifier for each probe.
    The company that manufactures the probe.
    The model number or identifier of the probe.
    Type of connection interface the probe uses (e.g., Cartridge).
    The configuration of the probe's sensor array (e.g.,

## Create SQL query examples
Next we create a list of examples, where each example is a dictionary containing an "input" key with a question as the value, and an "sql_query" key with the corresponding SQL query as the value. These examples will be used to train a model to generate SQL queries based on input questions.

In [4]:
# Define examples
sql_examples = [
    {
        "input": "Who is the maker of 9EVF4?",
        "sql_query": "SELECT Manufacturer FROM probes WHERE Probe_Model = '9EVF4';"
    },
    {
        "input": "Is 9EVF4 made by Siemens Acuson or Philips?",
        "sql_query": "SELECT Manufacturer, CASE WHEN Manufacturer = 'Siemens Acuson' THEN 'Yes' WHEN Manufacturer = 'Philips' THEN 'Yes' ELSE 'No' END AS MadeBy FROM probes WHERE Probe_Model = '9EVF4' AND Manufacturer IN ('Siemens Acuson', 'Philips');"
    },
    {
        "input": "What systems is the ATL C3 probe compatible with?",
        "sql_query": "SELECT s.System_Name FROM compatibility c JOIN probes p ON c.Probe_ID = p.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'ATL' AND p.Probe_Model = 'C3';"
    },
    {
        "input": "What systems is the Siemens Acuson 9EVF4 compatible with?",
        "sql_query": "SELECT s.System_Name FROM compatibility c JOIN probes p ON c.Probe_ID = p.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4';"
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the S3000?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID JOIN probes p ON c.Probe_ID = p.Probe_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' AND s.System_Name = 'S3000';"
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the Voluson 730?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID JOIN probes p ON c.Probe_ID = p.Probe_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' AND s.System_Name = 'Voluson 730';"
    },
    {
        "input": "What array type is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT Array_Type FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    },
    {
        "input": "What can the Siemens Acuson 9EVF4 be used for?",
        "sql_query": "SELECT Applications FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    },
    {
        "input": "Do you have any 9EVF4 for sale?",
        "sql_query": "SELECT CASE WHEN Stock > 0 THEN 'Yes' ELSE 'No' END AS AvailableForSale FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    },
    {
        "input": "Do you have any ATL C3 probes in stock?",
        "sql_query": "SELECT CASE WHEN Stock > 0 THEN 'Yes' ELSE 'No' END AS InStock FROM probes WHERE Manufacturer = 'ATL' AND Probe_Model = 'C3';"
    },
    {
        "input": "Do you have any C3 probes for the HDI 5000 system in stock?",
        "sql_query": "SELECT CASE WHEN p.Stock > 0 AND EXISTS (SELECT 1 FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID WHERE s.System_Name = 'HDI 5000' AND c.Probe_ID = p.Probe_ID) THEN 'Yes' ELSE 'No' END AS InStock FROM probes p WHERE p.Probe_Model = 'C3';"
    },
    {
        "input": "Do you have any convex array probes compatible with the HDI 1500 system?",
        "sql_query": "SELECT p.Probe_Model, p.Manufacturer, CASE WHEN p.Stock > 0 THEN 'In Stock' ELSE 'Out of Stock' END AS StockStatus FROM probes p JOIN compatibility c ON p.Probe_ID = c.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE s.System_Name = 'HDI 1500' AND p.Array_Type = 'Convex' AND p.Stock > 0;"
    },
    {
        "input": "What is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT Description FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';"
    }
]

# Create the SQL example selector
sql_example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples=sql_examples,
    embeddings=OllamaEmbeddings(model="mxbai-embed-large"),
    vectorstore_cls=FAISS,
    k=5,
    input_keys=["input"],
)


In [6]:
# Define the SQL example prompt format
sql_example_prompt = PromptTemplate(
    input_variables=["input", "sql_query"],
    template="{sql_query}"
)

# Define the FewShotPromptTemplate
sql_prompt = FewShotPromptTemplate(
    example_selector=sql_example_selector,
    example_prompt=sql_example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Use the LIMIT clause if necessary to limit the results to at most {top_k} rows, and place it immediately before the final semicolon. Return a single SQL statement with only one semicolon at the end. Do not include any explanations, backticks, or additional text.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "table_info", "top_k"],
)

# Initialize language model and SQL query chain
llm = Ollama(model="llama3")
sql_chain = create_sql_query_chain(llm=llm, db=db, prompt=sql_prompt)

# Use StrOutputParser to extract only the SQL query
sql_query_chain = sql_chain | StrOutputParser()

final_answer_examples = [
    {
        "input": "Who is the maker of 9EVF4?",
        "sql_query": "SELECT Manufacturer FROM Probes WHERE Probe_Model = '9EVF4';",
        "sql_result": "[('Siemens Acuson',)]",
        "answer": "The 9EVF4 probe is made by Siemens Acuson."
    },
    {
        "input": "Is 9EVF4 made by Siemens Acuson or Philips?",
        "sql_query": "SELECT Manufacturer, CASE WHEN Manufacturer = 'Siemens Acuson' THEN 'Yes' WHEN Manufacturer = 'Philips' THEN 'Yes' ELSE 'No' END AS MadeBy FROM Probes WHERE Probe_Model = '9EVF4' AND Manufacturer IN ('Siemens Acuson', 'Philips') LIMIT 5;",
        "sql_result": "[('Siemens Acuson', 'Yes')]",
        "answer": "The 9EVF4 probe is made by Siemens Acuson."
    },
    {
        "input": "What systems is Siemens Acuson 9EVF4 compatible with?",
        "sql_query": "SELECT s.System_Name FROM compatibility c JOIN probes p ON c.Probe_ID = p.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' LIMIT 5;",
        "sql_result": "[('S1000',), ('S2000',), ('S3000',)]",
        "answer": "The Siemens Acuson 9EVF4 is compatible with S1000, S2000, S3000 ultrasound systems."
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the S3000?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM Compatibility JOIN Systems ON Compatibility.System_ID = Systems.System_ID JOIN Probes ON Compatibility.Probe_ID = Probes.Probe_ID WHERE Probes.Manufacturer = 'Siemens Acuson' AND Probes.Probe_Model = '9EVF4' AND Systems.System_Name = 'S3000';",
        "sql_result": "[('Yes',)]",
        "answer": "Yes, the Siemens Acuson 9EVF4 is compatible with the S3000 ultrasound system, as well as S1000, S2000."
    },
    {
        "input": "Does the Siemens Acuson 9EVF4 work with the Voluson 730?",
        "sql_query": "SELECT CASE WHEN COUNT(*) > 0 THEN 'Yes' ELSE 'No' END AS WorksWith FROM Compatibility JOIN Systems ON Compatibility.System_ID = Systems.System_ID JOIN Probes ON Compatibility.Probe_ID = Probes.Probe_ID WHERE Probes.Manufacturer = 'Siemens Acuson' AND Probes.Probe_Model = '9EVF4' AND Systems.System_Name = 'Voluson 730';",
        "sql_result": "[('No',)]",
        "answer": "No, the Siemens Acuson 9EVF4 is not compatible with the Voluson 730 ultrasound system."
    },
    {
        "input": "What type of probe is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT Array_Type FROM Probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[('Convex Array', '4-9 MHz')]",
        "answer": "The Siemens Acuson 9EVF4 is a Convex Array probe with a frequency range of 4-9 MHz."
    },
    {
        "input": "What can the Siemens Acuson 9EVF4 be used for?",
        "sql_query": "SELECT Applications FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[('urological, endocavitary, obstetric, gynecological',)]",
        "answer": "The Siemens Acuson 9EVF4 can be used for urological, endocavitary, obstetric, and gynecological applications."
    },
    {
        "input": "Do you have any Siemens Acuson 9EVF4 for sale?",
        "sql_query": "SELECT CASE WHEN Stock > 0 THEN 'Yes' ELSE 'No' END AS AvailableForSale FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[('No',)]",
        "answer": "We currently do not have any Siemens Acuson 9EVF4 in stock."
    },
    {
        "input": "Do you have any 9EVF4 for S3000 for sale?",
        "sql_query": "SELECT p.Probe_Model, p.Manufacturer, CASE WHEN p.Stock > 0 THEN 'Yes' ELSE 'No' END AS AvailableForSale FROM probes p JOIN compatibility c ON p.Probe_ID = c.Probe_ID JOIN systems s ON c.System_ID = s.System_ID WHERE p.Manufacturer = 'Siemens Acuson' AND p.Probe_Model = '9EVF4' AND s.System_Name = 'S3000' LIMIT 5;",
        "sql_result": "[('9EVF4', 'Siemens Acuson', 'No')]",
        "answer": "Unfortunately, we do not have any Siemens Acuson 9EVF4 for the S3000 available for sale."
    },
    {
        "input": "Do you have any linear array probes for iU22?",
        "sql_query": "SELECT p.Probe_Model, p.Manufacturer FROM probes p JOIN compatibility c ON p.Probe_ID = c.Probe_ID WHERE c.System_ID IN (SELECT s.System_ID FROM systems s WHERE s.System_Name = 'iU22') AND p.Array_Type = 'Linear' LIMIT 5;",
        "sql_result": "[('L12-3', 'Philips'), ('L15-7io', 'Philips'), ('L17-5', 'Philips')]",
        "answer": "The Philips L12-3, L15-7io, and L17-5 are linear array probes that are compatible with the iU22 ultrasound system."
    },
    {
        "input": "What is the Siemens Acuson 9EVF4?",
        "sql_query": "SELECT * FROM probes WHERE Manufacturer = 'Siemens Acuson' AND Probe_Model = '9EVF4';",
        "sql_result": "[(43, 'Siemens Acuson', '9EVF4', None, 'Convex', '4-9 MHz', 'urological, endocavitary, obstetric, gynecological', 0, 'The Siemens Acuson 9EVF4 is a convex array probe that can be used for urological, abdominal, obstetric, and gynecologic applications. The 9EVF4 has a frequency range of 4-9 MHz while its curved array design gives it a wide field of view beneficial for OB/GYN exams.')]",
        "answer": "The Siemens Acuson 9EVF4 is a convex array probe with a frequency range of 4-9 MHz. It can be used for urological, endocavitary, obstetric, and gynecological applications. The probe's curved array design provides a wide field of view, which is beneficial for OB/GYN exams."
    },
    {
        "input": "Can the ATL C3 work with the HDI 5000?",
        "sql_query": "SELECT CASE WHEN EXISTS (SELECT 1 FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID WHERE s.System_Name = 'HDI 5000' AND c.Probe_ID = p.Probe_ID) THEN 'Yes' ELSE 'No' END AS CanWork FROM probes p WHERE p.Manufacturer = 'ATL' AND p.Probe_Model = 'C3' LIMIT 5;",
        "sql_result": "[('Yes',)]",
        "answer": "Yes, the ATL C3 works with the HDI 5000."
    }
]

answer_example_prompt = PromptTemplate(
    input_variables=["input", "sql_query", "sql_result", "answer"],
    template="{answer}"
)

# Create the final_answer_example_selector
final_answer_example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples=final_answer_examples,
    embeddings=OllamaEmbeddings(model="mxbai-embed-large"),
    vectorstore_cls=FAISS,
    k=5,
    input_keys=["input"]
)

answer_prompt = FewShotPromptTemplate(
    example_selector=final_answer_example_selector,
    example_prompt=answer_example_prompt,
    prefix="Given the following user input question, corresponding SQL query, and SQL result, generate a concise and informative answer to the user input question. Your response should only answer the question asked while only re-iterating details that are already in the user input question. Do not include any explanations or additional text.",
    suffix="User input: {input}\nSQL query: {sql_query}\nSQL result: {sql_result}\n\nAnswer: ",
    input_variables=["input", "sql_query", "sql_result"],
)

# Create a chain to generate the final answer
final_answer_chain = (
    answer_prompt | llm | StrOutputParser()
)

In [9]:
# Read the CSV file
#qa_dataset = pd.read_csv("../data/qa_dataset.csv")
qa_dataset = pd.read_csv("qa_dataset_with_answers.csv")

# Function to get the last completed row index
def get_last_completed_row_index(df):
    required_columns = ["SQL Query", "SQL Result", "Final Answer", "SQL Query Time", "Final Answer Time"]
    if all(col in df.columns for col in required_columns):
        # Check for the last row where all required columns are not null
        not_null_df = df.dropna(subset=required_columns)
        if not not_null_df.empty:
            return not_null_df.index[-1] + 1
    return 0  # If the file doesn't exist or no rows are fully completed, start from the first row

# Function to save the last completed row index
def save_last_completed_row_index(index):
    with open("last_completed_row.txt", "w") as file:
        file.write(str(index))

def handle_programming_error(e, sql_query):
    if "You can only execute one statement at a time" in str(e):
        if sql_query.count(";") > 1:
            sql_query = sql_query.rstrip(";").replace(";", "") + ";"
        elif sql_query.count(";") == 1 and not sql_query.endswith(";"):
            sql_query = sql_query.replace(";", "") + ";"
        return sql_query
    else:
        return None

def preprocess_sql_query(sql_query):
    # Strip leading and trailing whitespace
    sql_query = sql_query.strip()
    
    # Remove all semicolons and then add exactly one at the end
    sql_query = sql_query.replace(';', '')
    sql_query += ';'
    
    return sql_query

# Define a function to process each row
def process_row(row, row_index):
    question = row["Question"]
    max_retries = 3
    retries = 0
    sql_query_time = 0
    final_answer_time = 0

    while retries < max_retries:
        start_time = time.time()  # Start timer before the try block
        try:
            sql_query = sql_query_chain.invoke({
                "question": question,
                "input": question,
                "table_info": context,
                "top_k": 5
            })
            # Preprocess the SQL query
            sql_query = preprocess_sql_query(sql_query)
            sql_result = db.run(sql_query)
            sql_query_time += time.time() - start_time  # Stop timer after successful execution
            break  # If no exception is raised, exit the loop
        except ProgrammingError as e:
            sql_query_time += time.time() - start_time  # Update time with the failed attempt
            retries += 1
            if retries == max_retries:
                sql_query = "Error generating SQL query"
                sql_result = "Error generating SQL query"
                break
            else:
                print(f"ProgrammingError: {e}. Retrying...")
        except OperationalError as e:
            sql_query_time += time.time() - start_time  # Update time with the failed attempt
            retries += 1
            if retries == max_retries:
                sql_query = "Error generating SQL query"
                sql_result = "Error generating SQL query"
                break
            else:
                print(f"OperationalError: {e}. Retrying...")

    # Timer for final answer generation
    start_time = time.time()
    final_answer = final_answer_chain.invoke({
        "input": question,
        "sql_query": sql_query,
        "sql_result": sql_result
    })
    final_answer_time += time.time() - start_time

    # Update the row with new columns
    qa_dataset.loc[row_index, "SQL Query"] = sql_query
    qa_dataset.loc[row_index, "SQL Result"] = str(sql_result)
    qa_dataset.loc[row_index, "Final Answer"] = final_answer
    qa_dataset.loc[row_index, "SQL Query Time"] = sql_query_time
    qa_dataset.loc[row_index, "Final Answer Time"] = final_answer_time

    # Save the updated DataFrame to a new CSV file
    qa_dataset.to_csv("qa_dataset_with_answers.csv", index=False)

    # Save the last completed row index
    save_last_completed_row_index(row_index)

# Get the last completed row index
last_completed_row_index = get_last_completed_row_index(qa_dataset)
# For starting on first row
#last_completed_row_index = 0

# Apply the function to each row starting from the last completed row
progress_bar = tqdm(total=len(qa_dataset) - last_completed_row_index, unit="row")
for row_index, row in qa_dataset.iloc[last_completed_row_index:].iterrows():
    process_row(row, row_index)
    progress_bar.update(1)
progress_bar.close()



  0%|          | 0/312 [00:00<?, ?row/s]

OperationalError: (sqlite3.OperationalError) no such column: p.Probe_ID
[SQL: SELECT CASE WHEN c.Compatibility_ID IS NOT NULL THEN 'Yes' ELSE 'No' END AS WorksWith FROM compatibility c JOIN systems s ON c.System_ID = s.System_ID WHERE s.System_Name = 'Voluson 730' AND p.Probe_ID = c.Probe_ID AND p.Manufacturer = 'G.E.' AND p.Probe_Model = 'RAB2-5L' LIMIT 1;]
(Background on this error at: https://sqlalche.me/e/20/e3q8). Retrying...
OperationalError: (sqlite3.OperationalError) no such column: p
[SQL: SELECT p Probe_Model FROM probes p WHERE Manufacturer = 'G.E.' AND Probe_Model = 'RIC5-9D' LIMIT 1;]
(Background on this error at: https://sqlalche.me/e/20/e3q8). Retrying...
OperationalError: (sqlite3.OperationalError) no such column: p.System_ID
[SQL: SELECT p.Array_Type
FROM probes p
JOIN systems s ON p.System_ID = s.System_ID
WHERE p.Manufacturer = 'Philips' AND p.Probe_Model = 'S4-2' AND Connection_Type = 'Cartridge' LIMIT 5;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)