## Env setup

In [1]:
from dotenv import load_dotenv
import os

load_dotenv(override=True, verbose=True)

# MySQL connection string format: mysql+pymysql://username:password@host:port/database_name
conn = os.environ.get("DB_CONN", None)


In [2]:
from shared.db import db


2025-05-14 15:40:05 - shared.db - INFO - DB_CONN: mysql+pymysql://root:yolo2chill@localhost:3306/gsv
2025-05-14 15:41:01 - shared.db - INFO - mysql
2025-05-14 15:41:01 - shared.db - INFO - First 10 tables: ['appointment', 'appointment_history', 'appointment_service', 'attachment', 'attribute', 'batch_info', 'birthday_care', 'branch', 'care_service', 'care_treatment']


In [3]:
r = db.run_no_throw("SELECT full_name, mobile_phone FROM patient WHERE full_name = 'Cường nguyễn'")

In [4]:
r

"[('Cường nguyễn', '0989149111')]"

In [11]:
r.startswith('Error:')


True

In [5]:
from pydantic import BaseModel
from typing import List, Optional, Literal


class Conversation(BaseModel):
    question: str
    answer: str
    level: Optional[Literal["easy", "hard"]] = None
    result: Optional[str] = None
    error: Optional[str] = None
    

In [15]:
c = Conversation(question="Có bao nhiêu khách hàng đặt lịch hẹn trong vòng 6 tháng trở lại đây?", answer="SELECT count(*) from patient where patient.id in (select appointment.patient_id FROM appointment where appointment.created_on > date_sub(current_date(), interval 6 month));", level="hard")
c.model_dump_json()


'{"question":"Có bao nhiêu khách hàng đặt lịch hẹn trong vòng 6 tháng trở lại đây?","answer":"SELECT count(*) from patient where patient.id in (select appointment.patient_id FROM appointment where appointment.created_on > date_sub(current_date(), interval 6 month));","level":"hard","error":null}'

In [7]:
import re
from tqdm.notebook import tqdm
# from tqdm import tqdm

data_files = [
    "E:/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/grok.txt",
    "E:/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/gemini-2.5-pro.txt",
]

gen_data: List[Conversation] = []
failed_data: List[str] = []

for file in tqdm(data_files[:], desc="Processing files: "):
    with open(file, "r", encoding="utf-8") as f:
        data = f.read()
    data = data.split("---")
    for d in tqdm(data[:], leave=False):
        # print("----")
        # print(d)
        question = re.search(r"question: (.*)(?=answer:)", d, re.DOTALL)
        answer = re.search(r"answer: (.*)(?=level:)", d, re.DOTALL)
        level = re.search(r"level: (.*)", d, re.DOTALL)
        # print(question, answer, level)
        if question and answer and level:
            question = question.group(1)
            answer = answer.group(1)
            level = level.group(1)
            # print(f'question: {question}\nanswer: {answer}\nlevel: {level}\n')
            result = db.run_no_throw(answer)
            if result.lower().startswith('error:'):
                error, result = result, None
            else:
                error = None
            c = Conversation(question=question, answer=answer, level=level.strip(), result=result, error=error)
            gen_data.append(c)
        else:
            failed_data.append(d)


Processing files:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/178 [00:00<?, ?it/s]

  0%|          | 0/248 [00:00<?, ?it/s]

In [10]:
failed_data

['', '', '', '\n']

In [12]:
sql_failed_data = [d for d in gen_data if d.error]
len(sql_failed_data)

101

In [None]:
import csv

gen_data[0].model_dump_json()
success_data = [
    d for d in gen_data if d.result and not d.result.lower().startswith("error:")
]
failed_data = [
    d for d in gen_data if not d.result or d.result.lower().startswith("error:")
]

In [16]:
with open(
    "E:/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/gen_success_data.csv",
    "w",
    encoding="utf-8",
) as f:
    field_names = list(Conversation.model_fields.keys())
    writer = csv.DictWriter(f, fieldnames=field_names)
    writer.writeheader()
    for d in success_data:
        writer.writerow(d.model_dump())

'{"question":"Tìm danh sách các kỹ thuật viên có lịch hẹn tại chi nhánh Hà Đông và có khiếu nại trong tháng 3 năm 2025.\\n","answer":"SELECT DISTINCT e.full_name \\nFROM employee e \\nJOIN appointment a ON e.id = a.technican_id \\nJOIN complaint c ON e.id = c.technican_id \\nWHERE a.branch_id = (SELECT id FROM branch WHERE branch_name = \'Chi nhánh Hà Đông\') \\nAND YEAR(c.created_on) = 2025 \\nAND MONTH(c.created_on) = 3;\\n","level":"hard","result":null,"error":"Error: (pymysql.err.OperationalError) (1054, \\"Unknown column \'e.full_name\' in \'field list\'\\")\\n[SQL: SELECT DISTINCT e.full_name \\nFROM employee e \\nJOIN appointment a ON e.id = a.technican_id \\nJOIN complaint c ON e.id = c.technican_id \\nWHERE a.branch_id = (SELECT id FROM branch WHERE branch_name = \'Chi nhánh Hà Đông\') \\nAND YEAR(c.created_on) = 2025 \\nAND MONTH(c.created_on) = 3;\\n]\\n(Background on this error at: https://sqlalche.me/e/20/e3q8)"}'

In [22]:
with open(
    "E:/code/AI/agentic-AI/SQL-QA/data/GSV/generated-data/gen_failed_data.csv",
    "w",
    encoding="utf-8",
) as f:
    field_names = list(Conversation.model_fields.keys())
    writer = csv.DictWriter(f, fieldnames=field_names)
    writer.writeheader()
    for d in failed_data + sql_failed_data:
        writer.writerow(d.model_dump())