# Import Libraries

In [1]:
import sys
import os

notebook_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))

if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

In [2]:
from text_to_sql import (
    TextToSQL,
    Config,
    LLMConfig,
    SLConfig,
    ContextConfig,
    QueryConfig,
)
from dotenv import load_dotenv
from datetime import datetime

import pandas as pd
import os




# Constants

In [3]:
MAX_RETRIES = 5
RETRY_DELAY = 2
DATABASE = "soccer"
MODEL = "gemini-1.5-pro"
PROVIDER = "gemini"

# Load Environment

In [4]:
load_dotenv()

True

# Set Timestamp Experiment

In [5]:
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M")
output_dir = f"../files/experiment_result/{timestamp}"
os.makedirs(output_dir, exist_ok=True)

# Config

In [6]:
db_key = DATABASE.upper().replace("-", "_")
provider_key = PROVIDER.upper().replace("-", "_")

config = Config(
    max_retry_attempt=5,
    rewriter_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    query_generator_config=LLMConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
    ),
    schema_linker_config=SLConfig(
        type="api",
        model=MODEL,
        provider=PROVIDER,
        api_key=os.getenv(f"API_KEY_{provider_key}"),
        schema_path=f"../files/schema/{DATABASE}.txt",
        metadata_path=f"../files/metadata/{DATABASE}.json",
    ),
    retrieve_context_config=ContextConfig(data_path=f"../files/dataset/dataset_{DATABASE}.csv"),
    query_executor_config = QueryConfig(
        host=os.getenv(f"DB_HOST_{db_key}"),
        database=os.getenv(f"DB_DATABASE_{db_key}"),
        user=os.getenv(f"DB_USER_{db_key}"),
        password=os.getenv(f"DB_PASSWORD_{db_key}"),
        port=os.getenv(f"DB_PORT_{db_key}"),
    ),
)

# Model

In [7]:
text_to_sql_model = TextToSQL(config=config)

Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.
Initializing API client for gemini using model gemini-1.5-pro.


  from google.protobuf import service as _service


# Import Dataset

In [8]:
dataset = pd.read_csv(f"../files/dataset/dataset_{DATABASE}.csv")

In [9]:
dataset

Unnamed: 0,Question,Answer,Summary,Alternative Prompt 1 (English),Alternative Prompt 2 (Bahasa Indonesia),Expected Result
0,Which players have scored more than 2 goals in...,"SELECT gd.player_id, pm.player_name, gd.match_...",This SQL query identifies players who scored m...,Can you show me the players who scored more th...,Tolong tampilkan pemain yang mencetak lebih da...,"['player_id', 'player_name', 'match_no', 'tota..."
1,Show all matches where the final score was a d...,SELECT \r\n md1.match_no\r\nFROM \r\n ma...,This SQL query identifies football matches whe...,Show all matches where the final score was a d...,Tampilkan semua pertandingan yang berakhir imb...,['match_no']
2,Which teams have won all matches in the group ...,"SELECT team_id, COUNT(*) AS total_wins\n FROM...",This SQL query identifies teams that won all o...,Which teams managed to win all their group sta...,Tim mana saja yang menang terus di babak grup?...,"['team_id', 'total_wins']"
3,Find referees who officiated more than 3 match...,"SELECT rm.referee_id, rm.referee_name, COUNT(*...",This SQL query identifies the most active refe...,Can you list referees who led more than 3 matc...,Saya ingin tahu wasit yang memimpin lebih dari...,"['referee_id', 'referee_name', 'total_matches']"
4,Which players were substituted in and out in t...,"SELECT DISTINCT pio.player_id, pm.player_name,...",This SQL query identifies players who were bot...,Show me which players were both substituted in...,Tampilkan pemain yang pernah diganti masuk dan...,"['player_id', 'player_name', 'match_no']"
5,Identify players who have been booked and also...,"SELECT DISTINCT pb.player_id, pm.player_name, ...",This SQL query identifies players who were bot...,Which players got a card and also scored a goa...,Siapa saja pemain yang dapat kartu dan cetak g...,"['player_id', 'player_name', 'match_no']"
6,Show teams with more goals for than against in...,"SELECT team_id, goal_for, goal_agnst, goal_dif...",This SQL query identifies teams with a positiv...,Can you list teams that scored more goals than...,Tampilkan tim-tim yang jumlah gol masuknya leb...,"['team_id', 'goal_for', 'goal_agnst', 'goal_di..."
7,Which goalkeepers (player_gk) appeared in the ...,"SELECT player_gk AS player_id, COUNT(*) AS mat...",This SQL query analyzes goalkeeper appearances...,Who are the goalkeepers that played the most m...,Kiper mana saja yang paling sering tampil seba...,"['player_id', 'match_count']"
8,List venues with the highest audience attendan...,"SELECT sv.venue_name, sc.city, MAX(mm.audence)...",This SQL query identifies the maximum audience...,Show me the venues where audience attendance r...,Venue mana saja yang penontonnya banyak? Saya ...,"['venue_name', 'city', 'max_audience']"
9,Find players who were never substituted out. S...,"SELECT DISTINCT pm.player_id, pm.player_name\n...",This SQL query identifies players who were nev...,Find the players who were never substituted ou...,Tampilkan pemain yang tidak pernah diganti kel...,"['player_id', 'player_name']"


