In [92]:
import csv
import pandas as pd
import pprint
import re
from typing import List, Dict, Any, Generator
import sqlglot
from typing import TypedDict, Any, Union, cast
from tqdm.notebook import tqdm

## Helper functions

In [100]:
DIALECT = "mysql"

log_csv_file_path = '/mnt/Code/code/AI/agentic-AI/SQL-QA/logs/turn_log_benchmark.csv'

DELIMITER = (
        "\t"  # Using tab as delimiter which is less likely to appear in text content
    )

CANDIDATE_NAMES = [
    'direct_generation',
    'cot_generation',
    'dac_cot_genration',
    'query_plan_generation'
]

In [4]:

class LogRowEntry(TypedDict):
    created_date:Any
    user_question:Any
    linking_structured_result:Any
    filtered_schema_tables:Any
    direct_generation_prompt:Any
    generation_response:Any
    query_validation_prompt:Any
    query_validation_response:Any
    final_sql:Any
    cot_generation_prompt:Any
    dac_cot_genration_prompt:Any
    query_plan_generation_prompt:Any
    merger_prompt:Any
    merger_response:Any
    merger_result:Any
    sql_result:Any
    response_enhancement_prompt:Any
    response_enhancement_response:Any
    response_enhancement_result:Any
    query_fixing_prompt:Any
    query_fixing_response: Any      

class LlmRowEntry(TypedDict):
    question: Any
    ground_truth_sql: Any
    level: Any
    ground_truth_result: Any
    error: Any
    generated_sql_query: Any
    generated_query_result: Any
    generated_sql_error: Any
    generated_raw_result: Any
    llm_exact_match: Any
    llm_execution_match: Any

    candidate_generations: Any
    candidate_em: Any
    candidate_ex: Any
    candidate_least_correct: Any

In [88]:
def normalize_sql(sql: str, pretty=True) -> str:
    # TODO: sqlglot not supporting UTF-8
    """
    Normalize SQL queries by removing extra whitespace and converting to lowercase.
    """
    try:
        return sqlglot.transpile(
                    sql, write=DIALECT, pretty=pretty,
            )[0]

    except Exception as e:
        print(f"Error normalizing SQL: {e} at\nQuery: {sql}")
        # Fallback to a simple normalization if sqlglot fails
        return sql
    # return sqlglot.parse_one(sql).sql().lower().strip()
def is_exact_match(src_sql: str, tg_sql: str) -> bool:
    """
    Check if the query matches the answer exactly.
    """
    # return query.strip() == answer.strip()
    try:
        return normalize_sql(src_sql) == normalize_sql(tg_sql)
    except Exception as e:
        print(f"Error normalizing SQL: {e} at\nQuery: {src_sql}\nAnswer: {tg_sql}")
        return False

def normalize_answer(answer: str) -> str:   
    """
    Normalize the answer by removing extra whitespace and converting to lowercase.
    """
    return answer.strip().lower()
    
def is_execution_match(src_result: str, tg_result: str) -> bool:
    """
    Check if the query matches the answer after execution.
    """
    try:
        # Normalize both query and answer
        normalized_src_query = normalize_answer(src_result)
        normalized_tg_answer = normalize_answer(tg_result)
        
        # Compare normalized versions
        return normalized_src_query == normalized_tg_answer
    except Exception as e:
        print(f"Error normalizing execution result: {e} at\nQuery: {src_result}\nAnswer: {tg_result}")
        return False


regex = r"Câu lệnh SQL \d: ```sql([^`]*)```"

def parse_sql_candidates(regex=regex, merge_prompt='') -> List[str]:

    matches = re.finditer(regex, merge_prompt, re.MULTILINE | re.DOTALL)
    candidates = []
    for matchNum, match in enumerate(matches, start=1):
        
        # print ("Match {matchNum} was found at {start}-{end}: {match}".format(matchNum = matchNum, start = match.start(), end = match.end(), match = match.group()))
        
        for groupNum in range(0, len(match.groups())):
            groupNum = groupNum + 1
            
            # print ("Group {groupNum} found at {start}-{end}: {group}".format(groupNum = groupNum, start = match.start(groupNum), end = match.end(groupNum), group = match.group(groupNum)))
            # print(match.group(groupNum).strip())
            # print('-' * 20)
            sql = match.group(groupNum).strip()
            sql = normalize_sql(sql)
            candidates.append(sql)

    return candidates


