In [1]:
import os
import psycopg2
import google.generativeai as genai
import json
import time
import re
import logging
import random
from dotenv import load_dotenv
from typing import Dict, List, Tuple
# 导入语言对应的PROMPTS
from prompts.en import PROMPTS
#from prompts.zh import PROMPTS
# ===== 新增: 为处理日期、Decimal 等类型 =====
import datetime
import decimal

def custom_json_handler(obj):
    """
    将不能被 JSON 默认序列化的类型转换成可序列化的形式。
    """
    if isinstance(obj, (datetime.date, datetime.datetime)):
        # 将日期或时间类型转换成字符串（ISO 8601格式）
        return obj.isoformat()
    if isinstance(obj, decimal.Decimal):
        # Decimal 转 float 或 str，都可以；此处示例转 float
        return float(obj)
    # 若还有其它类型无法序列化，可自行添加处理逻辑
    return str(obj)

# ------------------------------------------
# 配置日志
# ------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[logging.StreamHandler()]
)
logging.info("Logging is configured correctly.")
logging.error("This is a test error message.")

# 加载环境变量
load_dotenv()

# 配置数据库和模式
DATABASE = {
    'database': os.getenv('DB_NAME'),
    'user': os.getenv('DB_USER'),
    'password': os.getenv('DB_PASSWORD'),
    'host': os.getenv('DB_HOST'),
    'port': os.getenv('DB_PORT')
}
SCHEMA = 'maude'

# 配置 Google Generative AI API
genai.configure(api_key=os.getenv('GENAI_API_KEY'))
model = genai.GenerativeModel("gemini-2.0-flash-exp")

# 定义全局列表来存储令牌计数
token_counts = []

# ------------------------------------------
# 自定义研究问题（如果为空，将由 AI 生成）
# ------------------------------------------
CUSTOM_RESEARCH_QUESTION = ""
#CUSTOM_RESEARCH_QUESTION = "金属接骨螺钉最常见的与器械相关的不良事件有哪些"

# ------------------------------------------
# 新增：性能指标与时间跟踪
# ------------------------------------------
performance_metrics = {
    "script_start_time": None,
    "script_end_time": None,
    "total_duration_seconds": None,
    "model_calls": [],
    "total_sql_queries": 0,
    "total_retry_count": 0,
    "final_success_queries": [],
}

# 在脚本开头记录开始时间
performance_metrics["script_start_time"] = time.time()

# ------------------------------------------
# 大模型调用函数
# ------------------------------------------
def generate_response(prompt: str) -> str:
    """
    Generate a response using Google Generative AI based on the provided prompt.
    同时记录 token 消耗到 performance_metrics["model_calls"].
    """
    try:
        start_time = time.time()
        response = model.generate_content(prompt)
        end_time = time.time()

        response_text = response.text.strip()
        if not isinstance(response_text, str):
            logging.error(f"generate_response expected a string, got {type(response_text)} instead.")
            return None

        usage = response.usage_metadata
        logging.info(f"Token Count:{usage}")
        token_info = {
            'timestamp': end_time,
            'prompt_token_count': usage.prompt_token_count,
            'candidates_token_count': usage.candidates_token_count,
            'total_token_count': usage.total_token_count,
            'call_duration_seconds': round(end_time - start_time, 3)
        }
        token_counts.append(token_info)
        performance_metrics["model_calls"].append(token_info)

        return response_text
    except Exception as e:
        logging.error(f"Error calling Google Generative AI API: {e}")
        return None

