In [10]:
import os
import re
import json
import time
import glob
import math
import statistics
from typing import Dict, Any, List, Tuple, Optional

import pandas as pd
import requests
from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

True

In [11]:
DATASET_DIR = "../data"
REPORTS_DIR = "../results"

API_URL = "http://localhost:8000/query"
API_TIMEOUT_S = 120

USE_OPENAI_JUDGE = True
OPENAI_MODEL = "gpt-4o-mini"
OPENAI_TIMEOUT_S = 60
JUDGE_SAMPLE_LIMIT = None     # set int (e.g., 200) to judge only first N rows

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

In [12]:
SQL_KEYWORDS = {
    "select","from","where","join","inner","left","right","full","cross","on",
    "group","by","having","order","limit","offset","distinct","union","all",
    "insert","into","values","update","set","delete","case","when","then","else","end",
    "as","and","or","not","in","exists","between","like","is","null"
}

def normalize_sql(sql: str) -> str:
    """Normalization that preserves structure while removing formatting noise."""
    if sql is None:
        return ""
    s = sql.strip()
    s = s.rstrip(";")
    s = re.sub(r"\s+", " ", s)          # collapse whitespace
    s = s.replace("`", "")              # remove backticks (optional)
    return s.lower().strip()

def sql_tokens(sql: str) -> List[str]:
    """Simple SQL tokenization (no external deps)."""
    s = normalize_sql(sql)
    toks = re.findall(r"[a-z_]+|\d+|<=|>=|!=|=|<|>|\*|\(|\)|,|\.", s)
    return [t for t in toks if t.strip()]

def token_f1(pred: str, gold: str) -> Tuple[float, float, float]:
    p = sql_tokens(pred)
    g = sql_tokens(gold)
    if not p and not g:
        return (1.0, 1.0, 1.0)
    if not p or not g:
        return (0.0, 0.0, 0.0)

    from collections import Counter
    pc = Counter(p)
    gc = Counter(g)

    common = 0
    for k in pc.keys():
        common += min(pc[k], gc.get(k, 0))

    precision = common / max(1, sum(pc.values()))
    recall = common / max(1, sum(gc.values()))
    f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
    return (precision, recall, f1)

def keyword_f1(pred: str, gold: str) -> Tuple[float, float, float]:
    pset = {t for t in sql_tokens(pred) if t in SQL_KEYWORDS}
    gset = {t for t in sql_tokens(gold) if t in SQL_KEYWORDS}
    if not pset and not gset:
        return (1.0, 1.0, 1.0)
    if not pset or not gset:
        return (0.0, 0.0, 0.0)

    common = len(pset & gset)
    precision = common / len(pset)
    recall = common / len(gset)
    f1 = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
    return (precision, recall, f1)

def levenshtein(a: str, b: str) -> int:
    """Classic DP Levenshtein distance."""
    a = a or ""
    b = b or ""
    if a == b:
        return 0
    if len(a) == 0:
        return len(b)
    if len(b) == 0:
        return len(a)

    if len(a) > len(b):
        a, b = b, a

    prev = list(range(len(a) + 1))
    for i, cb in enumerate(b, start=1):
        cur = [i]
        for j, ca in enumerate(a, start=1):
            ins = cur[j-1] + 1
            dele = prev[j] + 1
            sub = prev[j-1] + (0 if ca == cb else 1)
            cur.append(min(ins, dele, sub))
        prev = cur
    return prev[-1]

def edit_similarity(pred: str, gold: str) -> float:
    p = normalize_sql(pred)
    g = normalize_sql(gold)
    if not p and not g:
        return 1.0
    dist = levenshtein(p, g)
    denom = max(1, max(len(p), len(g)))
    return 1.0 - (dist / denom)

In [13]:
def call_sql_api(api_url: str, user_query: str, timeout_s: int = 30) -> Dict[str, Any]:
    payload = {"user_query": user_query}
    r = requests.post(api_url, json=payload, timeout=timeout_s)
    r.raise_for_status()
    return r.json()

# LLM as a Judge

