# Import Libraries

In [None]:
from text_to_sql.text_to_sql import (
    TextToSQL,
    Config,
    LLMConfig,
    SLConfig,
    ContextConfig,
    QueryConfig,
)
from dotenv import load_dotenv
from datetime import datetime

import pandas as pd
import os




# Constants

In [2]:
MAX_RETRIES = 5
RETRY_DELAY = 2

# Load Environment

In [3]:
load_dotenv()

True

# Set Timestamp Experiment

In [4]:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M")
output_dir = f"./experiment_result/{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Config

In [5]:
config = Config(
    max_retry_attempt=5,
    rewriter_config=LLMConfig(
        type="api",
        model="gemini-1.5-flash",
        provider="gemini",
        api_key=os.getenv("API_KEY"),
    ),
    query_generator_config=LLMConfig(
        type="api",
        model="gemini-1.5-flash",
        provider="gemini",
        api_key=os.getenv("API_KEY"),
    ),
    schema_linker_config=SLConfig(
        type="api",
        model="gemini-1.5-flash",
        provider="gemini",
        api_key=os.getenv("API_KEY"),
        schema_path="./metadata/sakila.json",
    ),
    retrieve_context_config=ContextConfig(data_path="./dataset/dataset_sakila.csv"),
    query_executor_config=QueryConfig(
        host=os.getenv("DB_HOST"),
        database=os.getenv("DB_DATABASE"),
        user=os.getenv("DB_USER"),
        password=os.getenv("DB_PASSWORD"),
        port=os.getenv("DB_PORT"),
    ),
)

# Model

In [6]:
text_to_sql_model = TextToSQL(config=config)

Initializing API client for gemini using model gemini-1.5-flash.
Initializing API client for gemini using model gemini-1.5-flash.
Initializing API client for gemini using model gemini-1.5-flash.


  from google.protobuf import service as _service


# Import Dataset

In [7]:
dataset = pd.read_csv("./dataset/dataset_sakila.csv")
dataset = dataset[dataset["Summary"].notna()]

In [8]:
dataset

Unnamed: 0,Question,Answer,Summary,Alternative Prompt 1,Alternative Prompt 2
0,Which actors have the first name ‘Scarlett’,SELECT * FROM actor WHERE first_name = 'Scarle...,This SQL query retrieves all records from the ...,Can you give me actors that have first name Sc...,I want to know actors that have the first name...
1,How many distinct actors last names are there?,SELECT COUNT(DISTINCT last_name) FROM actor;,This SQL query calculates the count of distinc...,"Hey, I'm curious—how many different last names...",Can you tell me how many unique actor last nam...
2,Which actor has appeared in the most films?,"SELECT \r\n a.actor_id, \r\n a.first_nam...",This SQL query identifies the actor who has ap...,Who is the most frequently appearing actor in ...,Can you find out which actor has been in the m...
3,List the top five genres in gross revenue in d...,"SELECT \r\n c.name AS genre, \r\n SUM(p....",This SQL query calculates the top 5 highest-gr...,Which five movie genres have made the most mon...,Can you show me the top five highest-earning g...
4,"Write a query to display how much business, in...","SELECT s.store_id, SUM(p.amount) AS total_reve...",This SQL query calculates the total revenue ge...,How much money has each store made from rentals?,Can you break down the total revenue for each ...
5,Which language is used in most films?,SELECT \n l.name\nFROM \n language l\n ...,This SQL query identifies the most frequently ...,Which language is the most common in the movie...,Can you find out which language is used the mo...
6,List the top five customers in number of rente...,"SELECT\n c.customer_id,\n c.first_name,\...",This SQL query identifies the top 5 customers ...,Who are our five most active customers based o...,Can you tell me which five customers have rent...
7,Which customers have rented films from more th...,"SELECT \r\n c.customer_id, \r\n c.first_...",This SQL query identifies customers who have r...,Which customers have explored a variety of gen...,Can you list customers who have rented movies ...
8,Which films have never been rented out? Show t...,"SELECT \r\n f.film_id, \r\n f.title, \r\...",This SQL query identifies films that have neve...,Are there any movies in our collection that no...,Can you find out which films have never been r...
9,What is the total revenue generated by each ac...,"SELECT \r\n a.actor_id, \r\n a.first_nam...",This SQL query calculates the total revenue ge...,How much money has each actor helped generate ...,Can you show me the total revenue for each act...


# Experiment Baseline Multistage

In [None]:
import time

EA = 0
total_questions = len(dataset) * 2
results_list = []

