# Import Libraries

In [17]:
import sys
import os

notebook_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))

if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

In [18]:
from text_to_sql import (
    TextToSQL,
    Config,
    LLMConfig,
    SLConfig,
    ContextConfig,
    QueryConfig,
)
from dotenv import load_dotenv
from datetime import datetime

import pandas as pd
import os

# Constants

In [19]:
DATABASE = "northwind"
MODEL = "gemini-1.5-pro"
PROVIDER = "gemini"
TOTAL_TABLES = 14

# Load Environment

In [20]:
load_dotenv()

True

# Set Timestamp Experiment

In [21]:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M")
output_dir = f"../files/experiment_result/{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Config

In [22]:
db_key = DATABASE.upper().replace("-", "_")
provider_key = PROVIDER.upper().replace("-", "_")

config = Config(
    max_retry_attempt=5,
    rewriter_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    query_generator_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    schema_linker_config=SLConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
        schema_path=f"../files/schema/{DATABASE}.txt",
        metadata_path=f"../files/metadata/{DATABASE}.json",
    ),
    retrieve_context_config=ContextConfig(data_path=f"../files/dataset/dataset_{DATABASE}.csv"),
    query_executor_config = QueryConfig(
        host=os.getenv(f"DB_HOST_{db_key}"),
        database=os.getenv(f"DB_DATABASE_{db_key}"),
        user=os.getenv(f"DB_USER_{db_key}"),
        password=os.getenv(f"DB_PASSWORD_{db_key}"),
        port=os.getenv(f"DB_PORT_{db_key}"),
    ),
)

# Model

In [23]:
text_to_sql_model = TextToSQL(config=config)

Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.


# Import Dataset

In [24]:
dataset = pd.read_csv(f"../files/dataset/dataset_schema_linker.csv")

In [25]:
dataset["tables_used"] = dataset["tables_used"].apply(eval)

In [26]:
dataset

Unnamed: 0,prompt,tables_used
0,Total pendapatan kita tahun 1997 berapa sih? T...,"[order_details, orders]"
1,Tiap customer udah bayar ke kita totalnya bera...,"[customers, order_details, orders]"
2,10 produk paling laku berdasarkan pendapatan s...,"[order_details, products]"
3,Customer dari UK yang bayar lebih dari $1000 s...,"[customers, order_details, orders]"
4,Berapa total pembayaran yang dilakukan oleh se...,"[customers, order_details, orders]"
...,...,...
95,Top 1 popular category by quantity ordered? Sh...,"[categories, order_details, products]"
96,Show customers (not company) and how many empl...,"[customers, orders]"
97,Total orders per country shipped to? Show 'shi...,[orders]
98,How many orders per shipper per year? Show 'sh...,"[orders, shippers]"


# Experiment

In [27]:
def extract_table_names_from_schema(schema: dict) -> list:
    """
    Given a schema dictionary (from predict_schema_only),
    return a list of lowercase table names.
    """
    return [table["name"].lower() for table in schema.get("tables", [])]

In [28]:
def calculate_reduction(predicted_tables):
    return 1 - (len(predicted_tables) / TOTAL_TABLES)

In [29]:
accuracies = []
predicted_tables_list = []

for _, row in dataset.iterrows():
    prompt = row["prompt"]
    true_tables = set(row["tables_used"])

    predicted_schema = text_to_sql_model.predict_schema_only(prompt)
    predicted_tables = set(extract_table_names_from_schema(predicted_schema))

    # Save predicted tables
    predicted_tables_list.append(list(predicted_tables))

    # Calculate accuracy
    intersection_count = len(true_tables.intersection(predicted_tables))
    total_true_tables = len(true_tables)

    if total_true_tables > 0:
        accuracy = intersection_count / total_true_tables
    else:
        accuracy = 1.0 if not predicted_tables else 0.0

    accuracies.append(accuracy)

Related Tables: {'order_details', 'suppliers', 'territories', 'region', 'customer_customer_demo', 'employees', 'products', 'us_states', 'shippers', 'customers', 'customer_demographics', 'employee_territories', 'orders', 'categories'}
Related Tables: {'order_details', 'suppliers', 'customer_customer_demo', 'employees', 'products', 'shippers', 'customers', 'customer_demographics', 'orders', 'categories'}
Related Tables: {'order_details', 'suppliers', 'customer_customer_demo', 'employees', 'products', 'shippers', 'customers', 'customer_demographics', 'orders', 'categories'}
Related Tables: {'order_details', 'suppliers', 'customer_customer_demo', 'employees', 'products', 'shippers', 'customers', 'customer_demographics', 'orders', 'categories'}
Related Tables: {'order_details', 'suppliers', 'customer_customer_demo', 'employees', 'products', 'shippers', 'customers', 'customer_demographics', 'orders', 'categories'}
Related Tables: {'order_details', 'suppliers', 'territories', 'region', 'custo

In [30]:
dataset["predicted_tables"] = predicted_tables_list
dataset["schema_accuracy"] = accuracies
dataset['schema_reduction'] = dataset['predicted_tables'].apply(calculate_reduction)

In [31]:
final_accuracy = sum(accuracies) / len(accuracies)
print(f"Schema Prediction Accuracy (Intersection-Based): {final_accuracy:.2%}")

average_reduction = dataset['schema_reduction'].mean()
print(f"Average Schema Reduction: {average_reduction:.2%}")

Schema Prediction Accuracy (Intersection-Based): 100.00%
Average Schema Reduction: 22.29%


# Save

In [32]:
dataset.to_csv(f"{output_dir}/{MODEL}_{DATABASE}_schema_linker.csv", index=False)