## Automated Error Analysis

This notebook uses LLMs to analyze and summarize the errors in the eval results across a single common dataset.

In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)
import matplotlib.pyplot as plt
import os
from openai import OpenAI
import os
import json
import warnings
from eval.eval import get_all_minimal_queries
warnings.filterwarnings('ignore')

openai = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
model_4_latest = "gpt-4-turbo-2024-04-09"

### 1. Extract all results for a single dataset name

In [None]:
results_dir_path = f"results/"
ds_to_analyze = "006b" # Provide the dataset name e.g. 001, 002 etc


In [None]:
def list_csv_files(path):
    # Get all file names in the directory
    file_names = os.listdir(path)

    # Filter for only .csv files
    csv_files = [f for f in file_names if f.endswith('.csv')]

    # Sort files alphabetically
    csv_files.sort()

    # Create dictionary with numbered keys
    csv_dict = {i: csv_files[i] for i in range(len(csv_files))}

    return csv_dict

csv_files = list_csv_files(results_dir_path)

In [None]:
# Get indices of csv_files with these keywords
keywords = [ds_to_analyze]
keywords_to_exclude = []
selected_models = [i for i, s in enumerate(csv_files.values()) if all(xs in s for xs in keywords)]
if keywords_to_exclude:
    selected_models = [i for i in selected_models if not any(xs in csv_files[i] for xs in keywords_to_exclude)]

# Print selected models
print("Results to analyze:")
for i in selected_models:
    print(f"{csv_files[i]}")

# Load results from csv file into dataframe
dfs = {}
for id in selected_models:
    file_name = csv_files[id]
    model = file_name.replace('.csv', '')
    dfs[model] = pd.read_csv(results_dir_path + file_name, comment='#')

In [None]:
#Combine all dataframes of selected models into one
all_dfs = []
for model in dfs:
    temp_df = dfs[model]
    temp_df['model'] = model
    all_dfs.append(temp_df)
df = pd.concat(all_dfs)
# Apply get_all_minimal_queries to all queries
df['true_queries'] = df['query'].apply(get_all_minimal_queries)
df['error_msg_short'] = df['error_msg'].str.split("\n\n").str[0].str.replace("QUERY EXECUTION ERROR:", "")

# Split model column by the last underscore 
df['eval'] = df['model'].str.rsplit(pat='_', n=1).str[1]
df['model'] = df['model'].str.rsplit(pat='_', n=1).str[0]
df.head(1)


### 2. Analyze correctness by category

In [None]:
# Plot the correctness by category
df_category_correct = df.pivot_table("correct", "query_category", aggfunc="mean").sort_values('correct', ascending=False)
df_category_correct.plot(kind='barh', color='skyblue', figsize=(10, 6))
plt.title('Correctness by SQL category')
plt.xlabel('Correctness')
plt.ylabel('Category')
# add labels
for i, v in enumerate(df_category_correct['correct']):
    plt.text(v, i, f"{v*100:.2f}%", color='black', va='center')
ax = plt.gca()
for spine in ['right', 'top']:
    ax.spines[spine].set_visible(False)
plt.show()

In [None]:
# Convert df to dict
category_corr_dict = df_category_correct.sort_values('correct', ascending=True).to_dict()['correct']
category_corr_dict

### 3. Analyze invalid SQL with DB exec errors

In [None]:
# Get db exec error rows across all result files
df_error_exec = df[df['error_db_exec'] == 1][['model', 'db_name', 'question', 'error_db_exec', 'error_msg_short', 'true_queries', 'generated_query']].sort_values(['db_name','question'])
# Get questions with recurring exec errors
df_error_exec_recurr = df_error_exec[df_error_exec.duplicated(subset=['db_name', 'question'], keep=False)][['question', 'error_db_exec', 'error_msg_short', 'true_queries', 'generated_query']]
print(f"{len(df_error_exec_recurr['question'].unique())} questions with recurring exec errors")
df_error_exec_recurr.head(2)

In [None]:
# Convert error_msg_short col to a string of bullet points
error_exec_str = "\n".join([f"- {x}" for x in df_error_exec_recurr['error_msg_short']])

Extract patterns from db execution error messages

In [None]:
# Get error exec patterns
def get_error_exec_patterns(
    model: str, error_exec_str: str
) -> str:
    """
    Use LLM to extract recurring patterns in a list of error messages.
    """
    messages = [
        {
            "role": "system",
            "content": f"""Your task is to identify recurring patterns in the error messages below and provide a summary of the patterns."""
        },
        {
            "role": "user",
            "content": f"""List of error messages:
{error_exec_str}

Format your response as a numbered list of recurring patterns in the error messages. 
Each point should be a concise yet detailed summary of a trend identified in the error messages along with specific examples.
Do not include any other information before and after the list.
""",
        },
    ]

    completion = openai.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=1000,
        temperature=0,
        # top_p=0.5,
        # response_format = {"type": "json_object"}
    )
    completion = completion.choices[0].message.content
    return completion