for idx, row in dataset.iterrows():
    question_1 = row["Alternative Prompt 1"]
    question_2 = row["Alternative Prompt 2"]
    answer = row["Answer"]

    for prompt_id, question in enumerate([question_1, question_2], start=1):
        print(f"\nProcessing Question {idx + 1}.{prompt_id}: {question}")
        result = None

        for attempt in range(1, MAX_RETRIES + 1):
            try:
                result = text_to_sql_model.generate_baseline(user_prompt=question, method="Multistage")
                break
            except Exception as e:
                print(f"[Attempt {attempt}] Failed to generate SQL: {e}")
                if attempt < MAX_RETRIES:
                    time.sleep(RETRY_DELAY)
                else:
                    print("Max retries reached. Setting result as 'ERROR'")
                    result = "ERROR"

        print(f"Generated SQL Query: {result}")

        try:
            acc = text_to_sql_model.evaluate(query=result, true_query=answer)
        except Exception as e:
            print(f"Evaluation failed: {e}")
            acc = 0.0

        print(f"Execution Accuracy: {acc:.4f}")

        results_list.append({
            "Question ID": f"{idx + 1}.{prompt_id}",
            "Question": question,
            "Generated SQL Query": result,
            "Expected SQL Query": answer,
            "Execution Accuracy": acc
        })

        EA += acc

# Calculate final execution accuracy
final_accuracy = EA / total_questions if total_questions > 0 else 0
print(f"\nFinal Execution Accuracy: {final_accuracy:.4f}")


Processing Question 1.1: Can you give me actors that have first name Scarlett
Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT * FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 1.2: I want to know actors that have the first name Scarlett
Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT * FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 2.1: Hey, I'm curious—how many different last names do the actors in our database have?
Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT COUNT(DISTINCT last_name) FROM actor;
Execution Accura

In [9]:
df_results_baseline_multistage = pd.DataFrame(results_list)
df_results_baseline_multistage.to_csv(f"{output_dir}/sql_execution_results_baseline_multistage.csv", index=False)

# Experiment Baseline Incremental

In [14]:
import time

EA = 0
total_questions = len(dataset) * 2
results_list = []

for idx, row in dataset.iterrows():
    question_1 = row["Alternative Prompt 1"]
    question_2 = row["Alternative Prompt 2"]
    answer = row["Answer"]

    for prompt_id, question in enumerate([question_1, question_2], start=1):
        print(f"\nProcessing Question {idx + 1}.{prompt_id}: {question}")
        result = None

        for attempt in range(1, MAX_RETRIES + 1):
            try:
                result = text_to_sql_model.generate_baseline(user_prompt=question, method="Incremental")
                break
            except Exception as e:
                print(f"[Attempt {attempt}] Failed to generate SQL: {e}")
                if attempt < MAX_RETRIES:
                    time.sleep(RETRY_DELAY)
                else:
                    print("Max retries reached. Setting result as 'ERROR'")
                    result = "ERROR"

        print(f"Generated SQL Query: {result}")

        try:
            acc = text_to_sql_model.evaluate(query=result, true_query=answer)
        except Exception as e:
            print(f"Evaluation failed: {e}")
            acc = 0.0

        print(f"Execution Accuracy: {acc:.4f}")

        results_list.append({
            "Question ID": f"{idx + 1}.{prompt_id}",
            "Question": question,
            "Generated SQL Query": result,
            "Expected SQL Query": answer,
            "Execution Accuracy": acc
        })

        EA += acc

# Calculate final execution accuracy
final_accuracy = EA / total_questions if total_questions > 0 else 0
print(f"\nFinal Execution Accuracy: {final_accuracy:.4f}")


Processing Question 1.1: Can you give me actors that have first name Scarlett
Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT first_name, last_name FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 1.2: I want to know actors that have the first name Scarlett
Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT actor_id, first_name, last_name FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 2.1: Hey, I'm curious—how many different last names do the actors in our database have?
Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT COUN

In [15]:
df_results_baseline_incremental = pd.DataFrame(results_list)
df_results_baseline_incremental.to_csv(f"{output_dir}/sql_execution_results_baseline_incremental.csv", index=False)

# Experiment V1 Multistage

In [16]:
import time

EA = 0
total_questions = len(dataset) * 2
results_list = []

