## Correct invalid/inaccurate SQL or instruction translations

Jump to sections:
- <a href="#invalid">Correct invalid translated SQL</a>
- <a href="#inaccurate">Correct valid but inaccurate translated SQL</a>
- <a href="#instructions">Translate just the instructions column with the instructions_to_dialect function</a>
- <a href="#inspectinstr">Inspect the translated instructions column</a>


In [None]:
import pandas as pd
import os
from tqdm import tqdm
import time
from utils.creds import db_creds_all
import sqlparse
tqdm.pandas()
pd.set_option('display.max_colwidth', None)

dialect = "mysql"
csv_file = f"data/instruct_advanced_{dialect}.csv"
df = pd.read_csv(csv_file)

In [None]:
# Check for all invalid SQL in the csv file
# NOTE: You are encouraged to update defog-data databases first before running this cell
from utils.dialects import test_valid_tsql, test_valid_mysql, test_valid_bq, test_valid_sqlite
from eval.eval import get_all_minimal_queries

# First check if sql col contains <INVALID ERR MSG>
if len(df[df["query"].str.contains("<INVALID ERR MSG>")].copy()) > 0:
    # split query into list by semicolon
    df["sql_list"] = df["err_msg_list"] = df["valid_list"] = df["query"].apply(lambda x: [s for s in x.split(";") if s])
    
    # Extract the translated query, error message and validity
    df["sql_list"] = df["sql_list"].apply(lambda x: [item.split("<INVALID TRANSLATION>: ")[1].split("-----------------")[0] if "<INVALID TRANSLATION>:" in item else item for item in x])
    df["err_msg_list"] = df["err_msg_list"].apply(lambda x: [item.split("<INVALID ERR MSG>: ")[1].split("-----------------")[0] if "<INVALID ERR MSG>:" in item else "" for item in x])
    df["valid_list"] = df["valid_list"].apply(lambda x: [False if "<INVALID ERR MSG>:" in item else True for item in x])

else:
    # Check validity of all queries on defog-data databases
    df["result_tuple_list"] = ""
    df["sql_list"] = ""
    sql_col = "query"

    for i, row in tqdm(df.iterrows(), total=len(df)):
        sqls = row[sql_col]
        sql_list = get_all_minimal_queries(sqls)
        df.at[i, "sql_list"] = sql_list
        if dialect == "bigquery":
            result_tuple_list = test_valid_bq(db_creds_all["bigquery"], sql_list, row.db_name)
        elif dialect == "mysql":
            result_tuple_list = test_valid_mysql(db_creds_all["mysql"], sql_list, row.db_name)
        elif dialect == "sqlite":
            result_tuple_list = test_valid_sqlite(db_creds_all["sqlite"], sql_list, row.db_name)
        elif dialect == "tsql":
            result_tuple_list = test_valid_tsql(db_creds_all["tsql"], sql_list, row.db_name)
        else:
            raise ValueError("Dialect not supported")
        df.at[i, "result_tuple_list"] = result_tuple_list
    df[f"valid_list"] = df["result_tuple_list"].apply(lambda x: [item[0] for item in x])
    df[f"err_msg_list"] = df["result_tuple_list"].apply(lambda x: [item[1] for item in x])
    df.drop(columns=["result_tuple_list"], inplace=True)

invalid_df = df[df["valid_list"].apply(lambda x: False in x)]
    
print("No. of invalid queries:", len(invalid_df))

<a id="invalid"></a>
### Correct invalid translated SQL

Use the next few cells to correct a single SQL that was translated to a different dialect but found to be invalid when executed on a database.

In [None]:
# Sample 1 row from invalid_df
invalid_eg = invalid_df.sample(1)

# Get the postgres query for the invalid index
postgres_csv_file = csv_file.replace(dialect, "postgres")
postgres_df = pd.read_csv(postgres_csv_file)
postgres_query = postgres_df.loc[invalid_eg.index[0], "query"]
print("Postgres query for comparison:\n", sqlparse.format(postgres_query, reindent=True, keyword_case='upper'))

invalid_eg

In [None]:
# Get only invalid indices of sql_list
invalid_indices = [i for i, val in enumerate(invalid_eg["valid_list"].iloc[0]) if not val]
print(invalid_indices)

