In [None]:
# 데이터 출처 : https://github.com/glee4810/ehrsql-2024/tree/master/data/mimic_iv

In [30]:
from llm import *
import pandas as pd
import os
import sqlglot
from sqlglot import parse_one, exp

In [31]:
load_dotenv('.env')

oracledb.init_oracle_client(lib_dir=r"C:\instant_client\instantclient_21_19")

pool = oracledb.create_pool(
    user=os.getenv('user'),
    password=os.getenv('password'),
    dsn=os.getenv('ORACLE_DSN'),
    min=2,        # 최소 커넥션
    max=10,       # 최대 커넥션
    increment=1,
)

conn = pool.acquire()
cur = conn.cursor()

In [32]:
def check_answer(conn, sql):
    cursor = conn.cursor()
    cursor.execute(sql)

    columns = [col[0] for col in cursor.description]
    rows = cursor.fetchall()
    return [dict(zip(columns, row)) for row in rows]

def save_TEXT2SQL_EVAL(cursor, conn, ID, QUESTION, ORACLE_QUERY, ANSWER, CREATING_TIME):
    sql = """
        INSERT INTO TEXT2SQL_EVAL (ID, QUESTION, ORACLE_QUERY, ANSWER, CREATING_TIME)
        VALUES (:ID, :QUESTION, :ORACLE_QUERY, :ANSWER, :CREATING_TIME)
    """

    try:
        cursor.execute(
            sql,
            {   
                "ID" : ID,
                "QUESTION": QUESTION,
                "ORACLE_QUERY": ORACLE_QUERY,
                "ANSWER" : ANSWER,
                "CREATING_TIME" : CREATING_TIME
            }
        )
        conn.commit()
    except Exception as e:
        conn.rollback()
        print("❌ DB Error 발생:", e)
        raise

In [33]:
df = pd.read_csv('검증셋.csv')
df.head()

Unnamed: 0,ID,QUESTION,LLM_QUERY,LLM_ANSWER,CREATING_TIME,ORACLE_QUERY,ANSWER
0,0,What are the consumption methods of ampicillin...,SELECT DISTINCT UPPER(prescriptions.route) AS ...,"[{""ROUTE"": ""IM""}, {""ROUTE"": ""IV""}, {""ROUTE"": ""...",46,SELECT DISTINCT prescriptions.route FROM presc...,[['iv']]
1,1,Please tell me the sex of patient 10018328.,SELECT patients.gender FROM patients WHERE pat...,"[{""GENDER"": ""F""}]",41,SELECT patients.gender FROM patients WHERE pat...,[['f']]
2,2,What was an insurance plan for patient 1000586...,SELECT admissions.insurance FROM admissions WH...,"[{""INSURANCE"": ""Medicaid""}]",48,SELECT admissions.insurance FROM admissions WH...,[['medicaid']]
3,3,What was the insurance for patient 10016150 du...,SELECT admissions.insurance FROM admissions WH...,"[{""INSURANCE"": ""Medicare""}]",51,SELECT admissions.insurance FROM admissions WH...,[['medicare']]
4,4,Tell me the insurance plan in place for patien...,SELECT admissions.insurance FROM admissions WH...,[],45,SELECT admissions.insurance FROM admissions WH...,[['other']]


In [34]:
# import time

# for i in range(len(df)):
#     print(i)
#     question = df['question'].iloc[i]
#     start_time = time.perf_counter()
#     try:
#         oracle_query = llm_answer(question)
#         elapsed_seconds = int(time.perf_counter() - start_time)
#         sql = oracle_query.replace(';','')
#         answer = check_answer(conn, sql) # 정답 구하기
#         answer_json = json.dumps(answer, ensure_ascii=False, default=str)
#         save_TEXT2SQL_EVAL(cur, conn, i, question, oracle_query, answer_json, elapsed_seconds)
#     except:
#         save_TEXT2SQL_EVAL(cur, conn, i, question, "실패", "", 0)

In [35]:
df['LLM_QUERY'] = df['LLM_QUERY'].str.replace(';','')

In [36]:
def validate_sql_syntax(cur, sql_text):
    try:
        cur.execute(f'EXPLAIN PLAN FOR {sql_text}')
        cur.execute("SELECT * FROM TABLE(DBMS_XPLAN.DISPLAY)")
        plan_text = "\n".join(r[0] for r in cur.fetchone())
        if len(plan_text) > 0:
            return True
    except:
        return False
    
cnt = 0

for i in range(len(df)):
    if validate_sql_syntax(cur, df['LLM_QUERY'].iloc[i]):
        cnt += 1

print(f'DB에서 실행될 수 있는 확률 : ',round(cnt / len(df),3))

DB에서 실행될 수 있는 확률 :  0.901


In [39]:
def normalize_column(expr):
    """표현식에서 순수 컬럼명만 추출 (함수 래퍼 및 Alias 제거)"""
    # 1. Alias 노드인 경우 내부 식(this)으로 진입
    if isinstance(expr, exp.Alias):
        return normalize_column(expr.this)
    
    # 2. Column 객체인 경우 별칭(alias)을 무시하고 실제 이름(this)만 반환
    if isinstance(expr, exp.Column):
        return expr.this.name.lower()
    
    # 3. UPPER(), TRIM() 같은 함수 래퍼 제거 (재귀)
    if hasattr(expr, 'this') and expr.this is not None:
        return normalize_column(expr.this)
    
    return str(expr).lower()