# ------------------------------------------
# 读取文件函数
# ------------------------------------------
def read_prompt_file(file_path):
    """
    Read the content of the prompt file.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
        return content
    except Exception as e:
        logging.error(f"Error reading file {file_path}: {e}")
        return None

# ------------------------------------------
# 数据库连接与查询执行函数
# ------------------------------------------
def connect_database():
    """
    Establish a connection to the PostgreSQL database.
    """
    try:
        conn = psycopg2.connect(**DATABASE)
        logging.info("Successfully connected to the database.")
        return conn
    except Exception as e:
        logging.error(f"Database connection failed: {e}")
        return None

def get_table_structure(cursor, table_name):
    """
    Retrieve the structure of a specified table.
    """
    try:
        query = f"""
        SELECT 
            column_name, 
            data_type, 
            character_maximum_length, 
            is_nullable
        FROM 
            information_schema.columns
        WHERE 
            table_schema = '{SCHEMA}' 
            AND table_name = '{table_name}';
        """
        cursor.execute(query)
        columns = cursor.fetchall()
        structure = []
        for col in columns:
            structure.append({
                'column_name': col[0],
                'data_type': col[1],
                'character_max_length': col[2],
                'is_nullable': col[3]
            })
        return structure
    except Exception as e:
        logging.error(f"Error retrieving table structure ({table_name}): {e}")
        return None

def get_sample_data(cursor, table_name, limit=3):
    """
    Retrieve sample data from a specified table.
    """
    try:
        query = f'SELECT * FROM {SCHEMA}."{table_name}" LIMIT {limit};'
        cursor.execute(query)
        rows = cursor.fetchall()
        # Retrieve column names
        col_names = [desc[0] for desc in cursor.description]
        sample_data = [dict(zip(col_names, row)) for row in rows]
        return sample_data
    except Exception as e:
        logging.error(f"Error retrieving sample data ({table_name}): {e}")
        return None

def clean_sql(sql_query):
    """
    清除 SQL 查询中的注释和多余空白。
    """
    # 去除单行注释
    sql_query = re.sub(r'--.*', '', sql_query)
    # 去除多行注释
    sql_query = re.sub(r'/\*.*?\*/', '', sql_query, flags=re.S)
    # 去除多余空格
    return sql_query.strip()

def execute_sql(conn, sql_query):
    """
    Execute an SQL query and handle results, including operations without result sets.
    Uses a context manager to ensure the cursor is properly closed after execution.
    """
    try:
        with conn.cursor() as cursor:
            # 清理 SQL 查询
            cleaned_query = clean_sql(sql_query)
            if not cleaned_query:
                return None, "Empty or commented query."

            # 设置 search_path
            cursor.execute('SET search_path TO maude;')

            # CREATE VIEW
            if cleaned_query.lower().startswith("create view"):
                cursor.execute(cleaned_query)
                conn.commit()
                return None, None

            # 查询语句
            elif cleaned_query.lower().startswith(("select", "with", "show", "describe")):
                cursor.execute(cleaned_query)
                rows = cursor.fetchall()
                col_names = [desc[0] for desc in cursor.description]
                data = [dict(zip(col_names, row)) for row in rows]
                return data, None

            # 非查询操作
            else:
                cursor.execute(cleaned_query)
                conn.commit()
                return None, None

    except psycopg2.Error as e:
        logging.error(f"SQL Execution Error: {e}")
        try:
            conn.rollback()
            logging.info("Transaction has been rolled back.")
        except Exception as rollback_error:
            logging.error(f"Failed to rollback transaction: {rollback_error}")
        return None, str(e)

# ------------------------------------------
# 提取与处理表信息
# ------------------------------------------
def process_execution_steps_and_tables(execution_steps: str, table_info: str) -> Tuple[str, List[str]]:
    """
    Process execution steps and extract involved tables.
    """

    # 1. 新增：优先匹配的手动指定关系
    merged_to_real = {
        "Merged_Table_1": "ASR_2019",
        "Merged_Table_11": "DISCLAIM",
        "Merged_Table_12": "foidevproblem",
        "Merged_Table_13": "patientproblemcode",
        "Merged_Table_4": "DEVICE2023",
        "Merged_Table_5": "foiclass",
        "Merged_Table_6": "mdr97",
        "Merged_Table_7": "mdrfoiThru2023",
        "Merged_Table_8": "patientThru2023",
        "Merged_Table_9": "foitext2023",
    }

    # 2. 解析 prompt.txt 中的伪表映射关系
    pseudo_tables_mapping = parse_prompt_txt(table_info)

    # 3. 利用正则表达式从 execution_steps 中识别出所有 Merged_Table_xxx
    table_patterns = [
        r"(?:FROM|JOIN)\s+(?:Merged_Table_\d+)",
        r"<Result from (?:Join )?Step \d+>",
        r"Merged_Table_\d+",
    ]
    pseudo_tables = set()
    for pattern in table_patterns:
        matches = re.finditer(pattern, execution_steps, re.IGNORECASE)
        for match in matches:
            table_ref = re.search(r'(Merged_Table_\d+)', match.group(0))
            if table_ref:
                pseudo_tables.add(table_ref.group(1))

    involved_tables = []
    updated_steps = execution_steps

    # 4. 逐个处理 Merged_Table_xxx，优先使用自定义映射，若无则回退原逻辑
    for pseudo_table in pseudo_tables:
        if pseudo_table in merged_to_real:
            real_table = merged_to_real[pseudo_table]
        else:
            # 如果自定义映射中没有该表，则尝试原本的 parse_prompt_txt + get_real_table_name 逻辑
            real_table = get_real_table_name(pseudo_tables_mapping, pseudo_table)

        if real_table:
            updated_steps = re.sub(
                rf'\b{pseudo_table}\b',
                real_table,
                updated_steps
            )
            if real_table not in involved_tables:
                involved_tables.append(real_table)
        else:
            logging.warning(f"Unable to determine a real table for {pseudo_table}.")

    # 5. 处理 <Result from Step X> 的中间结果（仅示例替换，用来把 <Result from Step 1> 替换成 step_1_result）
    step_pattern = r'<Result from (?:Join )?Step \d+>'
    step_matches = re.finditer(step_pattern, updated_steps)
    for match in step_matches:
        step_num = re.search(r'\d+', match.group(0)).group(0)
        updated_steps = updated_steps.replace(
            match.group(0),
            f'step_{step_num}_result'
        )

    # 6. 对涉及的表名排序并合并到 final_tables（如有需要）
    involved_tables = sorted(involved_tables)
    standard_tables = []  # 原逻辑留空；可根据项目实际需求添加“固定表”

    final_tables = []
    for table in standard_tables:
        if table not in final_tables:
            final_tables.append(table)
    for table in involved_tables:
        if table not in final_tables:
            final_tables.append(table)

    return updated_steps, final_tables

def parse_prompt_txt(table_info: str) -> Dict:
    """
    Parse prompt.txt file to extract pseudo table to real table and field mappings.
    """
    pseudo_tables = {}
    table_pattern = re.compile(
        r"Table\s+'(?P<pseudo_table>[^']+)'\s+\(merged from:\s+([^)]*)\):\s+Fields:\s+(?P<fields>[^.]+)\.",
        re.IGNORECASE
    )
    for match in table_pattern.finditer(table_info):
        pseudo_table = match.group('pseudo_table').strip()
        real_tables_str = match.group(2).strip()
        fields_str = match.group('fields').strip()

        real_tables = [tbl.strip() for tbl in real_tables_str.split(',')]
        fields = []
        field_pattern = re.compile(r"([A-Za-z0-9_]+)\s+\([^()]+\)")
        for field_match in field_pattern.finditer(fields_str):
            field_name = field_match.group(1).strip()
            fields.append(field_name)

        pseudo_tables[pseudo_table] = {
            'real_tables': real_tables,
            'fields': fields
        }
    return pseudo_tables

def get_real_table_name(pseudo_tables: Dict, pseudo_table: str) -> str:
    """
    Get a real table name from the pseudo table mapping.
    """
    if pseudo_table not in pseudo_tables:
        logging.error(f"Pseudo table '{pseudo_table}' not found in mapping.")
        return None

    real_tables = pseudo_tables[pseudo_table]['real_tables']
    if not real_tables:
        logging.error(f"Pseudo table '{pseudo_table}' has no corresponding real tables.")
        return None

    chosen_table = random.choice(real_tables)
    return chosen_table

# ------------------------------------------
# 分析并生成报告相关函数
# ------------------------------------------
def analyze_data(research_question, data):
    """
    Analyze the data to validate the research question.
    """
    if not data:
        logging.warning("No data available for analysis.")
        return

    # 使用自定义函数来处理可能的 date, datetime, decimal 等类型
    data_json = json.dumps(data, ensure_ascii=False, default=custom_json_handler)

    analysis_prompt = PROMPTS["ANALYSIS_PROMPT"].format(
        research_question=research_question,
        data_json=data_json
    )
    analysis_report = generate_response(analysis_prompt)
    if analysis_report:
        logging.info(f"Analysis Report:\n{analysis_report}")
        with open("finalreport.md", "w", encoding="utf-8") as file:
            file.write(f"# Analysis Report\n\n{analysis_report}\n")
        logging.info("Analysis report successfully written to finalreport.md.")
    else:
        logging.error("Failed to generate analysis report.")

# ------------------------------------------
# DQC（数据质量控制）相关函数
# ------------------------------------------
def generate_dqc_plan(execution_steps, table_info, add_content):
    table_info_json = json.dumps(table_info, ensure_ascii=False, indent=2, default=custom_json_handler)
    prompt = PROMPTS["DQC_PLAN"].format(
        table_info_json=table_info_json,
        add_content=add_content,
        execution_steps=execution_steps
    )

    dqc_plan = generate_response(prompt)
    if not dqc_plan:
        logging.error("Failed to generate DQC plan.")
        return None

    # Verify that the plan contains at least one SQL code block
    if not re.search(r'```sql\n.*?```', dqc_plan, re.DOTALL):
        logging.error("DQC plan does not contain any SQL queries in code blocks.")
        return None

    return dqc_plan

def extract_sql_queries(dqc_plan):
    if not isinstance(dqc_plan, str):
        logging.error(f"extract_sql_queries expects a string, got {type(dqc_plan)} instead.")
        return []
    pattern = r'```sql\n(.*?)```'
    matches = re.findall(pattern, dqc_plan, re.DOTALL)
    if not matches:
        logging.warning("No SQL queries found in the provided plan.")
    sql_queries = [match.strip() for match in matches]
    return sql_queries

def execute_query_list(sql_queries, table_info, usage_label=""):
    """
    执行给定的一组 SQL 语句，并进行重试和纠错。
    usage_label 用于记录该批 SQL 的用途，如 DQC、Advanced Analysis 等。
    """
    conn = connect_database()
    if not conn:
        logging.error("Database connection failed.")
        return {}

    dqc_results = {}
    dataall = []
    max_retries = 10
    max_empty_retries = 5

    # 记录本次 query 列表的总数
    performance_metrics["total_sql_queries"] += len(sql_queries)

    # 用于最终收集成功的 SQL
    final_success_sql = []

    try:
        for idx, sql_query in enumerate(sql_queries, start=1):
            logging.info(f"Executing {usage_label} SQL Query {idx}/{len(sql_queries)}:\n{sql_query}\n")
            attempt = 0
            empty_attempt = 0
            current_query = sql_query
            error = None  # 用来追踪是否出现错误

            while attempt < max_retries:
                data, error = execute_sql(conn, current_query)
                if error:
                    logging.error(f"SQL Execution Error on {usage_label} Query {idx}: {error}\n")

                    # 每次出错都 +1
                    performance_metrics["total_retry_count"] += 1

                    if "current transaction is aborted" in error.lower():
                        try:
                            conn.rollback()
                            logging.info("Rolled back the aborted transaction.")
                        except Exception as rollback_error:
                            logging.error(f"Failed to rollback transaction: {rollback_error}")
                            break
                        attempt += 1
                        continue

                    correction_prompt = PROMPTS["CORRECTION_PROMPT"].format(
                        current_query=current_query,
                        error=error,
                        table_info=json.dumps(table_info, ensure_ascii=False, default=custom_json_handler),
                        add_content=add_content
                    )
                    time.sleep(2)
                    corrected_sql_full = generate_response(correction_prompt)

                    if not corrected_sql_full:
                        logging.warning("Failed to correct SQL query. Skipping to the next query.\n")
                        dqc_results[f"{usage_label} Query {idx}"] = {"error": error}
                        break

                    pattern = r'```sql\s*\n(.*?)```'
                    matches = re.findall(pattern, corrected_sql_full, re.DOTALL | re.IGNORECASE)
                    if matches:
                        corrected_query = matches[0].strip()
                        logging.info(f"Updating {usage_label} Query {idx} with corrected SQL.")
                        current_query = corrected_query
                    else:
                        logging.warning("No SQL code block found in the corrected response. Skipping to the next query.")
                        dqc_results[f"{usage_label} Query {idx}"] = {"error": error}
                        break

                    attempt += 1
                else:
                    if data:
                        row_count = len(data)
                        logging.info(f"{usage_label} SQL Query {idx} executed successfully with {row_count} rows returned.\n")
                        dqc_results[f"{usage_label} Query {idx}"] = {"data": data, "row_count": row_count}
                        dataall.extend(data)

                        # 记录成功执行的 SQL
                        final_success_sql.append({
                            "query": current_query,
                            "usage_label": usage_label,
                            "rows_returned": row_count
                        })
                        break
                    else:
                        logging.info(f"{usage_label} SQL Query {idx} executed successfully but returned no data.\n")
                        if empty_attempt < max_empty_retries:
                            logging.info(f"No data returned. Attempting to redesign the SQL query (Retry {empty_attempt + 1}/{max_empty_retries}).")
                            performance_metrics["total_retry_count"] += 1
                            redesign_prompt = PROMPTS["REDESIGN_PROMPT"].format(
                                current_query=current_query,
                                table_info=json.dumps(table_info, ensure_ascii=False, default=custom_json_handler),
                                add_content=add_content
                            )
                            time.sleep(2)
                            redesigned_sql_full = generate_response(redesign_prompt)

                            if not redesigned_sql_full:
                                logging.warning("Failed to redesign SQL query. Skipping to the next query.\n")
                                dqc_results[f"{usage_label} Query {idx}"] = {"error": "No data returned and failed to redesign query."}
                                break

                            pattern = r'```sql\s*\n(.*?)```'
                            matches = re.findall(pattern, redesigned_sql_full, re.DOTALL | re.IGNORECASE)
                            if matches:
                                redesigned_query = matches[0].strip()
                                logging.info(f"Updating {usage_label} Query {idx} with redesigned SQL.")
                                current_query = redesigned_query
                                empty_attempt += 1
                            else:
                                logging.warning("No SQL code block found in the redesigned response. Skipping to the next query.")
                                dqc_results[f"{usage_label} Query {idx}"] = {"error": "No data returned and failed to extract redesigned query."}
                                break
                        else:
                            logging.warning(f"No data returned after {max_empty_retries} redesign attempts. Skipping to the next query.\n")
                            dqc_results[f"{usage_label} Query {idx}"] = {"error": "No data returned after multiple redesign attempts."}
                            break

            if error and attempt == max_retries:
                logging.error(f"Reached maximum retry attempts for {usage_label} Query {idx}. Unable to execute this query.\n")
                dqc_results[f"{usage_label} Query {idx}"] = {"error": error}

        total_records = sum(len(v["data"]) for v in dqc_results.values() if v.get("data"))
        logging.info(f"Total records retrieved from all {usage_label} queries: {total_records}")

    finally:
        try:
            conn.close()
            logging.info("Database connection closed.")
        except Exception as close_error:
            logging.error(f"Error closing database connection: {close_error}")

    # 将最终成功的 SQL 写入到全局 performance_metrics，方便后续统一保存
    performance_metrics["final_success_queries"].extend(final_success_sql)

    return dqc_results

def generate_dqc_report(dqc_plan, dqc_results):
    """
    Generate a Data Quality Control report using Generative AI based on the DQC plan and results.
    """
    # Convert DQC results to a readable format
    dqc_results_str = json.dumps(dqc_results, ensure_ascii=False, indent=2)
    prompt = PROMPTS["DQC_REPORT"].format(
        dqc_plan=dqc_plan,
        dqc_results_str=dqc_results_str
    )    
    dqc_report = generate_response(prompt)
    if not dqc_report:
        logging.error("Failed to generate DQC report.")
    return dqc_report

def perform_data_quality_control(execution_steps_new, table_info, add_content):
    dqc_plan = generate_dqc_plan(execution_steps_new, table_info, add_content)
    if not dqc_plan:
        logging.error("Failed to generate Data Quality Check plan.")
        return
    logging.info("Data Quality Check Plan:\n")
    logging.info(dqc_plan)

    sql_queries = extract_sql_queries(dqc_plan)
    if not sql_queries:
        logging.error("No SQL queries found in the DQC plan.")
        return
    logging.info("Extracted DQC SQL Queries:\n")
    for idx, query in enumerate(sql_queries, start=1):
        logging.info(f"--- DQC SQL Query {idx} ---")
        logging.info(query)
        logging.info("\n")
    
    dqc_results = execute_query_list(sql_queries, table_info, usage_label="DQC")
    dqc_report = generate_dqc_report(dqc_plan, dqc_results)
    if dqc_report:
        logging.info("\nData Quality Control Report:\n")
        logging.info(dqc_report)
        try:
            with open("data_quality_report.md", "w", encoding="utf-8") as report_file:
                report_file.write(dqc_report)
            logging.info("Data Quality Control Report has been successfully saved to data_quality_report.md.")
        except Exception as e:
            logging.error(f"Failed to write DQC report to file: {e}")
    else:
        logging.error("Failed to generate Data Quality Control report.")

# ------------------------------------------
# 主流程示例
# ------------------------------------------
if __name__ == "__main__":
    # 1.1 读取 prompt.txt 和 metadata.txt
    schema_file = 'schema.txt'
    prompt_content = read_prompt_file(schema_file)
    if not prompt_content:
        logging.error("Failed to read schema.txt file.")
        exit(1)

    meta_file = 'metadata.txt'
    add_content = read_prompt_file(meta_file)
    if not add_content:
        logging.error("Failed to read metadata.txt file.")
        exit(1)
    
    # 2. 生成或优化研究问题
    if CUSTOM_RESEARCH_QUESTION:
        logging.info("Custom research question provided. Proceeding to optimize it.\n")
        customized_optimize_prompt = PROMPTS["CUSTOMIZED_QUESTION_OPTIMIZE"].format(
            custom_research_question=CUSTOM_RESEARCH_QUESTION,
            prompt_content=prompt_content,
            add_content=add_content
        )
        research_question = generate_response(customized_optimize_prompt)
        if not research_question:
            logging.error("Failed to optimize the custom research question.")
            exit(1)
        logging.info(f"Optimized Research Question:\n\n{research_question}\n")
    else:
        research_prompt = PROMPTS["RESEARCH_QUESTION"].format(
            prompt_content=prompt_content,
            add_content=add_content
        )
        research_question = generate_response(research_prompt)
        if not research_question:
            logging.error("Failed to generate a research question.")
            exit(1)
        logging.info(f"Proposed Research Question:\n\n{research_question}\n")

    # 3. 规划执行步骤
    steps_prompt = PROMPTS["EXECUTION_STEPS"].format(
        research_question=research_question,
        add_content=add_content
    )
    execution_steps = generate_response(steps_prompt)
    if not execution_steps:
        logging.error("Failed to plan execution steps.")
        exit(1)
    logging.info(f"Planned Execution Steps:\n\n{execution_steps}\n")

    # 4. 提取并替换表名
    updated_steps, involved_tables = process_execution_steps_and_tables(execution_steps, prompt_content)
    logging.info("Updated execution steps:")
    logging.info(updated_steps)
    logging.info("\nInvolved tables:")
    logging.info(involved_tables)

    # 5. 获取表结构和样本数据
    conn = connect_database()
    if not conn:
        logging.error("Database connection failed.")
        exit(1)

    table_info = {}
    try:
        with conn.cursor() as cursor:
            for table in involved_tables:
                structure = get_table_structure(cursor, table)
                if structure is None or not structure:
                    logging.warning(f"Table '{table}' does not exist or has no columns. Skipping.")
                    continue

                samples = get_sample_data(cursor, table)
                if samples is None:
                    logging.warning(f"Unable to retrieve sample data for table: {table}")
                    continue

                table_info[table] = {
                    'structure': structure,
                    'samples': samples
                }

        if not table_info:
            logging.error("No valid tables found for analysis. Exiting.")
            exit(1)

    except psycopg2.Error as e:
        logging.error(f"An error occurred while accessing the database: {e}")
        exit(1)
    finally:
        try:
            conn.close()
            logging.info("Database connection closed.")
        except Exception as close_error:
            logging.error(f"Error closing database connection: {close_error}")

    # 6. 优化执行步骤
    table_info_json = json.dumps(table_info, ensure_ascii=False, indent=2, default=custom_json_handler)
    optimize_prompt = PROMPTS["OPTIMIZE_STEPS"].format(
        table_info_json=table_info_json,
        add_content=add_content,
        updated_steps=updated_steps
    )
    execution_steps_new = generate_response(optimize_prompt)
    if not execution_steps_new:
        logging.error("Failed to polish execution steps.")
        exit(1)
    logging.info(f"Optimized Execution Steps:\n\n{execution_steps_new}\n")

    # 7. 生成高级分析 SQL
    advanced_sql_prompt = PROMPTS["ADVANCED_SQL"].format(
        execution_steps_new=execution_steps_new,
        table_info=table_info_json,
        add_content=add_content
    )
    sql_query_full = generate_response(advanced_sql_prompt)
    if not sql_query_full:
        logging.error("Failed to generate SQL queries.")
        exit(1)
    logging.info(f"Generated SQL Queries:\n\n{sql_query_full}\n")

    # 8. 执行 SQL 查询（高级分析）
    #    同时传入 usage_label="Advanced Analysis" 以便区分
    sql_queries = extract_sql_queries(sql_query_full)
    sql_results = execute_query_list(sql_queries, table_info, usage_label="Advanced Analysis")

    # 9. 分析数据验证研究问题
    if sql_results:
        tmp_output = (
            f"Execution Steps: {execution_steps_new}\n\n"
            f"SQL Queries: {sql_queries}\n\n"
            f"SQL Execution Outcome:\n"
            f"{json.dumps(sql_results, ensure_ascii=False, indent=2, default=custom_json_handler)}\n\n"
        )
        analyze_data(research_question, tmp_output)
        
    # 如果需要执行 DQC，可以调用：
    # perform_data_quality_control(execution_steps_new, table_info, add_content)

    # -----------------------------
    # 将脚本执行的各种性能指标输出到 JSON
    # -----------------------------
    performance_metrics["script_end_time"] = time.time()
    performance_metrics["total_duration_seconds"] = round(
        performance_metrics["script_end_time"] - performance_metrics["script_start_time"], 
        3
    )

    # 将性能指标写入 JSON 文件
    try:
        with open("performance_metrics.json", "w", encoding="utf-8") as pm_file:
            json.dump(performance_metrics, pm_file, indent=2, ensure_ascii=False, default=custom_json_handler)
        logging.info("Performance metrics have been saved to performance_metrics.json.")
    except Exception as e:
        logging.error(f"Failed to write performance metrics to file: {e}")

    # -----------------------------
    # 将成功执行的 SQL 语句写入一个独立的 .sql 文件
    # 并包含注释说明其用途等信息
    # -----------------------------
    try:
        with open("final_queries.sql", "w", encoding="utf-8") as fq_file:
            for item in performance_metrics["final_success_queries"]:
                fq_file.write(f"-- Usage: {item['usage_label']}, Rows Returned: {item['rows_returned']}\n")
                fq_file.write(item["query"] + ";\n\n")
        logging.info("Final successful SQL queries have been saved to final_queries.sql.")
    except Exception as e:
        logging.error(f"Failed to write final queries to file: {e}")

    # -----------------------------
    # token statistics (示例打印)
    # -----------------------------
    for idx, count in enumerate(token_counts, 1):
        print(f"调用 {idx}:")
        print(f"  Prompt Token Count: {count['prompt_token_count']}")
        print(f"  Candidates Token Count: {count['candidates_token_count']}")
        print(f"  Total Token Count: {count['total_token_count']}")
        print(f"  Call Duration (s): {count['call_duration_seconds']}")

    # 如果需要总计，可以进行累加
    total_prompt = sum(item['prompt_token_count'] for item in token_counts)
    total_candidates = sum(item['candidates_token_count'] for item in token_counts)
    total_all = sum(item['total_token_count'] for item in token_counts)
    total_duration = sum(item.get('call_duration_seconds', 0) for item in token_counts)

    print("\n累积总计:")
    print(f"总 Prompt Token 数量: {total_prompt}")
    print(f"总 Candidates Token 数量: {total_candidates}")
    print(f"总 Token 数量: {total_all}")
    print(f"所有调用累计时长(秒): {total_duration}")

    logging.info("Script execution completed.")


2024-12-30 01:25:06,620 [INFO] Logging is configured correctly.
2024-12-30 01:25:06,621 [ERROR] This is a test error message.
2024-12-30 01:25:21,045 [INFO] Token Count:prompt_token_count: 23413
candidates_token_count: 1641
total_token_count: 25054

2024-12-30 01:25:21,046 [INFO] Proposed Research Question:

Okay, this is a rich dataset with a lot of potential! Let's craft a meaningful research question and strategy based on the provided MAUDE database information.

**Research Question:**

**How do manufacturer-reported medical device malfunctions, as categorized by specific device problem codes, correlate with patient problems and adverse events over time, and how does this relationship vary across different device classifications?**

**Rationale:**

*   **Focus on Malfunctions & Patient Harm:** The question directly tackles the core purpose of the MAUDE database – understanding device malfunctions and their impact on patients.
*   **Device Problem Codes as a Key:** Utilizing device p

调用 1:
  Prompt Token Count: 23413
  Candidates Token Count: 1641
  Total Token Count: 25054
  Call Duration (s): 14.42
调用 2:
  Prompt Token Count: 1668
  Candidates Token Count: 3380
  Total Token Count: 5048
  Call Duration (s): 24.992
调用 3:
  Prompt Token Count: 24342
  Candidates Token Count: 4822
  Total Token Count: 29164
  Call Duration (s): 36.508
调用 4:
  Prompt Token Count: 13270
  Candidates Token Count: 1830
  Total Token Count: 15100
  Call Duration (s): 14.202
调用 5:
  Prompt Token Count: 19367
  Candidates Token Count: 228
  Total Token Count: 19595
  Call Duration (s): 2.976
调用 6:
  Prompt Token Count: 19465
  Candidates Token Count: 317
  Total Token Count: 19782
  Call Duration (s): 3.712
调用 7:
  Prompt Token Count: 19501
  Candidates Token Count: 299
  Total Token Count: 19800
  Call Duration (s): 3.485
调用 8:
  Prompt Token Count: 19431
  Candidates Token Count: 625
  Total Token Count: 20056
  Call Duration (s): 5.845
调用 9:
  Prompt Token Count: 19792
  Candidates Toke

In [None]:
perform_data_quality_control(execution_steps_new, table_info, add_content)