# Import Libraries

In [1]:
import sys
import os

notebook_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))

if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

In [2]:
from text_to_sql import (
    TextToSQL,
    Config,
    LLMConfig,
    SLConfig,
    ContextConfig,
    QueryConfig,
)
from dotenv import load_dotenv
from datetime import datetime

import pandas as pd
import os




# Constants

In [3]:
DATABASE = "northwind"
MODEL = "gemini-1.5-pro"
PROVIDER = "gemini"
TOTAL_TABLES = 14

# Load Environment

In [4]:
load_dotenv()

True

# Set Timestamp Experiment

In [5]:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M")
output_dir = f"../files/experiment_result/{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Config

In [6]:
db_key = DATABASE.upper().replace("-", "_")
provider_key = PROVIDER.upper().replace("-", "_")

config = Config(
    max_retry_attempt=5,
    rewriter_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    query_generator_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    schema_linker_config=SLConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
        schema_path=f"../files/schema/{DATABASE}.txt",
        metadata_path=f"../files/metadata/{DATABASE}.json",
    ),
    retrieve_context_config=ContextConfig(data_path=f"../files/dataset/dataset_{DATABASE}.csv"),
    query_executor_config = QueryConfig(
        host=os.getenv(f"DB_HOST_{db_key}"),
        database=os.getenv(f"DB_DATABASE_{db_key}"),
        user=os.getenv(f"DB_USER_{db_key}"),
        password=os.getenv(f"DB_PASSWORD_{db_key}"),
        port=os.getenv(f"DB_PORT_{db_key}"),
    ),
)

# Model

In [7]:
text_to_sql_model = TextToSQL(config=config)

Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.


  from google.protobuf import service as _service


# Import Dataset

In [8]:
dataset = pd.read_csv(f"../files/dataset/dataset_rewriter.csv")

In [9]:
dataset["tables_used"] = dataset["tables_used"].apply(eval)

In [10]:
dataset

Unnamed: 0,prompt,tables_used
0,hjsz pendapatan kita tahun 1997 berapa59 sih? ...,"[order_details, orders]"
1,Tiap yv udah bayar0 ke kita totalnya berapa?26...,"[customers, order_details, orders]"
2,kspxm paling laku dt pendapatan setelah ndion...,"[order_details, products]"
3,C~ustomer UK yang bayar dari $1000 Tampilk...,"[customers, order_details, orders]"
4,B~erapa total pembayaran YANG dilakukan oleh s...,"[customers, order_details, orders]"
...,...,...
95,Top 1 plourpa category by quantity ordered? S~...,"[categories, order_details, products]"
96,owSh customers company) and how many73 ce the...,"[customers, orders]"
97,Total ORDERS per country shipped to? Show 'shi...,[orders]
98,How many orders per shipper per year? Show 'sh...,"[orders, shippers]"


# Experiment

In [11]:
def extract_table_names_from_schema(schema: dict) -> list:
    """
    Given a schema dictionary (from predict_schema_only),
    return a list of lowercase table names.
    """
    return [table["name"].lower() for table in schema.get("tables", [])]

In [12]:
def calculate_reduction(predicted_tables):
    return 1 - (len(predicted_tables) / TOTAL_TABLES)

In [13]:
accuracies = []
predicted_tables_list = []

for _, row in dataset.iterrows():
    prompt = row["prompt"]
    true_tables = set(row["tables_used"])

    predicted_schema = text_to_sql_model.predict_schema_only(prompt)
    predicted_tables = set(extract_table_names_from_schema(predicted_schema))

    # Save predicted tables
    predicted_tables_list.append(list(predicted_tables))

    # Calculate accuracy
    intersection_count = len(true_tables.intersection(predicted_tables))
    total_true_tables = len(true_tables)

    if total_true_tables > 0:
        accuracy = intersection_count / total_true_tables
    else:
        accuracy = 1.0 if not predicted_tables else 0.0

    accuracies.append(accuracy)

Related Tables: {'employee_territories', 'employees', 'orders', 'region', 'shippers', 'us_states', 'order_details', 'customer_demographics', 'customers', 'suppliers', 'customer_customer_demo', 'territories', 'products', 'categories'}
Related Tables: {'employee_territories', 'employees', 'orders', 'region', 'shippers', 'us_states', 'order_details', 'customer_demographics', 'customers', 'suppliers', 'customer_customer_demo', 'territories', 'products', 'categories'}
Related Tables: {'employees', 'orders', 'shippers', 'order_details', 'customers', 'suppliers', 'customer_demographics', 'customer_customer_demo', 'products', 'categories'}
Related Tables: {'employees', 'orders', 'shippers', 'order_details', 'customer_demographics', 'customers', 'suppliers', 'customer_customer_demo', 'products', 'categories'}
Related Tables: {'employees', 'orders', 'shippers', 'order_details', 'customers', 'customer_demographics', 'suppliers', 'customer_customer_demo', 'products', 'categories'}
Related Tables: 

In [14]:
dataset["predicted_tables"] = predicted_tables_list
dataset["schema_accuracy"] = accuracies
dataset['schema_reduction'] = dataset['predicted_tables'].apply(calculate_reduction)

In [15]:
final_accuracy = sum(accuracies) / len(accuracies)
print(f"Schema Prediction Accuracy (Intersection-Based): {final_accuracy:.2%}")

average_reduction = dataset['schema_reduction'].mean()
print(f"Average Schema Reduction: {average_reduction:.2%}")

Schema Prediction Accuracy (Intersection-Based): 99.50%
Average Schema Reduction: 23.29%