def extract_sql_components(sql):
    try:
        # Oracle dialect 기준으로 파싱
        tree = parse_one(sql, read="oracle")
    except Exception as e:
        print(f"Parsing Error: {e}")
        return None
    
    components = {}

    # 1. 테이블 추출
    components["tables"] = sorted({
        t.this.name.lower() if isinstance(t.this, exp.Identifier) else t.name.lower()
        for t in tree.find_all(exp.Table)
    })

    # 2. 컬럼 추출 (Alias 제외)
    columns = set()
    # tree.find_all(exp.Column)은 SELECT, WHERE, JOIN에 사용된 모든 컬럼 객체를 찾음
    for c in tree.find_all(exp.Column):
        # c.name은 Alias가 있어도 원래의 컬럼명을 반환함 (sqlglot의 특징)
        # 만약 테이블명.컬럼명 형태라면 컬럼명만 추출
        columns.add(c.this.name.lower())
    
    components["columns"] = sorted(columns)

    return components

# --- 테스트 ---
sql_test = """
SELECT 
    USER_ID AS UID, 
    USER_NAME AS UNAME, 
    DEPT_CODE
FROM USERS 
WHERE UPPER(LOCATION) = 'SEOUL'
"""

result = extract_sql_components(sql_test)
print(f"Tables: {result['tables']}")
print(f"Columns: {result['columns']}") # UID, UNAME이 아닌 USER_ID, USER_NAME이 나와야 함

Tables: ['users']
Columns: ['dept_code', 'location', 'user_id', 'user_name']


In [40]:
df = df[df['LLM_QUERY'] != '실패']

In [41]:
res = []

for i in range(len(df)):
    llm = extract_sql_components(df['LLM_QUERY'].iloc[i])
    real = extract_sql_components(df['ORACLE_QUERY'].iloc[i])

    table_equal = (len([i for i in llm['tables'] if i in real['tables']]) / len(llm['tables'])) * 100
    column_equal = (len([i for i in llm['columns'] if i in real['columns']]) / len(llm['columns'])) * 100

    res.append([table_equal, column_equal])

df[["table_equal", "column_equal"]] = res

In [42]:
import sqlglot
from sqlglot import exp, parse_one

def get_pure_conditions(sql, dialect="oracle"):
    """SQL에서 중복이 제거된 순수 조건 리스트를 추출"""
    try:
        tree = parse_one(sql, read=dialect)
    except:
        return set()

    where = tree.find(exp.Where)
    if not where:
        return set()

    target_types = (exp.EQ, exp.GT, exp.GTE, exp.LT, exp.LTE, exp.Like, exp.In, exp.Is, exp.NEQ)
    op_map = {"eq": "=", "neq": "!=", "gt": ">", "gte": ">=", "lt": "<", "lte": "<=", "like": "LIKE", "is": "IS", "in": "IN"}

    conditions = set()
    for node in where.find_all(target_types):
        # 1. 좌항 컬럼명 추출 (함수/별칭 제거)
        lhs = node.left if hasattr(node, 'left') else node.this
        while hasattr(lhs, 'this') and not isinstance(lhs, exp.Identifier):
            lhs = lhs.this
        col = lhs.name.lower() if hasattr(lhs, 'name') else str(lhs).lower()

        # 2. 연산자
        op = op_map.get(node.key, node.key.upper())

        # 3. 우항 값 추출
        rhs = node.right if hasattr(node, 'right') else (node.expression if hasattr(node, 'expression') else None)
        if isinstance(rhs, exp.Subquery):
            val = "SUBQUERY"
        elif isinstance(node, exp.In) and hasattr(node, 'expressions'):
            val = tuple(sorted(e.sql().strip("'").lower() for e in node.expressions))
        else:
            val = rhs.sql().strip("'").lower() if rhs else "null"

        # 튜플 형태로 저장하여 set으로 자동 중복 제거
        conditions.add((col, op, val))
    
    return conditions

def calculate_accuracy(sql_origin, sql_new):
    # 각 SQL에서 중복 제거된 조건 추출
    conds_origin = get_pure_conditions(sql_origin)
    conds_new = get_pure_conditions(sql_new)

    total_origin = len(conds_origin)
    if total_origin == 0:
        return "원본 SQL에 조건이 없습니다.", 0

    # 비교 로직
    matches = conds_new.intersection(conds_origin)  # 교집합 (맞힌 것)

    accuracy = (len(matches) / total_origin) * 100
    
    return accuracy

where_acc = []

for i in range(len(df)):
    where_acc.append(calculate_accuracy(df['ORACLE_QUERY'].iloc[i], df['LLM_QUERY'].iloc[i]))

df['where_equal'] = where_acc

In [45]:
df['최종평가'] = df[['table_equal', 'column_equal', 'where_equal']].sum(axis=1) / 3

In [49]:
print('최종 SQL 정확도 :', round(df['최종평가'].mean(),2))

최종 SQL 정확도 : 70.31