# Store values
db_name = invalid_eg["db_name"].values[0]
query = invalid_eg["query"].values[0]
question = invalid_eg["question"].values[0]
instructions = invalid_eg["instructions"].values[0] if "instructions" in invalid_eg.columns else ""
sql_list = invalid_eg["sql_list"].values[0]
err_msg_list = invalid_eg["err_msg_list"].values[0]

In [None]:
# Use LLM to rewrite the SQL for the dialect
from openai import OpenAI
from defog_utils.utils_sql import normalize_sql
import json
openai = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
model = "gpt-4o"

# TODO: Specify special SQL syntax rules for the dialect. e.g. "STR_TO_DATE function is not supported in T-SQL. Instead use the DATEFROMPARTS function to concatenate the date parts."
special_instructions = """
"""

def rewrite_invalid_sql(
    model: str,
    sql: str,
    question: str,
    instructions: str,
    err_msg: str,
    to_dialect: str,
    special_instructions: str,
) -> str:
    """
    Use LLM to rewrite invalid SQL for the dialect
    """
    if "Or data values in defog-data databases could be outdated." in err_msg:
        err_msg = err_msg.replace("Or data values in defog-data databases could be outdated.", "")
    messages = [
        {
            "role": "system",
            "content": f"""Your task is to rewrite an invalid SQL query in the {to_dialect} dialect to answer a specific question.
{special_instructions}""",
        },
        {
            "role": "user",
            "content": f"""Question to answer: {question}
Instructions: {instructions}
Invalid SQL: {sql}
Error to fix: {err_msg}

Format your response as a valid JSON string with reason and sql keys. 
Your response should look like the string below:
{{
    "reason": "Your reasoning for the response",
    "sql": "The valid rewritten query for {to_dialect}"
}}

Do not include any other information before and after the JSON string.
""",
        },
    ]
    completion = openai.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=2000,
        temperature=0,
        response_format={"type": "json_object"},
    )
    completion = completion.choices[0].message.content
    try:
        completion_dict = json.loads(completion)
    except:
        print(f"Error parsing completion {completion}")
        completion_dict = {
            "sql": None,
            "reason": None,
        }  
    return completion_dict

rewritten_sql_list = []
for i in range(len(sql_list)):
    if i in invalid_indices:
        sql = sql_list[i]
        err_msg = err_msg_list[i]
        completion_dict = rewrite_invalid_sql(model, sql, question, instructions, err_msg, dialect, special_instructions)
        sql_rewritten = completion_dict['sql']

        if sql_rewritten is not None:
            sql_rewritten = normalize_sql(sql_rewritten)
            rewritten_sql_list.append(sql_rewritten)
            print("Reason: ", completion_dict['reason'])
            print("Rewritten SQL: ", sqlparse.format(sql_rewritten, reindent=True, keyword_case='upper'))
            print("\n")
    else:
        print(i)
        rewritten_sql_list.append(sql_list[i])

# ensure no duplicates in rewritten_sql_list
rewritten_sql_list = list(set(rewritten_sql_list))

In [None]:
# RUN THIS CELL ONLY if you want to manually rewrite the SQL for the dialect
rewritten_sql_list = [
]

In [None]:
# Test validity of all SQL in rewritten_sql_list
if dialect == "bigquery":
    result_tuple_list = test_valid_bq(db_creds_all["bigquery"], rewritten_sql_list, db_name)
elif dialect == "mysql":
    result_tuple_list = test_valid_mysql(db_creds_all["mysql"], rewritten_sql_list, db_name)
elif dialect == "sqlite":
    result_tuple_list = test_valid_sqlite(db_creds_all["sqlite"], rewritten_sql_list, db_name)
elif dialect == "tsql":
    result_tuple_list = test_valid_tsql(db_creds_all["tsql"], rewritten_sql_list, db_name)
else:
    raise ValueError("Dialect not supported")
valid_list, err_msg_list = map(list, zip(*result_tuple_list))
for i in rewritten_sql_list:
    print(i)
print("Valid list:", valid_list)

