In [3]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Divide-and-Conquer CoT 方法实现
- 将复杂自然语言问题分解为多个子问题
- 对每个子问题分别生成SQL片段
- 最后合并这些片段，构造完整的SQL
"""

import json
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, List
import torch
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    BitsAndBytesConfig,
)

# ---------- 可调参数 ----------
@dataclass
class Config:
    model_name: str = r"/home/yangliu26/qwen3-8b"  # 请根据实际模型路径调整
    input_json: str = r"/home/yangliu26/CHASE/schema_linking_lxy/schema_linking_result.json"
    output_dir: str = r"/home/yangliu26/CHASE/candidates/cot_result"
    # 文本生成超参
    max_new_tokens: int = 1024
    do_sample: bool = True
    temperature: float = 0.2
    # 性能设置
    batch_size: int = 4
    use_fp16: bool = True
    device_map: str = "auto"

# 配置实例
CFG = Config()

def load_model_and_tokenizer(cfg: Config):
    """加载模型和分词器"""
    # 量化配置
    quant_cfg = None
    if not cfg.use_fp16:
        quant_cfg = BitsAndBytesConfig(load_in_8bit=True)

    tokenizer = AutoTokenizer.from_pretrained(
        cfg.model_name,
        trust_remote_code=True,
        padding_side="left",
        local_files_only=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16 if cfg.use_fp16 else torch.float32,
        quantization_config=quant_cfg,
        device_map=cfg.device_map,
    )
    return tokenizer, model

def batched(iterable: List[Any], n: int):
    """将列表分批切片"""
    for i in range(0, len(iterable), n):
        yield iterable[i : i + n]

def load_prompt_template(path: str) -> str:
    with open(path, encoding="utf-8") as f:
        return f.read()
    
def decompose_question(question: str, db_schema: str, generator) -> List[str]:
    # 分解步骤的提示词模板
    DECOMPOSE_TEMPLATE = load_prompt_template(r"/home/yangliu26/CHASE/template/decompose_template.txt")

    """将问题分解为多个子问题"""
    prompt = DECOMPOSE_TEMPLATE.format(
        question=question,
        db_schema=db_schema
    )
    
    response = generator(prompt, max_new_tokens=512, do_sample=True)
    text = response[0]["generated_text"]
    
    # 解析子问题
    sub_questions = []
    for line in text.strip().split('\n'):
        if line.strip() and any(line.strip().startswith(str(i)) for i in range(1, 10)):
            # 移除序号前缀
            sub_q = line.strip()
            for i in range(1, 10):
                prefix = f"{i}. "
                if sub_q.startswith(prefix):
                    sub_q = sub_q[len(prefix):]
                    break
            sub_questions.append(sub_q)
    
    return sub_questions

def generate_partial_sql(sub_question: str, db_schema: str, generator) -> str:
     # 生成SQL片段的提示词模板
    PARTIAL_SQL_TEMPLATE = load_prompt_template(r"CHASE_work\template\partial_sql_template.txt")
    
    """为子问题生成SQL片段"""
    prompt = PARTIAL_SQL_TEMPLATE.format(
        sub_question=sub_question,
        db_schema=db_schema
    )
    
    response = generator(prompt, max_new_tokens=512, do_sample=True)
    return response[0]["generated_text"].strip()

def assemble_sql(question: str, db_schema: str, sub_questions: List[str], 
                partial_sqls: List[str], generator) -> str:
    # 组合SQL的提示词模板
    ASSEMBLE_TEMPLATE = load_prompt_template(r"CHASE_work\template\assemble_template.txt")
    
    """组合SQL片段为完整SQL"""
    # 格式化子问题和SQL片段
    sub_qs_and_sqls = ""
    for i, (q, sql) in enumerate(zip(sub_questions, partial_sqls), 1):
        sub_qs_and_sqls += f"{i}. 子问题: {q}\n   SQL片段: {sql}\n\n"
    
    prompt = ASSEMBLE_TEMPLATE.format(
        question=question,
        db_schema=db_schema,
        sub_questions_and_sqls=sub_qs_and_sqls
    )
    
    response = generator(prompt, max_new_tokens=1024, do_sample=False)
    return response[0]["generated_text"].strip()

def optimize_sql(sql: str, generator) -> str:
    """优化SQL查询（可选）"""
    # 这里可以添加SQL优化逻辑，如去除冗余等
    # 简单实现，直接返回
    return sql

def divide_and_conquer_sql(question: str, db_schema: str, generator):
    """主函数：使用分而治之方法生成SQL"""
    # 步骤1: 分解问题
    sub_questions = decompose_question(question, db_schema, generator)
    
    # 步骤2: 生成每个子问题的SQL片段
    partial_sqls = []
    for q in sub_questions:
        partial_sql = generate_partial_sql(q, db_schema, generator)
        partial_sqls.append(partial_sql)
    
    # 步骤3: 汇总构造最终SQL
    final_sql = assemble_sql(question, db_schema, sub_questions, partial_sqls, generator)
    
    return optimize_sql(final_sql, generator)

In [6]:
"""处理数据并生成SQL"""
# 创建输出目录(若已存在,则不用创建)
os.makedirs(CFG.output_dir, exist_ok=True)

# 加载数据
with open(CFG.input_json, 'r', encoding='utf-8') as f:
    data = json.load(f)

data[:2]

[{'db_id': 'movie_platform',
  'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.',
  'evidence': 'released in the year 1945 refers to movie_release_year = 1945;',
  'keywords': ['movie titles', 'released', '1945', 'descending', 'popularity'],
  'schema_linking': {'movies': ['movie_title',
    'movies',
    'movie_title_language',
    'movie_release_year',
    'movie_popularity',
    'movie_release_year',
    'movie_title',
    'movies',
    'movie_id',
    'movie_release_year',
    'movies',
    'movie_title',
    'movie_popularity',
    'movies',
    'movie_title',
    'movie_popularity',
    'movies'],
   'ratings': ['movie_id',
    'critic',
    'ratings',
    'critic',
    'ratings',
    'critic_likes',
    'rating_url'],
   'lists': ['lists']}},
 {'db_id': 'movie_platform',
  'question': 'State the most popular movie? When was it released and who is the director for the movie?',
  'evidence': 'most popular movie ref

In [7]:
# 加载模型和分词器
tokenizer, model = load_model_and_tokenizer(CFG)

# 创建生成器
generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=False,
)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.
Device set to use cuda:0


In [8]:
item=data[0]
item

{'db_id': 'movie_platform',
 'question': 'Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.',
 'evidence': 'released in the year 1945 refers to movie_release_year = 1945;',
 'keywords': ['movie titles', 'released', '1945', 'descending', 'popularity'],
 'schema_linking': {'movies': ['movie_title',
   'movies',
   'movie_title_language',
   'movie_release_year',
   'movie_popularity',
   'movie_release_year',
   'movie_title',
   'movies',
   'movie_id',
   'movie_release_year',
   'movies',
   'movie_title',
   'movie_popularity',
   'movies',
   'movie_title',
   'movie_popularity',
   'movies'],
  'ratings': ['movie_id',
   'critic',
   'ratings',
   'critic',
   'ratings',
   'critic_likes',
   'rating_url'],
  'lists': ['lists']}}

In [9]:
question = item.get("question", "")
db_schema = item.get("schema_linking", "")  # 假设数据中有db_schema字段
db_id = item.get("db_id", "")
question, db_schema, db_id

('Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.',
 {'movies': ['movie_title',
   'movies',
   'movie_title_language',
   'movie_release_year',
   'movie_popularity',
   'movie_release_year',
   'movie_title',
   'movies',
   'movie_id',
   'movie_release_year',
   'movies',
   'movie_title',
   'movie_popularity',
   'movies',
   'movie_title',
   'movie_popularity',
   'movies'],
  'ratings': ['movie_id',
   'critic',
   'ratings',
   'critic',
   'ratings',
   'critic_likes',
   'rating_url'],
  'lists': ['lists']},
 'movie_platform')

In [11]:
# 分解步骤的提示词模板
DECOMPOSE_TEMPLATE = load_prompt_template(r"/home/yangliu26/CHASE/template/decompose_template.txt")
print(DECOMPOSE_TEMPLATE)

You are a professional database expert. Your task is to **decompose a complex natural language question** into **2 to 4 simpler sub-questions**, each of which can be more easily translated into an individual SQL fragment.

**Database Schema:**
{db_schema}

**Original Question:**
{question}

**Instructions:**
* Break the original question down into **logically sequential sub-questions**.
* Each sub-question should be **atomic** and **specific**, targeting one aspect of the data.
* Do **not** include explanations or SQL queries—only output the sub-questions as a **numbered list**.

**Output Format:**
1. Sub-question 1
2. Sub-question 2
    ...

---

### **Example**

**Database Schema:**

```sql
Table: Students(id, name, age, major_id)  
Table: Majors(id, name, department)  
Table: Courses(id, name, instructor_id)  
Table: Enrollments(student_id, course_id, grade)  
Table: Instructors(id, name, department)
```

**Original Question:**
Which students majoring in Computer Science have enroll

In [12]:
"""将问题分解为多个子问题"""
prompt = DECOMPOSE_TEMPLATE.format(
    question=question,
    db_schema=db_schema
)
print(prompt)

You are a professional database expert. Your task is to **decompose a complex natural language question** into **2 to 4 simpler sub-questions**, each of which can be more easily translated into an individual SQL fragment.

**Database Schema:**
{'movies': ['movie_title', 'movies', 'movie_title_language', 'movie_release_year', 'movie_popularity', 'movie_release_year', 'movie_title', 'movies', 'movie_id', 'movie_release_year', 'movies', 'movie_title', 'movie_popularity', 'movies', 'movie_title', 'movie_popularity', 'movies'], 'ratings': ['movie_id', 'critic', 'ratings', 'critic', 'ratings', 'critic_likes', 'rating_url'], 'lists': ['lists']}

**Original Question:**
Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.

**Instructions:**
* Break the original question down into **logically sequential sub-questions**.
* Each sub-question should be **atomic** and **specific**, targeting one aspect of the data.
* Do **not** include explanations o

In [13]:
response = generator(prompt, max_new_tokens=512, do_sample=True)
text = response[0]["generated_text"]
print(text)

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.16 GiB. GPU 0 has a total capacity of 39.49 GiB of which 1011.81 MiB is free. Process 2436076 has 36.02 GiB memory in use. Process 2442398 has 2.47 GiB memory in use. Of the allocated memory 1.94 GiB is allocated by PyTorch, and 41.67 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [35]:
# 解析子问题
sub_questions = []
for line in text.strip().split('\n'):
    if line.strip() and any(line.strip().startswith(str(i)) for i in range(1, 10)):
        # 移除序号前缀
        sub_q = line.strip()
        for i in range(1, 10):
            prefix = f"{i}. "
            if sub_q.startswith(prefix):
                sub_q = sub_q[len(prefix):]
                break
        sub_questions.append(sub_q)
sub_questions

['1']

In [None]:
# 处理每个样本
results = []
for i, item in enumerate(tqdm(data, desc="Processing")):
    try:
        question = item.get("question", "")
        db_schema = item.get("db_schema", "")  # 假设数据中有db_schema字段
        db_id = item.get("db_id", f"db_{i}")
        
        # 使用Divide-and-Conquer方法生成SQL
        sql = divide_and_conquer_sql(question, db_schema, generator)
        
        # 保存结果
        result = {
            "db_id": db_id,
            "question": question,
            "sql": sql,
        }
        results.append(result)
        
        # 写入单个文件
        output_file = os.path.join(CFG.output_dir, f"{db_id}_{i}.json")
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=2)
            
    except Exception as e:
        print(f"Error processing item {i}: {e}")

# 写入汇总文件
with open(os.path.join(CFG.output_dir, "all_results.json"), 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print(f"处理完成，结果保存在: {CFG.output_dir}")