for idx, row in dataset.iterrows():
    question_1 = row["Alternative Prompt 1"]
    question_2 = row["Alternative Prompt 2"]
    answer = row["Answer"]

    for prompt_id, question in enumerate([question_1, question_2], start=1):
        print(f"\nProcessing Question {idx + 1}.{prompt_id}: {question}")
        result = None

        for attempt in range(1, MAX_RETRIES + 1):
            try:
                result = text_to_sql_model.generate_v1(user_prompt=question, method="Multistage")
                break
            except Exception as e:
                print(f"[Attempt {attempt}] Failed to generate SQL: {e}")
                if attempt < MAX_RETRIES:
                    time.sleep(RETRY_DELAY)
                else:
                    print("Max retries reached. Setting result as 'ERROR'")
                    result = "ERROR"

        print(f"Generated SQL Query: {result}")

        try:
            acc = text_to_sql_model.evaluate(query=result, true_query=answer)
        except Exception as e:
            print(f"Evaluation failed: {e}")
            acc = 0.0

        print(f"Execution Accuracy: {acc:.4f}")

        results_list.append({
            "Question ID": f"{idx + 1}.{prompt_id}",
            "Question": question,
            "Generated SQL Query": result,
            "Expected SQL Query": answer,
            "Execution Accuracy": acc
        })

        EA += acc

# Calculate final execution accuracy
final_accuracy = EA / total_questions if total_questions > 0 else 0
print(f"\nFinal Execution Accuracy: {final_accuracy:.4f}")


Processing Question 1.1: Can you give me actors that have first name Scarlett
Rewritten Prompt: Retrieve actors with the first name Scarlett.

Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT * FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 1.2: I want to know actors that have the first name Scarlett
Rewritten Prompt: Retrieve actors with the first name Scarlett.

Related Tables: {'film', 'category', 'actor', 'staff', 'address', 'film_category', 'inventory', 'film_actor', 'language', 'city', 'store', 'country'}
Generated SQL Query: SELECT * FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 2.1: Hey, I'm curious—how many different last names do the actors in our database have?
Rewritten Prompt: Determine the number of unique last names among actors in the database.

Related Tab

In [17]:
df_results_v1_multistage = pd.DataFrame(results_list)
df_results_v1_multistage.to_csv(f"{output_dir}/sql_execution_results_v1_multistage.csv", index=False)

# Experiment V1 Incremental

In [9]:
import time

EA = 0
total_questions = len(dataset) * 2
results_list = []

for idx, row in dataset.iterrows():
    question_1 = row["Alternative Prompt 1"]
    question_2 = row["Alternative Prompt 2"]
    answer = row["Answer"]

    for prompt_id, question in enumerate([question_1, question_2], start=1):
        print(f"\nProcessing Question {idx + 1}.{prompt_id}: {question}")
        result = None

        for attempt in range(1, MAX_RETRIES + 1):
            try:
                result = text_to_sql_model.generate_v1(user_prompt=question, method="Incremental")
                break
            except Exception as e:
                print(f"[Attempt {attempt}] Failed to generate SQL: {e}")
                if attempt < MAX_RETRIES:
                    time.sleep(RETRY_DELAY)
                else:
                    print("Max retries reached. Setting result as 'ERROR'")
                    result = "ERROR"

        print(f"Generated SQL Query: {result}")

        try:
            acc = text_to_sql_model.evaluate(query=result, true_query=answer)
        except Exception as e:
            print(f"Evaluation failed: {e}")
            acc = 0.0

        print(f"Execution Accuracy: {acc:.4f}")

        results_list.append({
            "Question ID": f"{idx + 1}.{prompt_id}",
            "Question": question,
            "Generated SQL Query": result,
            "Expected SQL Query": answer,
            "Execution Accuracy": acc
        })

        EA += acc

# Calculate final execution accuracy
final_accuracy = EA / total_questions if total_questions > 0 else 0
print(f"\nFinal Execution Accuracy: {final_accuracy:.4f}")


Processing Question 1.1: Can you give me actors that have first name Scarlett
Rewritten Prompt: Retrieve actors with the first name Scarlett.

Related Tables: {'actor', 'store', 'film', 'city', 'country', 'address', 'category', 'staff', 'language', 'inventory', 'film_actor', 'film_category'}
Generated SQL Query: SELECT first_name, last_name FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 1.2: I want to know actors that have the first name Scarlett
Rewritten Prompt: Retrieve all actors with the first name Scarlett.

Related Tables: {'actor', 'store', 'film', 'city', 'country', 'address', 'category', 'staff', 'language', 'inventory', 'film_actor', 'film_category'}
Generated SQL Query: SELECT first_name, last_name FROM actor WHERE first_name = 'Scarlett';
Execution Accuracy: 1.0000

Processing Question 2.1: Hey, I'm curious—how many different last names do the actors in our database have?
Rewritten Prompt: Determine the count of unique last names

In [10]:
df_results_v1_incremental = pd.DataFrame(results_list)
df_results_v1_incremental.to_csv(f"{output_dir}/sql_execution_results_v1_incremental.csv", index=False)