In [36]:
import json

# 1. cot.json 파일 로드
dataset = json.load(open('cot_ko.json', 'r', encoding='utf-8'))

In [37]:
import random
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import os
# SentenceTransformer 모델 로드
model_path = os.path.join('plm', 'paraphrase-multilingual-MiniLM-L12-v2')
sentence_encoder = SentenceTransformer(model_path, device="cpu")

def get_static_shots(static_count):
    shots = set()
    while len(shots) < static_count:
        shots.add(random.randint(0, len(dataset) - 1))
    shots = [dataset[id] for id in shots]

    return shots

def get_dynamic_shots(user_question, dataset, dynamic_count):
    question_embedding = sentence_encoder.encode(
        [user_question],
        batch_size=32,
        show_progress_bar=True,
        normalize_embeddings=True,
        convert_to_tensor=True,
        device="cpu"
    ).cpu().tolist()

    scores = cos_sim(question_embedding, [example['question_embedding'] for example in dataset]).squeeze(0).tolist()
    scores = sorted(enumerate(scores), key=lambda x: -x[1])
    shots = []
    for item in scores:
        shots.append(dataset[item[0]])
        if len(shots) == dynamic_count:
            break
    return shots



In [38]:
def get_eot():
    tables_path = os.path.join("data", "spider_data", "tables.json")
    with open(tables_path, "r", encoding="utf-8") as f:
        dbs = json.load(f)
    result = {}
    for db in dbs:
            db_id = db["db_id"]
            table_names = db["table_names_original"]             # ['singer', 'concert', ...]
            col_names = db["column_names_original"]              # [(table_idx, col_name), ...] (주의: index 0 == (*))
            col_types = db["column_types"]                       # ['text','number',...], col_names와 동일 인덱스
            primary_keys = db["primary_keys"]                    # [col_id, ...]
            foreign_keys = db["foreign_keys"]                    # [[src_col_id, tgt_col_id], ...]

            ddl_chunks = []

            # 각 테이블에 대해 EoT 스타일 CREATE TABLE 생성
            for t_idx, t_name in enumerate(table_names):
                # 1) 컬럼 정의 (스파이더 포맷의 0번째 컬럼은 (*), 건너뜀)
                column_defs = []
                for col_id, (owner_t_idx, col_name) in enumerate(col_names):
                    if col_id == 0:          # skip '*'
                        continue
                    if owner_t_idx != t_idx: # 다른 테이블 소속
                        continue
                    ctype = col_types[col_id]
                    column_defs.append(f"    {col_name} {ctype}")

                # 2) 테이블 레벨 제약 (PRIMARY KEY, FOREIGN KEY) — EoT
                #    - PK: 현재 테이블 소속 col_id만 추림
                pk_cols = [col_names[cid][1] for cid in primary_keys if col_names[cid][0] == t_idx]

                #    - FK: src 컬럼이 현재 테이블 소속인 경우만 추림
                fk_defs = []
                for src_col_id, tgt_col_id in foreign_keys:
                    src_owner, src_name = col_names[src_col_id]
                    if src_owner != t_idx:
                        continue
                    tgt_owner, tgt_name = col_names[tgt_col_id]
                    ref_table = table_names[tgt_owner]
                    fk_defs.append(f"    foreign key ({src_name}) references {ref_table}({tgt_name})")

                # 3) 본문 구성: 컬럼들 + (옵션) PK + (옵션) FK들
                body_lines = list(column_defs)
                if pk_cols:
                    body_lines.append(f"    primary key ({', '.join(pk_cols)})")
                body_lines.extend(fk_defs)

                # 4) 최종 CREATE TABLE 문자열
                ddl = "create table {name} (\n{body}\n);\n".format(
                    name=t_name,
                    body=",\n".join(body_lines) if body_lines else ""  # 컬럼 없는 테이블 방어
                )
                ddl_chunks.append(ddl)

            # 하나의 DB에 대한 모든 테이블 DDL을 합쳐 저장
            ddl_all = "".join(ddl_chunks)
            result[db_id] = ddl_all
    return result


In [39]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

llm = ChatOpenAI(
    model = "openai/gpt-oss-20b",
    temperature = 0.0,
    base_url = "http://202.31.200.184:8111/v1",
    api_key = "not_used"
)

In [40]:
prompt_template = [
    {"role": "system", "content": "Given the database schema, you need to translate the question into the SQL query.\n No Markdown, No Explanation, Only SQL"},
    {"role": "user", "content": "Database schema:\n {eot}\n question: {question}"},
    {"role": "assistant", "content": "{cot}"}
]

In [41]:
# dev.json 파일 읽기
data = json.load(open("data/spider_data/dev_ko.json", "r", encoding="utf-8"))

prompt = ChatPromptTemplate.from_messages(prompt_template)
chain = prompt | llm
answers = []
for d in data:
    question = d["question"]
    static_shots = get_static_shots(2)
    dynamic_shots = get_dynamic_shots(question, dataset, 2)
    eot = get_eot()[d["db_id"]]
    cots = [shot["cot"] for shot in static_shots + dynamic_shots]

    result = chain.invoke({"eot": eot, "question": question, "cot": "\n".join(cots)})

    answers.append(f"{result.content}\t{d['db_id']}")


Batches: 100%|██████████| 1/1 [00:00<00:00, 29.40it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 45.47it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 43.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 43.46it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  9.62it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 52.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 47.64it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.67it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 47.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 52.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 45.44it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 52.56it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 38.47it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 45.45it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 43.51it/s]
Batches: 1

In [42]:
answers = [a.replace("\n", " ") for a in answers]

tmp = "\n".join(answers)

with open("pred_ko.txt", "w", encoding="utf-8") as f:
    f.write(tmp)