In [9]:
def gen_get_row_sql_gen_candidates(
    csv_file_path: str, delimiter: str = DELIMITER
) -> Generator[LogRowEntry, None, None]:
    """
    Extract Rows from a CSV file and yield each row as a dictionary.
    Each row contains fields like:
    
    """
    with open(
        csv_file_path, "r", newline="", encoding="utf-8"
    ) as csvfile:
        reader = csv.DictReader(
            csvfile, delimiter=delimiter, quoting=csv.QUOTE_ALL
        )
        i = 0
        for row in reader:
            # i += 1
            # if i > 5: break
            # print(f'----------Row {i}---------')
            # pprint.pprint(row, indent=4, width=80)
            yield cast(LogRowEntry,row)

In [None]:
parse_sql_candidates(merge_prompt=merge_prompt)

['SELECT\n  id,\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1',
 'SELECT\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1',
 'SELECT\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1',
 'SELECT\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1']

In [7]:
row['merger_response']

" sql='SELECT branch_name, branch_code, description FROM branch WHERE is_active = 1;' explanation='The SQL query selects the branch_name, branch_code, and description from the branch table where the is_active flag is set to 1.'"

In [24]:
test_sql = ['SELECT\n  id,\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1',
 'SELECT\n  branch_name,\n  branch_code,\n  description\nFROM branch\nWHERE\n  is_active = 1']

for sql in test_sql:
    print(f'Original SQL: \n```{sql}\n```')
    normalized_sql = normalize_sql(sql)
    print(f'Normalized SQL: \n```{normalized_sql}```')
    print(f'Is exact match: {is_exact_match(sql, normalized_sql)}')
    print('-' * 40)

Original SQL: 
```SELECT
  id,
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1
```
Normalized SQL: 
```SELECT
  id,
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1```
Is exact match: True
----------------------------------------
Original SQL: 
```SELECT
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1
```
Normalized SQL: 
```SELECT
  branch_name,
  branch_code,
  description
FROM branch
WHERE
  is_active = 1```
Is exact match: True
----------------------------------------


In [None]:

candidates_gen = gen_get_row_sql_gen_candidates(csv_file_path=log_csv_file_path, delimiter=DELIMITER)
# row = next(iter(candidates_gen))  # Skip the first row (header)
row = next(candidates_gen)  # Get the first row
print(row['user_question'])

Liệt kê tất cả các chi nhánh hiện đang hoạt động.



In [84]:
from langchain_community.utilities import SQLDatabase

conn = 'mysql+pymysql://root:yolo2chill@localhost:3306/gsv'
db =  SQLDatabase.from_uri(conn)

def get_sql_result(sql: str) -> Any:
    """
    Execute the SQL query and return the result.
    """
    # try:
    #     result = db.execute(sql)
    #     return result.fetchall()
    # except Exception as e:
    #     print(f"Error executing SQL: {e} at\nSQL: {sql}")
    #     return None
    result = db.run_no_throw(sql)
    return result

In [None]:
import sys

csv.field_size_limit(sys.maxsize)

csv_result_path = "/mnt/Code/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/gen_success_data_results_LLM_judge_updated.csv"

df = pd.read_csv(
    csv_result_path,
    encoding="utf-8",
    quoting=1,  # QUOTE_ALL mode
    escapechar="\\",
    on_bad_lines="warn",
)
df["candidate_generations"] = None
df["candidate_em"] = None
df["candidate_ex"] = None
df["candidate_least_correct"] = None
df["candidate_passes"] = None

