In [26]:
import csv
import pandas as pd
import pprint
import re
from typing import List, Dict, Any, Generator
import sqlglot
from typing import TypedDict, Any, Union, cast
from tqdm.notebook import tqdm
import json

In [2]:
from sql_qa.utils.constant import EColumn, LlmRowEntry, LogRowEntry

## Helper functions

In [3]:
DIALECT = "mysql"

log_csv_file_path = '/mnt/Code/code/AI/agentic-AI/SQL-QA/logs/turn_log_benchmark.csv'

DELIMITER = (
        "\t"  # Using tab as delimiter which is less likely to appear in text content
    )

CANDIDATE_NAMES = [
    'direct_generation',
    'cot_generation',
    'dac_cot_genration',
    'query_plan_generation'
]

In [4]:
def pd_read_csv(file_path: str) -> pd.DataFrame:
    df = pd.read_csv(
        file_path,
        encoding="utf-8",
        quoting=1,  # QUOTE_ALL mode
        escapechar="\\",
        on_bad_lines="warn",
    )
    return df

def pd_save_csv(df: pd.DataFrame, csv_result_path: str):
    df.to_csv(
        csv_result_path,
        encoding="utf-8",
        quoting=1,  # QUOTE_ALL mode
        escapechar="\\",
        index=False
    )

    

In [5]:
def normalize_sql(sql: str, pretty=True) -> str:
    # TODO: sqlglot not supporting UTF-8
    """
    Normalize SQL queries by removing extra whitespace and converting to lowercase.
    """
    try:
        return sqlglot.transpile(
                    sql, write=DIALECT, pretty=pretty,
            )[0]

    except Exception as e:
        print(f"Error normalizing SQL: {e} at\nQuery: {sql}")
        # Fallback to a simple normalization if sqlglot fails
        return sql
    # return sqlglot.parse_one(sql).sql().lower().strip()
def is_exact_match(src_sql: str, tg_sql: str) -> bool:
    """
    Check if the query matches the answer exactly.
    """
    # return query.strip() == answer.strip()
    try:
        return normalize_sql(src_sql) == normalize_sql(tg_sql)
    except Exception as e:
        print(f"Error normalizing SQL: {e} at\nQuery: {src_sql}\nAnswer: {tg_sql}")
        return False

def normalize_answer(answer: str) -> str:   
    """
    Normalize the answer by removing extra whitespace and converting to lowercase.
    """
    return answer.strip().lower()
    
def is_execution_match(src_result: str, tg_result: str) -> bool:
    """
    Check if the query matches the answer after execution.
    """
    try:
        # Normalize both query and answer
        normalized_src_query = normalize_answer(src_result)
        normalized_tg_answer = normalize_answer(tg_result)
        
        # Compare normalized versions
        return normalized_src_query == normalized_tg_answer
    except Exception as e:
        print(f"Error normalizing execution result: {e} at\nQuery: {src_result}\nAnswer: {tg_result}")
        return False


regex = r"Câu lệnh SQL \d: ```sql([^`]*)```"

def parse_sql_candidates(regex=regex, merge_prompt='') -> List[str]:

    matches = re.finditer(regex, merge_prompt, re.MULTILINE | re.DOTALL)
    candidates = []
    for matchNum, match in enumerate(matches, start=1):
        
        # print ("Match {matchNum} was found at {start}-{end}: {match}".format(matchNum = matchNum, start = match.start(), end = match.end(), match = match.group()))
        
        for groupNum in range(0, len(match.groups())):
            groupNum = groupNum + 1
            
            # print ("Group {groupNum} found at {start}-{end}: {group}".format(groupNum = groupNum, start = match.start(groupNum), end = match.end(groupNum), group = match.group(groupNum)))
            # print(match.group(groupNum).strip())
            # print('-' * 20)
            sql = match.group(groupNum).strip()
            sql = normalize_sql(sql)
            candidates.append(sql)

    return candidates


In [10]:
def gen_get_row_sql_gen_candidates(
    csv_file_path: str, delimiter: str = DELIMITER
) -> Generator[LogRowEntry, None, None]:
    """
    Extract Rows from a CSV file and yield each row as a dictionary.
    Each row contains fields like:
    
    """
    with open(
        csv_file_path, "r", newline="", encoding="utf-8"
    ) as csvfile:
        reader = csv.DictReader(
            csvfile, delimiter=delimiter, quoting=csv.QUOTE_ALL
        )
        i = 0
        for row in reader:
            # i += 1
            # if i > 5: break
            # print(f'----------Row {i}---------')
            # pprint.pprint(row, indent=4, width=80)
            yield cast(LogRowEntry,row)