In [14]:
JUDGE_PROMPT = """You are a strict SQL evaluator.

Task:
Given (1) a natural-language question, (2) a GOLD SQL, and (3) a PREDICTED SQL,
judge whether GOLD and PREDICTED are semantically equivalent for the same schema/data.

Return JSON ONLY with keys:
- "equivalent": true/false
- "score": integer from 0 to 5  (5 = fully equivalent, 0 = totally wrong)
- "reason": short explanation (<= 2 sentences)

Be careful:
- Different formatting is fine.
- Different table aliases are fine.
- If selected columns, filters, joins, grouping, or ordering change the meaning, it's NOT equivalent.

NL_QUESTION:
{question}

GOLD_SQL:
{gold}

PREDICTED_SQL:
{pred}
"""

def call_openai_judge(question: str, gold_sql: str, pred_sql: str) -> Dict[str, Any]:
    prompt = JUDGE_PROMPT.format(question=question, gold=gold_sql, pred=pred_sql)

    # Ask the model to return only JSON; we still parse defensively.
    resp = client.responses.create(
        model=OPENAI_MODEL,
        input=prompt,
        # You can add reasoning controls for some models if desired (optional):
        # reasoning={"effort": "low"},
    )

    text = (resp.output_text or "").strip()

    # Extract JSON object robustly
    m = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if not m:
        return {"equivalent": None, "score": None, "reason": f"Could not parse JSON. Raw: {text[:200]}"}

    try:
        obj = json.loads(m.group(0))
        if "equivalent" not in obj or "score" not in obj or "reason" not in obj:
            return {"equivalent": None, "score": None, "reason": f"Invalid judge JSON: {obj}"}
        return obj
    except Exception as e:
        return {"equivalent": None, "score": None, "reason": f"JSON parse error: {e}. Raw: {text[:200]}"}

In [15]:
def load_all_csvs(dataset_dir: str) -> pd.DataFrame:
    paths = sorted(glob.glob(os.path.join(dataset_dir, "*.csv")))
    if not paths:
        raise FileNotFoundError(f"No CSV files found in '{dataset_dir}/'")

    dfs = []
    for p in paths:
        df = pd.read_csv(p)
        if not {"user_query", "sql_query"}.issubset(df.columns):
            raise ValueError(f"Missing required columns in {p}. Need: user_query, sql_query")
        df = df[["user_query", "sql_query"]].copy()
        df["source_file"] = os.path.basename(p)
        dfs.append(df)

    out = pd.concat(dfs, ignore_index=True)
    out["row_id"] = range(len(out))
    return out

