In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Batch-generates NL2SQL 结果并保存为 JSON 文件。

关键优化：
1. **一次读取模板 & format_map**避免在循环里多次 I/O 和正则替换。
2. **批量推理**利用 pipeline 的 `batch_size` 提升吞吐。
3. **no_grad + fp16/8bit**减小显存占用，推理更快。
4. **路径与超参集中管理**便于脚本化 / 日后复现。
5. **进度可视化 (tqdm)**长任务更直观。
6. **异常兜底**避免单条数据崩溃拖垮整体。
"""

import json
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, List
import sys
import torch
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    BitsAndBytesConfig,
)

# ---------- 可调参数 ----------
@dataclass
class Config:
    model_name: str = r"/data/qwen2-7b-instruct"
    prompt_file: str = "QP_prompt.txt"
    input_json: str = "schema_linking_result.json"
    output_dir: str = "result"
    max_new_tokens: int = 1536
    batch_size: int = 4          # 根据显存灵活调整
    do_sample: bool = True
    temperature: float = 0.6
    use_fp16: bool = True        # 或用 8bit/4bit 量化
    device_map: str = "auto"


CFG = Config()
# --------------------------------


def load_prompt_template(path: str) -> str:
    with open(path, encoding="utf-8") as f:
        return f.read()


def format_prompt(template: str, db_id: str, question: str,
                  evidence: str, schema_linking: Dict[str, Any]) -> str:
    """
    使用 str.format_map 直接填充占位符，保持模板可读性。
    模板里写 {db_id}、{question}、{evidence}、{schema_linking}
    """
    return template.format_map({
        "db_id": db_id,
        "question": question,
        "evidence": evidence,
        "schema_linking": json.dumps(schema_linking, ensure_ascii=False),
    })


def load_model_and_tokenizer(cfg: Config):
    # 量化示例：8bit
    quant_cfg = None
    if not cfg.use_fp16:
        quant_cfg = BitsAndBytesConfig(load_in_8bit=True)

    tokenizer = AutoTokenizer.from_pretrained(
        cfg.model_name,
        trust_remote_code=True,
        padding_side="left",  # 对于 text-generation 更稳妥
        local_files_only=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16 if cfg.use_fp16 else torch.float32,
        quantization_config=quant_cfg,
        device_map=cfg.device_map,
        local_files_only=True,
    )
    return tokenizer, model


def batched(iterable: List[Any], n: int):
    """将列表分批切片 (Python 3.12 里有 itertools.batched)"""
    for i in range(0, len(iterable), n):
        yield iterable[i : i + n]

import re

def extract_sql_block(generated_text: str) -> str:
    """从模型输出中提取 ```sql ... ``` 中间内容"""
    pattern = r"```sql\s+(.*?)\s*```"
    matches = re.findall(pattern, generated_text, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches[-1].strip()
    return generated_text.strip()  # fallback
    
def generate():
    cfg = CFG
    root = Path(__file__).resolve().parent
    input_path = root / cfg.input_json
    out_dir = root / cfg.output_dir
    out_dir.mkdir(exist_ok=True)

    data: List[Dict[str, Any]] = json.loads(Path(input_path).read_text(encoding="utf-8"))

    prompt_tpl = load_prompt_template(root / cfg.prompt_file)
    tokenizer, model = load_model_and_tokenizer(cfg)

    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        batch_size=cfg.batch_size,
        do_sample=cfg.do_sample,
        device_map=cfg.device_map,
        return_full_text=False,
    )

    results = []

    # —— 主循环 ——
    with torch.no_grad():
        for batch in tqdm(list(batched(data, cfg.batch_size)), desc="Generating"):
            prompts = [
                format_prompt(
                    prompt_tpl,
                    item.get("db_id", ""),
                    item.get("question", ""),
                    item.get("evidence", ""),
                    item.get("schema_linking", {}),
                )
                for item in batch
            ]

            # 生成
            outputs = generator(prompts, max_new_tokens=cfg.max_new_tokens)
            # pipeline 在 batch 模式下返回 List[List[Dict]]
            for item, gen in zip(batch, outputs):
                text: str = gen[0]["generated_text"]
                sql_key = "sql statement:"
                sql_start = text.lower().find(sql_key)
                sql = text[sql_start + len(sql_key):].strip() if sql_start != -1 else text.strip()

                result = {
                    "db_id": item.get("db_id"),
                    "question": item.get("question"),
                    "evidence": item.get("evidence"),
                    "schema_linking": item.get("schema_linking"),
                    "sql": sql,
                }
                results.append(result)
    return results

cfg = CFG
cfg

Config(model_name='/data/qwen2-7b-instruct', prompt_file='QP_prompt.txt', input_json='schema_linking_result.json', output_dir='result', max_new_tokens=1536, batch_size=4, do_sample=True, temperature=0.6, use_fp16=True, device_map='auto')

In [2]:

root = Path("/home/yangliu26/CHASE/candidates")
input_path = root / cfg.input_json
out_dir = root / cfg.output_dir
out_dir.mkdir(exist_ok=True)

In [3]:
data: List[Dict[str, Any]] = json.loads(Path(input_path).read_text(encoding="utf-8"))
len(data)

10

In [38]:
prompt_tpl = load_prompt_template(root / cfg.prompt_file)
# print(format_prompt(
#                 prompt_tpl,
#                 "test",
#                 "hello, world",
#                 "this is a hint",
#                 "schema",
#             ))

In [5]:
tokenizer, model = load_model_and_tokenizer(cfg)

generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    batch_size=cfg.batch_size,
    do_sample=cfg.do_sample,
    temperature=cfg.temperature,
    device_map=cfg.device_map,
    return_full_text=False,     # 仅返回新增 token，不含 prompt
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [39]:
item = data[3]
prompts = format_prompt(
                prompt_tpl,
                item.get("db_id", ""),
                item.get("question", ""),
                item.get("evidence", ""),
                item.get("schema_linking", {}),
            )
print(prompts)

You are an expert in translating natural language questions into SQL queries using a Query Plan approach.

Your task:
- Given the database, question, evidence, and schema linking, generate a step-by-step Query Plan followed by the final SQL query.

Rules:
- Think step-by-step according to the standard Query Plan structure.
- Do not skip any steps.
- Output only what is required: the Query Plan and Final SQL.
- Final SQL must be enclosed inside a ```sql code block.
- Do not output any additional text, headers, or explanations beyond the required sections.

Reference structure for Query Plan:
1. Understand the intent
2. Locate target tables and columns
3. Identify filter conditions
4. Determine aggregation, grouping, ordering
5. Handle joins if needed
6. Build subqueries if needed
7. Formulate final SQL

[Input]
Given the following information:

- **Database**: movie_platform
- **Question**: Name the movie with the most ratings.
- **Evidence** (schema and sample data): movie with the mos

In [40]:
output = generator(prompts, max_new_tokens=cfg.max_new_tokens)
print(output)

[{'generated_text': " sql\n### Query Plan:\n1. **Understand the intent**: The question asks for the movie title that has the highest total sum of ratings.\n2. **Locate target tables and columns**: We need to join'movies' table with 'ratings' table to get the total ratings for each movie.\n   - Target table:'movies', 'ratings'\n   - Target columns:'movie_title', 'rating_score'\n3. **Identify filter conditions**: We need to calculate the sum of ratings for each movie.\n   - Filter condition: Join'movies' on 'ratings' using'movie_id'.\n4. **Determine aggregation, grouping, ordering**: Aggregate the sum of ratings per movie and order by descending sum.\n   - Aggregation: SUM(rating_score)\n   - Grouping: GROUP BY'movie_title'\n   - Ordering: ORDER BY SUM(rating_score) DESC\n5. **Handle joins if needed**: Join'movies' and 'ratings' tables based on'movie_id'.\n   - Join type: INNER JOIN\n6. **Build subqueries if needed**: No subqueries are necessary for this query.\n7. **Formulate final SQL*

In [21]:
cfg.max_new_tokens

1536

In [41]:
output[0]['generated_text']

" sql\n### Query Plan:\n1. **Understand the intent**: The question asks for the movie title that has the highest total sum of ratings.\n2. **Locate target tables and columns**: We need to join'movies' table with 'ratings' table to get the total ratings for each movie.\n   - Target table:'movies', 'ratings'\n   - Target columns:'movie_title', 'rating_score'\n3. **Identify filter conditions**: We need to calculate the sum of ratings for each movie.\n   - Filter condition: Join'movies' on 'ratings' using'movie_id'.\n4. **Determine aggregation, grouping, ordering**: Aggregate the sum of ratings per movie and order by descending sum.\n   - Aggregation: SUM(rating_score)\n   - Grouping: GROUP BY'movie_title'\n   - Ordering: ORDER BY SUM(rating_score) DESC\n5. **Handle joins if needed**: Join'movies' and 'ratings' tables based on'movie_id'.\n   - Join type: INNER JOIN\n6. **Build subqueries if needed**: No subqueries are necessary for this query.\n7. **Formulate final SQL**: Select the top mo

In [42]:
print(extract_sql_block(output[0]['generated_text']))

SELECT m.movie_title
FROM movies m
INNER JOIN ratings r ON m.movie_id = r.movie_id
GROUP BY m.movie_title
ORDER BY SUM(r.rating_score) DESC
LIMIT 1;