In [24]:
test_sql = ['SELECT\n  id,\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1',
 'SELECT\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1']

for sql in test_sql:
    print(f'Original SQL: \n```{sql}\n```')
    normalized_sql = normalize_sql(sql)
    print(f'Normalized SQL: \n```{normalized_sql}```')
    print(f'Is exact match: {is_exact_match(sql, normalized_sql)}')
    print('-' * 40)

Original SQL: 
```SELECT
  id,
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1
```
Normalized SQL: 
```SELECT
  id,
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1```
Is exact match: True
----------------------------------------
Original SQL: 
```SELECT
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1
```
Normalized SQL: 
```SELECT
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1```
Is exact match: True
----------------------------------------


In [None]:

candidates_gen = gen_get_row_sql_gen_candidates(csv_file_path=log_csv_file_path, delimiter=DELIMITER)
# row = next(iter(candidates_gen))  # Skip the first row (header)
row = next(candidates_gen)  # Get the first row
print(row['user_question'])

Liệt kê tất cả các chi nhánh hiện đang hoạt động.



In [None]:
from langchain_community.utilities import SQLDatabase

conn = 'mysql+pymysql://root:yolo2chill@localhost:3306/gsv'
db =  SQLDatabase.from_uri(conn)

def run_sql(sql: str) -> Any:
    """
    Execute the SQL query and return the result.
    """
    # try:
    #     result = db.execute(sql)
    #     return result.fetchall()
    # except Exception as e:
    #     print(f"Error executing SQL: {e} at\nSQL: {sql}")
    #     return None
    result = db.run_no_throw(sql)
    return result

In [None]:
import sys

csv.field_size_limit(sys.maxsize)

csv_result_path = "/mnt/Code/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/gen_success_data_results_LLM_judge_updated.csv"

df = pd_read_csv(csv_result_path)

df["candidate_generations"] = None
df["candidate_em"] = None
df["candidate_ex"] = None
df["candidate_least_correct"] = None
df["candidate_passes"] = None

candidates_gen = gen_get_row_sql_gen_candidates(
    csv_file_path=log_csv_file_path, delimiter=DELIMITER
)
pbar = tqdm(df.iterrows(), total=len(df), leave=False)
for idx, llm_row in pbar:
    # if idx > 5:
    #     break
    llm_row = cast(LlmRowEntry, llm_row)  # Ensure row is of type LlmRowEntry
    pbar.write(f"----------Row {idx}---------")
    pbar.write(f'Processing llm_row with question: {llm_row["question"]}')
    while True:
        try:
            candidate_row = next(candidates_gen)
        except StopIteration:
            pbar.write(f"No candidate row found for index {idx}")
            candidate_row = None
            break

        if candidate_row["user_question"] != llm_row["question"]:
            pbar.write(f'skip candidate row:\n{candidate_row["user_question"]}')
            continue  # Skip if the question matches the current row's question
        else:
            break
    if not candidate_row:
        break
    sql_candidates = parse_sql_candidates(merge_prompt=candidate_row["merger_prompt"])
    df.at[idx, "candidate_generations"] = sql_candidates
    em_candidates = [
        is_exact_match(c, llm_row["ground_truth_sql"]) for c in sql_candidates
    ]
    df.at[idx, "candidate_em"] = em_candidates
    ex_canddiates = [
        is_execution_match(run_sql(c), llm_row["ground_truth_result"])
        for c in sql_candidates
    ]
    df.at[idx, "candidate_ex"] = ex_canddiates
    df.at[idx, "candidate_least_correct"] = [
        c
        for c, em, ex in zip(sql_candidates, em_candidates, ex_canddiates)
        if (em) or (ex)
    ]
    df.at[idx, "candidate_passes"] = [
        CANDIDATE_NAMES[i]
        for i, (em, ex) in enumerate(zip(em_candidates, ex_canddiates))
        if (em) or (ex)
    ]
    # pbar.write(f'Row {idx} candidates: {row["candidate_generations"]}')

    # if int( idx ) > 5:
    #     break
    # pbar.write(llm_row)
    # pbar.write(llm_row['question'])
    # pbar.write(llm_row['candidate_generations'])
    # pbar.write(llm_row['candidate_em'])
    # pbar.write(llm_row['candidate_ex'])
    # pbar.write(llm_row['candidate_least_correct'])
    pass