In [None]:
# Replace invalid sql with rewritten sql in the original dataset
# If all items True in valid_list, then replace the original sql with rewritten sql
if all(valid_list):
    df_index = invalid_eg.index[0]
    df.at[df_index, "sql_list"] = rewritten_sql_list
    df.at[df_index, "valid_list"] = valid_list
    df.at[df_index, "err_msg_list"] = err_msg_list
    df.at[df_index, "query"] = ";".join(rewritten_sql_list).replace(";;", ";")
    print("Updated original dataset with rewritten SQL at index", df_index)
    df2 = df.drop(columns=["valid_list", "err_msg_list", "sql_list"])
    df2.to_csv(csv_file, index=False)

    # remove the row from invalid_df
    invalid_df = invalid_df[invalid_df.index != df_index]
    print("Removed invalid query from the invalid dataset")


<a id="inaccurate"></a>
### Correct valid but inaccurate translated SQL

Use the next few cells if you've discovered a wrong SQL that does not accurately answer the question.

In [None]:
#TODO: Insert question that has inaccurate SQL
qn_inaccurate = "" 

# Get row from df where qn_inaccurate is a substring in the query column
inacc_eg = df[df["question"].str.contains(qn_inaccurate, case=False, na=False)]
inacc_eg

In [None]:
#TODO: Specify the indices of inaccurate SQL in the sql_list
inacc_indices = [0]

# Check that indices are in range of valid_list
if any([i >= len(inacc_eg["valid_list"].iloc[0]) for i in inacc_indices]):
    raise Exception("Index out of range. Please check the indices again.")

# Store values
db_name = inacc_eg["db_name"].values[0]
query = inacc_eg["query"].values[0]
question = inacc_eg["question"].values[0]
instructions = inacc_eg["instructions"].values[0] if "instructions" in inacc_eg.columns else ""
sql_list = inacc_eg["sql_list"].values[0]

In [None]:
# use LLM to rewrite the SQL for the dialect
from openai import OpenAI
from defog_utils.utils_sql import normalize_sql
import json
openai = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
model = "gpt-4o"

# TODO: Specify reasons for why the SQL is inaccurate for the question"
special_instructions = """
"""

def rewrite_inacc_sql(
    model: str,
    sql: str,
    question: str,
    instructions: str,
    to_dialect: str,
    special_instructions: str,
) -> str:
    """
    Use LLM to rewrite inaccurate SQL for the dialect
    """
    messages = [
        {
            "role": "system",
            "content": f"""Your task is to rewrite an inaccurate SQL query in the {to_dialect} dialect to answer a specific question. Analyze the question and SQL query to determine why the SQL query is inaccurate before rewriting it.""",
        },
        {
            "role": "user",
            "content": f"""Question to answer: {question}
Instructions: {instructions}
Inaccurate SQL: {sql}
{special_instructions}

Format your response as a valid JSON string with reason and sql keys. 
Your response should look like the string below:
{{
    "reason": "Your reasoning for the response",
    "sql": "The valid rewritten query for {to_dialect}"
}}

Do not include any other information before and after the JSON string.
""",
        },
    ]
    completion = openai.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=2000,
        temperature=0,
        response_format={"type": "json_object"},
    )
    completion = completion.choices[0].message.content
    try:
        completion_dict = json.loads(completion)
    except:
        print(f"Error parsing completion {completion}")
        completion_dict = {
            "sql": None,
            "reason": None,
        }  
    return completion_dict

rewritten_sql_list = []
for i in range(len(sql_list)):
    if i in inacc_indices:
        sql = sql_list[i]
        completion_dict = rewrite_inacc_sql(model, sql, question, instructions, dialect, special_instructions)
        sql_rewritten = completion_dict['sql']

        if sql_rewritten is not None:
            sql_rewritten = normalize_sql(sql_rewritten)
            rewritten_sql_list.append(sql_rewritten)
            print("Reason: ", completion_dict['reason'])
            print("Rewritten SQL: ", sqlparse.format(sql_rewritten, reindent=True, keyword_case='upper'))
            print("\n")
    else:
        print(i)
        rewritten_sql_list.append(sql_list[i])

# Ensure no duplicates in rewritten_sql_list
rewritten_sql_list = list(set(rewritten_sql_list))

In [None]:
# Test validity of all SQL in rewritten_sql_list
if dialect == "bigquery":
    result_tuple_list = test_valid_bq(db_creds_all["bigquery"], rewritten_sql_list, db_name)