In [None]:
error_exec_summary = get_error_exec_patterns(model_4_latest, error_exec_str)
print(error_exec_summary)

### 4. Analyze valid but wrong examples

In [None]:
# Get valid but correct examples across all result files
df_valid_wrong = df[(df['correct'] == 0) & (df['error_db_exec'] == 0)][['model', 'db_name', 'query_category', 'question', 'instructions', 'correct', 'error_db_exec', 'true_queries', 'generated_query']].sort_values(['db_name','question']).fillna('')
# Get questions that were repeatedly valid but wrong
df_valid_wrong_recurr = df_valid_wrong[df_valid_wrong.duplicated(subset=['db_name', 'question'], keep=False)][['db_name', 'query_category', 'question', 'instructions', 'true_queries', 'generated_query']]
print(f"{len(df_valid_wrong_recurr['question'].unique())} unique questions that are recurring valid but wrong")
df_valid_wrong_recurr.head(3)

In [None]:
# Get first row of all duplicates
# To reduce the number of LLM calls, we will assume that all duplicates are wrong in the same way
df_valid_wrong_recurr_first = df_valid_wrong_recurr.drop_duplicates(subset=['db_name', 'question'], keep='first')[['query_category', 'question', 'instructions', 'true_queries', 'generated_query']]
df_valid_wrong_recurr_first.head(4)

In [None]:
# Get reasons for valid but wrong examples
def explain_incorrect(
    model: str, question: str, instructions: str, true_sqls: list, generated_sql: str
) -> str:
    """
    Use LLM to explain why a SQL query is incorrect given a question, instructions and the true SQL queries.
    """
    if instructions:
        instructions = f"\nInstructions: {instructions}\n"
    messages = [
        {
            "role": "system",
            "content": f"""Your task is to explain why the SQL query is incorrect given the question, instructions and the true SQL queries."""
        },
        {
            "role": "user",
            "content": f"""Question: {question}{instructions}
Incorrect SQL: {generated_sql}

True SQL queries:
{true_sqls}

Format your response as a valid JSON string with reason as a key. 
Your response should look like the string below:
{{ "reason": "Your reasoning for why the SQL query is incorrect according to the question and the true SQL queries."
}}

Do not include any other information before and after the JSON string.
""",
        },
    ]

    completion = openai.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=500,
        temperature=0,
        # top_p=0.5,
        response_format = {"type": "json_object"}
    )
    completion = completion.choices[0].message.content
    try:
        completion_dict = json.loads(completion)
    except:
        print(f"Error parsing completion {completion}", flush=True)
        completion_dict = {"reason": None}
    reason = completion_dict.get("reason", None)
    return reason

In [None]:
from tqdm.notebook import tqdm
tqdm.pandas()
# Get explanations for valid but wrong examples
df_valid_wrong_recurr_first['reason_incorrect'] = df_valid_wrong_recurr_first.progress_apply(lambda x: explain_incorrect(model_4_latest, x['question'], x['instructions'], x['true_queries'], x['generated_query']), axis=1)
df_valid_wrong_recurr_first

In [None]:
# Convert reason_incorrect col to a string of bullet points
reason_incorrect_str = "\n".join([f"- {x}" for x in df_valid_wrong_recurr_first['reason_incorrect']])

Get recurring patterns from reasons of valid but wrong examples.

In [None]:
# Get reasons for valid but wrong examples
def get_valid_wrong_patterns(
    model: str, reason_incorrect_str: str
) -> str:
    """
    Use LLM to extract recurring patterns in a list of error messages that describe why a SQL query is wrong according to the question and the true SQL queries.
    """
    messages = [
        {
            "role": "system",
            "content": f"""Your task is to identify recurring patterns in the error messages below and provide a summary of the patterns."""
        },
        {
            "role": "user",
            "content": f"""List of error messages that describe why a SQL query is wrong according to the question and the true SQL queries:
{reason_incorrect_str}

Format your response as a numbered list of recurring patterns in the error messages. 
Each point should be a concise yet detailed summary of a trend identified in the error messages along with specific examples (e.g. inability to follow instructions, common errors in specific SQL categories, etc.).
Do not include any other information before and after the list.
""",
        },
    ]

    completion = openai.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=1000,
        temperature=0,
        # top_p=0.5,
        # response_format = {"type": "json_object"}
    )
    completion = completion.choices[0].message.content
    return completion

In [None]:
valid_wrong_summary = get_valid_wrong_patterns(model_4_latest, reason_incorrect_str)
print(valid_wrong_summary)

In [19]:
# Store summaries in a json file
summary_dict = {
    "category_corr_dict": category_corr_dict,
    "error_exec_summary": error_exec_summary,
    "valid_wrong_summary": valid_wrong_summary
}
output_file = f"{results_dir_path}error_analysis_ds_{ds_to_analyze}.json"
with open(output_file, 'w') as f:
    json.dump(summary_dict, f, indent=4)