pbar.close()

In [None]:
pd_save_csv(df, csv_result_path)

In [103]:
for idx, row in df.iterrows():
    df.at[idx, "candidate_passes"] = [CANDIDATE_NAMES[i] for i, (em, ex) in enumerate(zip(row['candidate_em'], row['candidate_ex'])) if em or ex]

In [104]:
df.iloc[-5:]

Unnamed: 0,question,ground_truth_sql,level,ground_truth_result,error,generated_sql_query,generated_query_result,generated_sql_error,generated_raw_result,llm_exact_match,llm_execution_match,candidate_generations,candidate_em,candidate_ex,candidate_least_correct,candidate_passes
284,Liệt kê 5 dịch vụ (service_info) được sử dụng ...,"SELECT si.service_info_name, COUNT(aps.id) AS ...",hard,"[('Phuong Test', 8), ('Gói điều trị thâm vùng ...",,"SELECT si.service_info_name AS 'Tên dịch vụ', ...",Dưới đây là danh sách 5 dịch vụ được sử dụng n...,,"[('Phuong Test', 8), ('Gói điều trị thâm vùng ...",True,True,"[SELECT\n si.service_info_name,\n COUNT(aps....","[False, False, False, False]","[True, True, True, True]","[SELECT\n si.service_info_name,\n COUNT(aps....","[direct_generation, cot_generation, dac_cot_ge..."
285,Bệnh nhân nào có nhiều lịch hẹn bị hủy (appoin...,"SELECT p.full_name, COUNT(a.id) AS so_lich_huy...",hard,"[('Cường nguyễn', 5)]",,"SELECT p.full_name AS patient_name, COUNT(a.pa...",Bệnh nhân có nhiều lịch hẹn bị hủy nhất là Cườ...,,"[('Cường nguyễn', 5)]",True,True,"[SELECT\n p.id AS patient_id,\n p.full_name ...","[False, False, False, False]","[False, True, False, True]","[SELECT\n p.full_name AS patient_name,\n COU...","[cot_generation, query_plan_generation]"
286,Liệt kê các nhân viên và số lượng lịch hẹn họ ...,"SELECT e.first_name, e.last_name, COUNT(a.id) ...",hard,"[('Admin', None, 95), ('Cường Cầu giấy', 'Cườn...",,"SELECT e.first_name, e.last_name, COUNT(a.id) ...",Dưới đây là danh sách nhân viên và số lượng lị...,,"[('Admin', None, 95), ('Cường Cầu giấy', 'Cườn...",,,"[SELECT\n e.id AS employee_id,\n e.first_nam...","[False, False, False, False]","[False, False, False, True]","[SELECT\n e.first_name,\n e.last_name,\n CO...",[query_plan_generation]
287,Tìm thông tin các gói (package) được mua bởi b...,SELECT pkg.* FROM package pkg JOIN patient p O...,hard,"[(datetime.datetime(2025, 4, 17, 13, 17, 32, 8...",,"SELECT p.package_code AS 'Mã gói', p.package_n...",Thông tin về các gói mà bệnh nhân có mã 'KH 00...,,"[('GDV20250417201732531', None, datetime.datet...",False,False,"[SELECT\n p.id AS package_id,\n p.package_co...","[False, False, False, False]","[False, False, False, False]",[],[]
288,Chi nhánh nào có nhiều lịch hẹn nhất?\n,"SELECT b.branch_name, COUNT(a.id) AS total_app...",hard,"[('Chi nhánh Cầu Giấy', 63)]",,"SELECT b.branch_name, COUNT(a.id) AS total_app...","Chi nhánh Cầu Giấy có nhiều lịch hẹn nhất, với...",,"[('Chi nhánh Cầu Giấy', 63)]",True,True,"[SELECT\n b.branch_name,\n COUNT(a.id) AS to...","[False, False, False, False]","[True, True, True, True]","[SELECT\n b.branch_name,\n COUNT(a.id) AS to...","[direct_generation, cot_generation, dac_cot_ge..."


In [None]:
df['candidate_least_correct'].value_counts()

In [None]:
df['llm_exact_match'].value_counts()

llm_exact_match
True     127
False     97
Name: count, dtype: int64

In [None]:
df['llm_execution_match'].value_counts()

llm_execution_match
True     135
False     89
Name: count, dtype: int64

In [106]:
candidate_counter = {}
for candidate in CANDIDATE_NAMES:
    candidate_counter[candidate] = df['candidate_passes'].apply(lambda x: candidate in x).sum()
candidate_counter
# df['candidate_passes']

{'direct_generation': np.int64(100),
 'cot_generation': np.int64(109),
 'dac_cot_genration': np.int64(104),
 'query_plan_generation': np.int64(102)}

## GT_question + PRED_sql -> GT_execution + PRED_sql

In [50]:
import pandas as pd
import csv

In [51]:
csv_log_file = '/mnt/Code/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/GSV-data-Nam-800_results.csv'

In [52]:
df = pd_read_csv(csv_log_file)
# df['gt_execution'] = None
# df['llm_final_generation'] = None
# df['llm_execution'] = None

In [None]:
for idx, row in df.iterrows():
    df.at[idx, 'gt_execution'] = run_sql(df.at[idx, 'sql'])
    # df.at[idx, 'llm_final']


In [None]:
# pd_save_csv(df, csv_log_file)

## Eval on first 200

In [None]:
# df = df[:200]

In [54]:
df['eval_em'] = None
df['eval_ex'] = None
df['llm_em'] = None
df['llm_ex'] = None

In [55]:
len(df)

927

In [None]:
max_len = 1000
pbar = tqdm(df.iterrows(), total=len(df))

accepted_ids = []
for idx, row in pbar:
    if len(accepted_ids) > 200: 
        break
    try:
        if (
            len(str(row["generated_sql_query"])) > max_len
            or len(str(row["generated_raw_result"])) > max_len
        ):
            continue

        df.at[idx, "eval_em"] = is_exact_match(row["sql"], row["generated_sql_query"])
        df.at[idx, "eval_ex"] = is_execution_match(
            str(row["gt_execution"]), str(row["generated_raw_result"])
        )
        accepted_ids.append(idx)
    except Exception as e:
        print("-------")
        print(idx)
        print(row)
        print("-------")
        raise (e)

pbar.close()

In [61]:
df['eval_ex'].value_counts()

eval_ex
False    105
True      96
Name: count, dtype: int64

In [62]:
df['eval_em'].value_counts()

eval_em
False    199
True       2
Name: count, dtype: int64

In [63]:
saved_df = df[['question', 'sql', 'gt_execution', 'generated_sql_query', 'generated_raw_result', 'llm_em', 'llm_ex']]

In [65]:
csv_log_file = '/mnt/Code/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/GSV-data-Nam-200_results-llm-as-judge.csv'
pd_save_csv(saved_df, csv_log_file)

## Extract table, column

In [6]:
from sqlglot import parse_one, exp, transpile

In [7]:
import pandas as pd

In [8]:
csv_file_path = '/mnt/Code/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/GSV-data-Nam-200_results_eval_LLM-as-judge_KT_KD_results.csv'

In [9]:
df = pd.read_csv(csv_file_path)

In [10]:
sql = df.at[1, EColumn.gen_sql.value]

In [11]:
print(transpile(sql, write="mysql", pretty=True)[0])

WITH MonthlyAppointments AS (
  SELECT
    doctor_id,
    DATE_FORMAT(appointment_time, '%Y-%m') AS appointment_month,
    COUNT(*) AS appointment_count
  FROM appointment
  WHERE
    appointment_time >= DATE_SUB(CURDATE(), INTERVAL (INTERVAL '1' MONTH) DAY)
    AND appointment_time <= CURDATE()
  GROUP BY
    doctor_id,
    appointment_month
), MonthlyChanges AS (
  SELECT
    ma1.doctor_id,
    ma1.appointment_month,
    ma1.appointment_count,
    LAG(ma1.appointment_count, 1, 0) OVER (PARTITION BY ma1.doctor_id ORDER BY ma1.appointment_month) AS previous_month_count
  FROM MonthlyAppointments AS ma1
), GrowthRates AS (
  SELECT
    doctor_id,
    appointment_month,
    appointment_count,
    previous_month_count,
    (
      appointment_count - previous_month_count
    ) AS growth
  FROM MonthlyChanges
  WHERE
    appointment_month = DATE_FORMAT(CURDATE(), '%Y-%m')
)
SELECT
  e.first_name,
  e.last_name,
  gr.growth
FROM GrowthRates AS gr
JOIN employee AS e
  ON gr.doctor_id = e.id


In [12]:
def extract_with_detailed_info(sql_string, dialect=""):
    """
    Extract tables and columns with more detailed information including aliases.
    
    Args:
        sql_string (str): The SQL query string to analyze
        dialect (str): SQL dialect
    
    Returns:
        dict: Detailed information about tables and columns
    """
    try:
        parsed = sqlglot.parse_one(sql_string, dialect=dialect)
        
        result = {
            "tables": {},
            "columns": [],
            "table_aliases": {}
        }
        
        # Extract table information including aliases
        for table in parsed.find_all(exp.Table):
            table_name = table.name
            table_alias = table.alias if table.alias else None
            
            result["tables"][table_name] = {
                "alias": table_alias,
                "columns": []
            }
            
            if table_alias:
                result["table_aliases"][table_alias] = table_name
        
        # Extract column information
        for column in parsed.find_all(exp.Column):
            column_info = {
                "column_name": column.name,
                "table_reference": None,
                "actual_table": None
            }
            
            if column.table:
                column_info["table_reference"] = column.table
                # Check if it's an alias
                if column.table in result["table_aliases"]:
                    column_info["actual_table"] = result["table_aliases"][column.table]
                else:
                    column_info["actual_table"] = column.table
            
            result["columns"].append(column_info)
            
            # Add column to the appropriate table
            if column_info["actual_table"]:
                if column_info["actual_table"] in result["tables"]:
                    result["tables"][column_info["actual_table"]]["columns"].append(column.name)
        
        # Remove duplicates from column lists
        for table_info in result["tables"].values():
            table_info["columns"] = list(set(table_info["columns"]))
        
        return result
        
    except Exception as e:
        return {"error": str(e)}

In [75]:
extract_with_detailed_info(sql)

{'tables': {'GrowthRates': {'alias': 'gr', 'columns': ['doctor_id', 'growth']},
  'employee': {'alias': 'e', 'columns': ['last_name', 'id', 'first_name']},
  'appointment': {'alias': None, 'columns': []},
  'MonthlyAppointments': {'alias': 'ma1',
   'columns': ['appointment_count', 'appointment_month', 'doctor_id']},
  'MonthlyChanges': {'alias': None, 'columns': []}},
 'columns': [{'column_name': 'first_name',
   'table_reference': 'e',
   'actual_table': 'employee'},
  {'column_name': 'last_name',
   'table_reference': 'e',
   'actual_table': 'employee'},
  {'column_name': 'growth',
   'table_reference': 'gr',
   'actual_table': 'GrowthRates'},
  {'column_name': 'doctor_id',
   'table_reference': 'gr',
   'actual_table': 'GrowthRates'},
  {'column_name': 'id', 'table_reference': 'e', 'actual_table': 'employee'},
  {'column_name': 'growth',
   'table_reference': 'gr',
   'actual_table': 'GrowthRates'},
  {'column_name': 'doctor_id', 'table_reference': None, 'actual_table': None},
  {'

In [20]:
for e in EColumn:
    if e.value not in df.columns:
        df[e.value] = None

In [27]:
for idx, row in df.iterrows():
    gt_sql = row[EColumn.gt_sql.value]
    table_with_cols = extract_with_detailed_info(gt_sql)
    schema = table_with_cols.get("tables", table_with_cols.get("error"))
    df.at[idx, EColumn.gt_schema.value] = json.dumps(schema, indent=2)

    gen_sql = row[EColumn.gen_sql.value]
    table_with_cols = extract_with_detailed_info(gen_sql)
    schema = table_with_cols.get("tables", table_with_cols.get("error"))
    df.at[idx, EColumn.gen_schema.value] = json.dumps(schema, indent=2)

In [28]:
pd_save_csv(df, csv_file_path)