# Experiment

In [10]:
import time
import ast

EA = 0
total_questions = len(dataset) * 2
results_list = []

for idx, row in dataset.iterrows():
    question_1 = row["Alternative Prompt 1 (English)"]
    question_2 = row["Alternative Prompt 2 (Bahasa Indonesia)"]
    answer = row["Answer"]
    expected_columns = ast.literal_eval(row["Expected Result"])

    for prompt_id, question in enumerate([question_1, question_2], start=1):
        print(f"\nProcessing Question {idx + 1}.{prompt_id}: {question}")
        result = None

        for attempt in range(1, MAX_RETRIES + 1):
            try:
                result = text_to_sql_model.predict_sql_with_example_only(user_prompt=question)
                break
            except Exception as e:
                print(f"[Attempt {attempt}] Failed to generate SQL: {e}")
                if attempt < MAX_RETRIES:
                    time.sleep(RETRY_DELAY)
                else:
                    print("Max retries reached. Setting result as 'ERROR'")
                    result = "ERROR"

        print(f"Generated SQL Query: {result}")

        try:
            acc = text_to_sql_model.evaluate(query=result, true_query=answer, expected_columns=expected_columns)
        except Exception as e:
            print(f"Evaluation failed: {e}")
            acc = 0.0

        print(f"Execution Accuracy: {acc:.4f}")

        results_list.append({
            "Question ID": f"{idx + 1}.{prompt_id}",
            "Question": question,
            "Generated SQL Query": result,
            "Expected SQL Query": answer,
            "Execution Accuracy": acc
        })

        EA += acc

# Calculate final execution accuracy
final_accuracy = EA / total_questions if total_questions > 0 else 0
print(f"\nFinal Execution Accuracy: {final_accuracy:.4f}")


Processing Question 1.1: Can you show me the players who scored more than 2 goals in one game? I’d like to see their player_id, name, match number, and how many goals they scored.
Generated SQL Query: SELECT gd.player_id, pm.player_name, gd.match_no, COUNT(*) AS total_goals
FROM goal_details AS gd
JOIN player_mast AS pm ON gd.player_id = pm.player_id
GROUP BY gd.player_id, pm.player_name, gd.match_no
HAVING COUNT(*) > 2;
Execution Accuracy: 1.0000

Processing Question 1.2: Tolong tampilkan pemain yang mencetak lebih dari 2 gol dalam satu pertandingan. Saya mau lihat player_id, nama pemain, match_no, dan total golnya.
Generated SQL Query: SELECT player_id, player_name, match_no, COUNT(*) AS total_goals
FROM goal_details AS gd
JOIN player_mast AS pm ON gd.player_id = pm.player_id
GROUP BY player_id, player_name, match_no
HAVING COUNT(*) > 2;
Error executing query: column reference "player_id" is ambiguous
LINE 1: SELECT player_id, player_name, match_no, COUNT(*) AS total_g...
          

In [11]:
df_results = pd.DataFrame(results_list)
df_results.to_csv(f"{output_dir}/{MODEL}_{DATABASE}_relevant_example.csv", index=False)