candidates_gen = gen_get_row_sql_gen_candidates(
    csv_file_path=log_csv_file_path, delimiter=DELIMITER
)
pbar = tqdm(df.iterrows(), total=len(df), leave=False)
for idx, llm_row in pbar:
    # if idx > 5:
    #     break
    llm_row = cast(LlmRowEntry, llm_row)  # Ensure row is of type LlmRowEntry
    pbar.write(f"----------Row {idx}---------")
    pbar.write(f'Processing llm_row with question: {llm_row["question"]}')
    while True:
        try:
            candidate_row = next(candidates_gen)
        except StopIteration:
            pbar.write(f"No candidate row found for index {idx}")
            candidate_row = None
            break

        if candidate_row["user_question"] != llm_row["question"]:
            pbar.write(f'skip candidate row:\n{candidate_row["user_question"]}')
            continue  # Skip if the question matches the current row's question
        else:
            break
    if not candidate_row:
        break
    sql_candidates = parse_sql_candidates(merge_prompt=candidate_row["merger_prompt"])
    df.at[idx, "candidate_generations"] = sql_candidates
    em_candidates = [
        is_exact_match(c, llm_row["ground_truth_sql"]) for c in sql_candidates
    ]
    df.at[idx, "candidate_em"] = em_candidates
    ex_canddiates = [
        is_execution_match(get_sql_result(c), llm_row["ground_truth_result"])
        for c in sql_candidates
    ]
    df.at[idx, "candidate_ex"] = ex_canddiates
    df.at[idx, "candidate_least_correct"] = [
        c
        for c, em, ex in zip(sql_candidates, em_candidates, ex_canddiates)
        if (em) or (ex)
    ]
    df.at[idx, "candidate_passes"] = [
        CANDIDATE_NAMES[i]
        for i, (em, ex) in enumerate(zip(em_candidates, ex_canddiates))
        if (em) or (ex)
    ]
    # pbar.write(f'Row {idx} candidates: {row["candidate_generations"]}')

    # if int( idx ) > 5:
    #     break
    # pbar.write(llm_row)
    # pbar.write(llm_row['question'])
    # pbar.write(llm_row['candidate_generations'])
    # pbar.write(llm_row['candidate_em'])
    # pbar.write(llm_row['candidate_ex'])
    # pbar.write(llm_row['candidate_least_correct'])
    pass
pbar.close()

  0%|          | 0/289 [00:00<?, ?it/s]

----------Row 0---------
Processing llm_row with question: Có bao nhiêu lịch hẹn được tạo trong tháng 3 năm 2025?

----------Row 1---------
Processing llm_row with question: Liệt kê tất cả các chi nhánh hiện đang hoạt động.

skip candidate row:
Có bao nhiêu lịch hẹn được tạo trong tháng 3 năm 2025?

----------Row 2---------
Processing llm_row with question: Có bao nhiêu khách hàng đặt lịch hẹn tại chi nhánh Hà Đông?

skip candidate row:
Có bao nhiêu lịch hẹn được tạo trong tháng 3 năm 2025?

skip candidate row:
Có bao nhiêu lịch hẹn được tạo trong tháng 3 năm 2025?

skip candidate row:
Liệt kê tất cả các chi nhánh hiện đang hoạt động.

skip candidate row:
Liệt kê tất cả các chi nhánh hiện đang hoạt động.

----------Row 3---------
Processing llm_row with question: Liệt kê tất cả các lịch hẹn có trạng thái "Đã xác nhận" (appointment_status_id = 16).

skip candidate row:
Có bao nhiêu khách hàng đặt lịch hẹn tại chi nhánh Hà Đông?

skip candidate row:
Có bao nhiêu lịch hẹn được tạo trong t

In [105]:

df.to_csv(
    csv_result_path,
    encoding="utf-8",
    quoting=1,  # QUOTE_ALL mode
    escapechar="\\",
)

In [103]:
for idx, row in df.iterrows():
    df.at[idx, "candidate_passes"] = [CANDIDATE_NAMES[i] for i, (em, ex) in enumerate(zip(row['candidate_em'], row['candidate_ex'])) if em or ex]

In [104]:
df.iloc[-5:]