In [16]:
def evaluate():
    os.makedirs(REPORTS_DIR, exist_ok=True)

    df = load_all_csvs(DATASET_DIR)
    results = []
    judge_count = 0

    for i, row in df.iterrows():
        user_q = str(row["user_query"])
        gold = str(row["sql_query"])

        t0 = time.time()
        api_ok = True
        api_err = None
        pred = ""
        relevant_tables = []
        columns = []

        try:
            resp = call_sql_api(API_URL, user_q, timeout_s=API_TIMEOUT_S)
            pred = resp.get("sql", "") or ""
            relevant_tables = resp.get("relevant_tables", []) or []
            columns = resp.get("columns", []) or []
        except Exception as e:
            api_ok = False
            api_err = str(e)

        latency = time.time() - t0

        # Metrics
        exact = (pred == gold)
        norm_exact = (normalize_sql(pred) == normalize_sql(gold))
        p_tok, r_tok, f1_tok = token_f1(pred, gold)
        p_kw, r_kw, f1_kw = keyword_f1(pred, gold)
        ed_sim = edit_similarity(pred, gold)

        # OpenAI judge (optional)
        judge_equiv = None
        judge_score = None
        judge_reason = None

        if USE_OPENAI_JUDGE and api_ok and pred.strip():
            if (JUDGE_SAMPLE_LIMIT is None) or (judge_count < JUDGE_SAMPLE_LIMIT):
                j = call_openai_judge(user_q, gold, pred)
                judge_equiv = j.get("equivalent")
                judge_score = j.get("score")
                judge_reason = j.get("reason")
                judge_count += 1

        results.append({
            "row_id": int(row["row_id"]),
            "source_file": row["source_file"],
            "user_query": user_q,
            "gold_sql": gold,
            "pred_sql": pred,
            "api_ok": api_ok,
            "api_error": api_err,
            "latency_s": latency,
            "exact_match": exact,
            "normalized_exact_match": norm_exact,
            "token_precision": p_tok,
            "token_recall": r_tok,
            "token_f1": f1_tok,
            "keyword_precision": p_kw,
            "keyword_recall": r_kw,
            "keyword_f1": f1_kw,
            "edit_similarity": ed_sim,
            "judge_equivalent": judge_equiv,
            "judge_score_0_5": judge_score,
            "judge_reason": judge_reason,
            "relevant_tables": json.dumps(relevant_tables, ensure_ascii=False),
            "returned_columns": json.dumps(columns, ensure_ascii=False),
        })

        if (i + 1) % 25 == 0:
            print(f"Processed {i+1}/{len(df)}")

    out = pd.DataFrame(results)

    # Summary
    def safe_mean(x):
        xs = [v for v in x if v is not None and not (isinstance(v, float) and math.isnan(v))]
        return statistics.mean(xs) if xs else None

    summary = {
        "n": len(out),
        "api_success_rate": float(out["api_ok"].mean()),
        "avg_latency_s": float(out["latency_s"].mean()),
        "exact_match_rate": float(out["exact_match"].mean()),
        "normalized_exact_match_rate": float(out["normalized_exact_match"].mean()),
        "avg_token_f1": float(out["token_f1"].mean()),
        "avg_keyword_f1": float(out["keyword_f1"].mean()),
        "avg_edit_similarity": float(out["edit_similarity"].mean()),
        "judge_used": USE_OPENAI_JUDGE,
        "judge_rows": int(out["judge_score_0_5"].notna().sum()),
        "avg_judge_score_0_5": safe_mean(out["judge_score_0_5"].tolist()),
        "judge_equivalent_rate": (
            safe_mean([1.0 if v is True else 0.0 if v is False else None for v in out["judge_equivalent"].tolist()])
        ),
    }

    ts = time.strftime("%Y%m%d_%H%M%S")
    csv_path = os.path.join(REPORTS_DIR, f"eval_results_{ts}.csv")
    json_path = os.path.join(REPORTS_DIR, f"eval_summary_{ts}.json")

    out.to_csv(csv_path, index=False)
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)

    print("\n=== SUMMARY ===")
    print(json.dumps(summary, indent=2))
    print(f"\nSaved:\n- {csv_path}\n- {json_path}")

    return out, summary

In [17]:
out_df, summary = evaluate()


=== SUMMARY ===
{
  "n": 15,
  "api_success_rate": 0.8,
  "avg_latency_s": 14.163130458196004,
  "exact_match_rate": 0.0,
  "normalized_exact_match_rate": 0.06666666666666667,
  "avg_token_f1": 0.56254590695965,
  "avg_keyword_f1": 0.7205817711700065,
  "avg_edit_similarity": 0.46726268151567546,
  "judge_used": true,
  "judge_rows": 12,
  "avg_judge_score_0_5": 3.8333333333333335,
  "judge_equivalent_rate": 0.5833333333333334
}

Saved:
- ../results/eval_results_20251221_213255.csv
- ../results/eval_summary_20251221_213255.json


In [18]:
out_df.head(10), summary

(   row_id source_file                                         user_query  \
 0       0    test.csv  Find the products supplied by more than 2 vend...   
 1       1    test.csv  List the product categories along with the num...   
 2       2    test.csv          Get the top 5 customers by total revenue.   
 3       3    test.csv  Retrieve all products that have never been ord...   
 4       4    test.csv  Get the total sales value for each month of ea...   
 5       5     val.csv         Find the top-selling product in each year.   
 6       6     val.csv  Retrieve the most expensive product in each ca...   
 7       7     val.csv  Which products have the highest and lowest pro...   
 8       8     val.csv              Show all sales orders placed in 2011.   
 9       9     val.csv  Customers who purchased more than 100 differen...   
 
                                             gold_sql  \
 0  SELECT ProductID, COUNT(*) AS VendorCount FROM...   
 1  SELECT pc.ProductCategoryID, pc.N