# Import Libraries

In [1]:
import sys
import os

notebook_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))

if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

In [2]:
from text_to_sql import (
    TextToSQL,
    Config,
    LLMConfig,
    SLConfig,
    ContextConfig,
    QueryConfig,
)
from dotenv import load_dotenv
from datetime import datetime

import pandas as pd
import os




# Constants

In [3]:
DATABASE = "northwind"
MODEL = "gemini-1.5-pro"
PROVIDER = "gemini"
TOTAL_TABLES = 14

# Load Environment

In [4]:
load_dotenv()

True

# Set Timestamp Experiment

In [5]:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M")
output_dir = f"../files/experiment_result/{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Config

In [6]:
db_key = DATABASE.upper().replace("-", "_")
provider_key = PROVIDER.upper().replace("-", "_")

config = Config(
    max_retry_attempt=5,
    rewriter_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    query_generator_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    schema_linker_config=SLConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
        schema_path=f"../files/schema/{DATABASE}.txt",
        metadata_path=f"../files/metadata/{DATABASE}.json",
    ),
    retrieve_context_config=ContextConfig(data_path=f"../files/dataset/dataset_{DATABASE}.csv"),
    query_executor_config = QueryConfig(
        host=os.getenv(f"DB_HOST_{db_key}"),
        database=os.getenv(f"DB_DATABASE_{db_key}"),
        user=os.getenv(f"DB_USER_{db_key}"),
        password=os.getenv(f"DB_PASSWORD_{db_key}"),
        port=os.getenv(f"DB_PORT_{db_key}"),
    ),
)

# Model

In [7]:
text_to_sql_model = TextToSQL(config=config)

Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.


  from google.protobuf import service as _service


# Experiment

In [8]:
# Rewriter prompt
user_prompt = "tampilin itu pesen yg bnyk dri kota smw aj klo bsa cepet"
print(f"User Prompt: {user_prompt}")

User Prompt: tampilin itu pesen yg bnyk dri kota smw aj klo bsa cepet


In [9]:
rewritten_prompt = text_to_sql_model.predict_rewriter_only(user_prompt)
print(f"Rewritten Prompt: {rewritten_prompt}")

Rewritten Prompt: Display all orders from every city as quickly as possible.



In [None]:
# Schema Linking
user_prompt = "Show all orders from every city"
print(f"User Prompt: {user_prompt}")

User Prompt: Show all orders from every city


In [11]:
schema = text_to_sql_model.predict_schema_only(user_prompt)

Related Tables: {'customers', 'order_details', 'employees', 'shippers', 'orders'}


In [24]:
# Relevant Example
user_prompt = "Which city has the highest number of orders that were shipped late? Show city and late_order_count."
print(f"User Prompt: {user_prompt}")
example = text_to_sql_model.predict_relevance_only(user_prompt)
print(example)

User Prompt: Which city has the highest number of orders that were shipped late? Show city and late_order_count.
{'relevant_question': 'What are the top 3 cities with the highest number of orders? Show city and total_orders.', 'relevant_answer': 'SELECT ship_city AS city, COUNT(order_id) AS total_orders\n  FROM orders\n  WHERE ship_city IS NOT NULL\n  GROUP BY ship_city\n  ORDER BY total_orders DESC\n  LIMIT 3;', 'relevant_summary': "This SQL query identifies the top 3 cities with the highest number of orders by analyzing the 'orders' table. It counts orders per city (excluding NULL values), groups by 'ship_city', and returns the cities with the most orders in descending order. Key operations include COUNT aggregation, GROUP BY, and ORDER BY with LIMIT. The primary metric is 'total_orders', showing order volume by destination city. The WHERE clause filters out NULL values to ensure data quality. The output helps identify key markets and shipping destinations, valuable for logistics pla

In [23]:
example = text_to_sql_model.predict_relevance_only(user_prompt)
print(example)

{'relevant_question': 'What are the top 3 cities with the highest number of orders? Show city and total_orders.', 'relevant_answer': 'SELECT ship_city AS city, COUNT(order_id) AS total_orders\n  FROM orders\n  WHERE ship_city IS NOT NULL\n  GROUP BY ship_city\n  ORDER BY total_orders DESC\n  LIMIT 3;', 'relevant_summary': "This SQL query identifies the top 3 cities with the highest number of orders by analyzing the 'orders' table. It counts orders per city (excluding NULL values), groups by 'ship_city', and returns the cities with the most orders in descending order. Key operations include COUNT aggregation, GROUP BY, and ORDER BY with LIMIT. The primary metric is 'total_orders', showing order volume by destination city. The WHERE clause filters out NULL values to ensure data quality. The output helps identify key markets and shipping destinations, valuable for logistics planning and regional sales analysis. Potential modifications include: (1) adding a date filter to analyze recent or