Unnamed: 0,question,ground_truth_sql,level,ground_truth_result,error,generated_sql_query,generated_query_result,generated_sql_error,generated_raw_result,llm_exact_match,llm_execution_match,candidate_generations,candidate_em,candidate_ex,candidate_least_correct,candidate_passes
284,Liệt kê 5 dịch vụ (service_info) được sử dụng ...,"SELECT si.service_info_name, COUNT(aps.id) AS ...",hard,"[('Phuong Test', 8), ('Gói điều trị thâm vùng ...",,"SELECT si.service_info_name AS 'Tên dịch vụ', ...",Dưới đây là danh sách 5 dịch vụ được sử dụng n...,,"[('Phuong Test', 8), ('Gói điều trị thâm vùng ...",True,True,"[SELECT\n si.service_info_name,\n COUNT(aps....","[False, False, False, False]","[True, True, True, True]","[SELECT\n si.service_info_name,\n COUNT(aps....","[direct_generation, cot_generation, dac_cot_ge..."
285,Bệnh nhân nào có nhiều lịch hẹn bị hủy (appoin...,"SELECT p.full_name, COUNT(a.id) AS so_lich_huy...",hard,"[('Cường nguyễn', 5)]",,"SELECT p.full_name AS patient_name, COUNT(a.pa...",Bệnh nhân có nhiều lịch hẹn bị hủy nhất là Cườ...,,"[('Cường nguyễn', 5)]",True,True,"[SELECT\n p.id AS patient_id,\n p.full_name ...","[False, False, False, False]","[False, True, False, True]","[SELECT\n p.full_name AS patient_name,\n COU...","[cot_generation, query_plan_generation]"
286,Liệt kê các nhân viên và số lượng lịch hẹn họ ...,"SELECT e.first_name, e.last_name, COUNT(a.id) ...",hard,"[('Admin', None, 95), ('Cường Cầu giấy', 'Cườn...",,"SELECT e.first_name, e.last_name, COUNT(a.id) ...",Dưới đây là danh sách nhân viên và số lượng lị...,,"[('Admin', None, 95), ('Cường Cầu giấy', 'Cườn...",,,"[SELECT\n e.id AS employee_id,\n e.first_nam...","[False, False, False, False]","[False, False, False, True]","[SELECT\n e.first_name,\n e.last_name,\n CO...",[query_plan_generation]
287,Tìm thông tin các gói (package) được mua bởi b...,SELECT pkg.* FROM package pkg JOIN patient p O...,hard,"[(datetime.datetime(2025, 4, 17, 13, 17, 32, 8...",,"SELECT p.package_code AS 'Mã gói', p.package_n...",Thông tin về các gói mà bệnh nhân có mã 'KH 00...,,"[('GDV20250417201732531', None, datetime.datet...",False,False,"[SELECT\n p.id AS package_id,\n p.package_co...","[False, False, False, False]","[False, False, False, False]",[],[]
288,Chi nhánh nào có nhiều lịch hẹn nhất?\n,"SELECT b.branch_name, COUNT(a.id) AS total_app...",hard,"[('Chi nhánh Cầu Giấy', 63)]",,"SELECT b.branch_name, COUNT(a.id) AS total_app...","Chi nhánh Cầu Giấy có nhiều lịch hẹn nhất, với...",,"[('Chi nhánh Cầu Giấy', 63)]",True,True,"[SELECT\n b.branch_name,\n COUNT(a.id) AS to...","[False, False, False, False]","[True, True, True, True]","[SELECT\n b.branch_name,\n COUNT(a.id) AS to...","[direct_generation, cot_generation, dac_cot_ge..."


In [95]:
df['candidate_least_correct'].value_counts()

candidate_least_correct
[]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              

In [None]:
df['llm_exact_match'].value_counts()

llm_exact_match
True     127
False     97
Name: count, dtype: int64

In [None]:
df['llm_execution_match'].value_counts()

llm_execution_match
True     135
False     89
Name: count, dtype: int64

In [106]:
candidate_counter = {}
for candidate in CANDIDATE_NAMES:
    candidate_counter[candidate] = df['candidate_passes'].apply(lambda x: candidate in x).sum()
candidate_counter
# df['candidate_passes']

{'direct_generation': np.int64(100),
 'cot_generation': np.int64(109),
 'dac_cot_genration': np.int64(104),
 'query_plan_generation': np.int64(102)}