elif dialect == "mysql":
    result_tuple_list = test_valid_mysql(db_creds_all["mysql"], rewritten_sql_list, db_name)
elif dialect == "sqlite":
    result_tuple_list = test_valid_sqlite(db_creds_all["sqlite"], rewritten_sql_list, db_name)
elif dialect == "tsql":
    result_tuple_list = test_valid_tsql(db_creds_all["tsql"], rewritten_sql_list, db_name)
else:
    raise ValueError("Dialect not supported")
valid_list, err_msg_list = map(list, zip(*result_tuple_list))
for i in rewritten_sql_list:
    print(i)
print("Valid list:", valid_list)

In [None]:
# Replace inaccurate sql with rewritten sql in the original dataset
# If all items True in valid_list, then replace the original sql with rewritten sql
if all(valid_list):
    df_index = inacc_eg.index[0]
    df.at[df_index, "sql_list"] = rewritten_sql_list
    df.at[df_index, "valid_list"] = valid_list
    df.at[df_index, "err_msg_list"] = err_msg_list
    df.at[df_index, "query"] = ";".join(rewritten_sql_list).replace(";;", ";")
    print("Updated original dataset with rewritten SQL at index", df_index)
    df2 = df.drop(columns=["valid_list", "err_msg_list", "sql_list"])
    df2.to_csv(csv_file, index=False)

<a id="instructions"></a>
### Translate just the instructions column

Use the next cell to translate the instructions with dialect-specific SQL syntax using the instructions_to_{dialect} functions in `utils/dialects.py`

In [None]:
# Feel free to modify this cell for future dialects

from utils.dialects import instructions_to_sqlite, instructions_to_tsql, instructions_to_mysql

if "instructions" in df.columns:
    if dialect == "sqlite":
        df['instructions'] = df['instructions'].fillna("")
        df["instructions"] = df.progress_apply(
            lambda x: instructions_to_sqlite(x["instructions"]), axis=1
        )
    elif dialect == "tsql":
        df['instructions'] = df['instructions'].fillna("")
        df["instructions"] = df.progress_apply(
            lambda x: instructions_to_tsql(x["instructions"]), axis=1
        )
    elif dialect == "mysql":
        print(df['instructions'].nunique())
        df['instructions'] = df['instructions'].fillna("")
        df["instructions"] = df.progress_apply(
            lambda x: instructions_to_mysql(x["instructions"]), axis=1
        )
    else:
        raise ValueError(f"Dialect not yet supported for instructions translation. Please add an instructions_to_{dialect} function in utils/dialects.py")
else:
    print("No instructions column in the dataframe")
    
if "full_instructions" in df.columns:
    if dialect == "sqlite":
        df['full_instructions'] = df['full_instructions'].fillna("")
        df["full_instructions"] = df.progress_apply(
            lambda x: instructions_to_sqlite(x["full_instructions"]), axis=1
        )
    elif dialect == "tsql":
        df['full_instructions'] = df['full_instructions'].fillna("")
        df["full_instructions"] = df.progress_apply(
            lambda x: instructions_to_tsql(x["full_instructions"]), axis=1
        )
    elif dialect == "mysql":
        df['full_instructions'] = df['full_instructions'].fillna("")
        df["full_instructions"] = df.progress_apply(
            lambda x: instructions_to_mysql(x["full_instructions"]), axis=1
        )
    else:
        raise ValueError(f"Dialect not yet supported for instructions translation. Please add an instructions_to_{dialect} function in utils/dialects.py")
else:
    print("No full_instructions column in the dataframe")

<a id="inspectinstr"></a>
### Inspect translated instructions column

In [None]:
# Show all unique values in the instructions column
if "instructions" in df.columns:
    instructions = df["instructions"].unique()
    print("Instructions in the dataset:")
    for i in instructions:
        print("-", i)

In [None]:
# Show all unique values in the full_instructions column
if "full_instructions" in df.columns:
    full_instructions = df["full_instructions"].unique()
    print("Full instructions in the dataset:")
    for i in full_instructions:
        print(i, "\n")

In [None]:
# Update the csv file with the new translated instructions
df.to_csv(csv_file, index=False)