In [1]:
import os
import psycopg2
import google.generativeai as genai
import json
import time
import re
import logging
from dotenv import load_dotenv

# Configure logging once at the beginning
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)

# Test logging configuration
logging.info("Logging is configured correctly.")
logging.error("This is a test error message.")

# 加载环境变量
load_dotenv()

# PostgreSQL Database Connection Parameters
DATABASE = {
    'database': os.getenv('DB_NAME'),
    'user': os.getenv('DB_USER'),
    'password': os.getenv('DB_PASSWORD'),
    'host': os.getenv('DB_HOST'),
    'port': os.getenv('DB_PORT')
}
SCHEMA = 'maude'

# Configure Google Generative AI API
genai.configure(api_key=os.getenv('GENAI_API_KEY'))  # Store your API key securely
model = genai.GenerativeModel("gemini-2.0-flash-exp")

def generate_response(prompt):
    """
    Generate a response using Google Generative AI based on the provided prompt.
    """
    try:
        response = model.generate_content(prompt)
        return response.text.strip()
    except Exception as e:
        logging.error(f"Error calling Google Generative AI API: {e}")
        return None

def read_prompt_file(file_path):
    """
    Read the content of the prompt file.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
        return content
    except Exception as e:
        logging.error(f"Error reading file {file_path}: {e}")
        return None

def connect_database():
    """
    Establish a connection to the PostgreSQL database.
    
    Returns:
        psycopg2.connection: The active database connection.
    """
    try:
        conn = psycopg2.connect(**DATABASE)
        logging.info("Successfully connected to the database.")
        return conn
    except Exception as e:
        logging.error(f"Database connection failed: {e}")
        return None

def get_table_structure(cursor, table_name):
    """
    Retrieve the structure of a specified table.
    """
    try:
        query = f"""
        SELECT 
            column_name, 
            data_type, 
            character_maximum_length, 
            is_nullable
        FROM 
            information_schema.columns
        WHERE 
            table_schema = '{SCHEMA}' 
            AND table_name = '{table_name}';
        """
        cursor.execute(query)
        columns = cursor.fetchall()
        structure = []
        for col in columns:
            structure.append({
                'column_name': col[0],
                'data_type': col[1],
                'character_max_length': col[2],
                'is_nullable': col[3]
            })
        return structure
    except Exception as e:
        logging.error(f"Error retrieving table structure ({table_name}): {e}")
        return None

def get_sample_data(cursor, table_name, limit=2):
    """
    Retrieve sample data from a specified table.
    """
    try:
        query = f'SELECT * FROM {SCHEMA}."{table_name}" LIMIT {limit};'
        cursor.execute(query)
        rows = cursor.fetchall()
        # Retrieve column names
        col_names = [desc[0] for desc in cursor.description]
        sample_data = [dict(zip(col_names, row)) for row in rows]
        return sample_data
    except Exception as e:
        logging.error(f"Error retrieving sample data ({table_name}): {e}")
        return None

def clean_sql(sql_query):
    """
    清除 SQL 查询中的注释和多余空白。
    """
    # 去除单行注释
    sql_query = re.sub(r'--.*', '', sql_query)
    # 去除多行注释
    sql_query = re.sub(r'/\*.*?\*/', '', sql_query, flags=re.S)
    # 去除多余空格
    return sql_query.strip()

def execute_sql(conn, sql_query):
    """
    Execute an SQL query and handle results, including operations without result sets.
    Uses a context manager to ensure the cursor is properly closed after execution.
    
    Args:
        conn (psycopg2.connection): The active database connection.
        sql_query (str): The SQL query to be executed.
    
    Returns:
        tuple: (data, error) where data is the result set for SELECT queries or None,
               and error is the error message if an exception occurred, else None.
    """
    try:
        with conn.cursor() as cursor:
            # 清理 SQL 查询
            cleaned_query = clean_sql(sql_query)

            # 如果清理后 SQL 为空，则不执行
            if not cleaned_query:
                return None, "Empty or commented query."

            # 设置 search_path
            cursor.execute('SET search_path TO maude;')

            # 判断是否为 CREATE VIEW 操作
            if cleaned_query.lower().startswith("create view"):
                cursor.execute(cleaned_query)
                conn.commit()  # 提交事务
                return None, None  # CREATE VIEW 操作成功，但没有返回数据

            # 判断是否为查询语句（返回结果集）
            elif cleaned_query.lower().startswith(("select", "with", "show", "describe")):
                cursor.execute(cleaned_query)
                rows = cursor.fetchall()
                col_names = [desc[0] for desc in cursor.description]
                data = [dict(zip(col_names, row)) for row in rows]
                return data, None

            # 对于非查询操作（如 INSERT、UPDATE、DELETE 等）
            else:
                cursor.execute(cleaned_query)
                conn.commit()  # 提交事务
                return None, None  # 非查询操作成功，但没有返回数据

    except psycopg2.Error as e:
        logging.error(f"SQL Execution Error: {e}")
        try:
            conn.rollback()  # 回滚事务以重置连接状态
            logging.info("Transaction has been rolled back.")
        except Exception as rollback_error:
            logging.error(f"Failed to rollback transaction: {rollback_error}")
        return None, str(e)

def extract_table_names(prompt_content):
    """
    Extract real table names from the prompt content.
    Assumes that table names are mentioned in the format `merged table xxx` and replaces them with actual table names.
    For demonstration, uses a predefined mapping.
    """
    # Define a mapping from merged table names to actual table names
    # This should be updated based on actual mappings
    merged_to_real = {
        "Merged_Table_1": "mdrfoi",
        "Merged_Table_2": "patientproblemcode",
        "Merged_Table_3": "some_other_table",  # Replace with actual table names
        # Add all necessary mappings here
    }

    # Extract merged table names using regex
    merged_tables = re.findall(r'Merged_Table_\d+', prompt_content)

    # Replace merged table names with real table names
    involved_tables = []
    for mt in merged_tables:
        real_table = merged_to_real.get(mt, None)
        if real_table:
            involved_tables.append(real_table)
        else:
            logging.warning(f"No real table mapping found for {mt}. Please update the mapping.")

    # If no merged tables found, use default involved tables
    if not involved_tables:
        involved_tables = ["mdrfoi", "ASR_2019", "DEVICE", "deviceproblemcodes", "patient", "patientproblemcode", "patientproblemdata", "foiclass","foitext", "DISCLAIM"]

    return involved_tables

def analyze_data(research_question, data):
    """
    Analyze the data to validate the research question.
    This function can be expanded based on specific analysis requirements.
    For demonstration, it sends the data and research question to Generative AI for analysis.
    """
    if not data:
        logging.warning("No data available for analysis.")
        return

    # Convert data to JSON string
    data_json = json.dumps(data, ensure_ascii=False)

    # Create analysis prompt
    analysis_prompt = (
        f"Based on the following research question and data, list the research question firstly, make interpretation and insights on returned data (use Tables to present the insights if possible, **DO NOT FAKE DATA**) secondly and then analyze the validity and feasibility of the research question.\n\n"
        f"Research Question: {research_question}\n\n"
        f"Data: {data_json}\n\n"
        f"Provide a detailed analysis report:"
    )

    # Get analysis report from Generative AI
    analysis_report = generate_response(analysis_prompt)
    if analysis_report:
        logging.info(f"Analysis Report:\n{analysis_report}")
        # Write the analysis report to a file
        with open("finalreport.md", "w", encoding="utf-8") as file:
            file.write(f"# Analysis Report\n\n{analysis_report}\n")
        logging.info("Analysis report successfully written to finalreport.md.")
    else:
        logging.error("Failed to generate analysis report.")



import random
from typing import Dict, List, Tuple, Set

def parse_prompt_txt(table_info: str) -> Dict:
    """
    Parse prompt.txt file to extract pseudo table to real table and field mappings.
    
    Args:
        table_info (str): Content from prompt.txt
        
    Returns:
        Dict: Mapping of pseudo tables to real tables and fields
    """
    pseudo_tables = {}
    table_pattern = re.compile(
        r"Table\s+'(?P<pseudo_table>[^']+)'\s+\(merged from:\s+([^)]*)\):\s+Fields:\s+(?P<fields>[^.]+)\.",
        re.IGNORECASE
    )
    
    for match in table_pattern.finditer(table_info):
        pseudo_table = match.group('pseudo_table').strip()
        real_tables_str = match.group(2).strip()
        fields_str = match.group('fields').strip()
        
        real_tables = [tbl.strip() for tbl in real_tables_str.split(',')]
        fields = []
        field_pattern = re.compile(r"([A-Za-z0-9_]+)\s+\([^()]+\)")
        for field_match in field_pattern.finditer(fields_str):
            field_name = field_match.group(1).strip()
            fields.append(field_name)
            
        pseudo_tables[pseudo_table] = {
            'real_tables': real_tables,
            'fields': fields
        }
    
    return pseudo_tables

def get_real_table_name(pseudo_tables: Dict, pseudo_table: str) -> str:
    """
    Get a real table name from the pseudo table mapping.
    """
    if pseudo_table not in pseudo_tables:
        logging.error(f"Pseudo table '{pseudo_table}' not found in mapping.")
        return None
        
    real_tables = pseudo_tables[pseudo_table]['real_tables']
    if not real_tables:
        logging.error(f"Pseudo table '{pseudo_table}' has no corresponding real tables.")
        return None
        
    chosen_table = random.choice(real_tables)
    return chosen_table

def process_execution_steps_and_tables(execution_steps: str, table_info: str) -> Tuple[str, List[str]]:
    """
    Process execution steps and extract involved tables.
    
    Args:
        execution_steps (str): The execution steps text containing SQL queries
        table_info (str): Content from prompt.txt
        
    Returns:
        Tuple[str, List[str]]: (Updated execution steps, List of involved tables)
    """
    # Parse table mapping information
    pseudo_tables_mapping = parse_prompt_txt(table_info)
    
    # Regular expressions to find table references
    table_patterns = [
        r"(?:FROM|JOIN)\s+(?:Merged_Table_\d+)",
        r"<Result from (?:Join )?Step \d+>",
        r"Merged_Table_\d+",
    ]
    
    # Find all pseudo table references
    pseudo_tables = set()
    for pattern in table_patterns:
        matches = re.finditer(pattern, execution_steps, re.IGNORECASE)
        for match in matches:
            table_ref = re.search(r'(Merged_Table_\d+)', match.group(0))
            if table_ref:
                pseudo_tables.add(table_ref.group(1))
    
    # Map pseudo tables to real tables and update execution steps
    involved_tables = []
    updated_steps = execution_steps
    
    for pseudo_table in pseudo_tables:
        real_table = get_real_table_name(pseudo_tables_mapping, pseudo_table)
        if real_table:
            # Update execution steps
            updated_steps = re.sub(
                rf'\b{pseudo_table}\b',
                real_table,
                updated_steps
            )
            # Add to involved tables if not already present
            if real_table not in involved_tables:
                involved_tables.append(real_table)
    
    # Handle intermediate result references
    step_pattern = r'<Result from (?:Join )?Step \d+>'
    step_matches = re.finditer(step_pattern, updated_steps)
    for match in step_matches:
        step_num = re.search(r'\d+', match.group(0)).group(0)
        updated_steps = updated_steps.replace(
            match.group(0),
            f'step_{step_num}_result'
        )
    
    # Sort involved tables for consistency
    involved_tables = sorted(involved_tables)
    
    # Add standard tables that are always involved
    standard_tables = []
    
    # Merge and deduplicate while maintaining order
    final_tables = []
    for table in standard_tables:
        if table not in final_tables:
            final_tables.append(table)
    for table in involved_tables:
        if table not in final_tables:
            final_tables.append(table)
    
    return updated_steps, final_tables

def generate_dqc_plan(execution_steps, table_info, add_content):
    """
    Generate a Data Quality Check (DQC) plan using GenAI based on execution steps.
    """
    prompt = (
        "Please create a detailed Data Quality Check (DQC) plan based on the following execution steps. "
        "Focus only on the fields involved in the execution steps, not all fields of the tables.\n\n"
        f"Confirmed MAUDE Database Table Structures and Sample Data:\n{json.dumps(table_info, ensure_ascii=False, indent=2)}\n\n"
        f"Additional Table Information:\n{add_content}\n\n"
        f"Optimized Execution Steps:\n{execution_steps}\n\n"
        "The DQC plan should include the following aspects:\n"
        "1. Field Existence Checks\n"
        "2. Field Type Consistency Checks\n"
        "3. Logical Relationships Between Fields\n"
        "4. Data Cleaning and Transformation Validation\n"
        "5. Potential Data Quality Issues and Recommendations\n\n"
        "Please outline the strategies and corresponding SQL queries needed to perform these checks.\n "
        "Provide the SQL queries in code blocks within ```sql``` code fences.\n"
        "Ensure that table names are formatted as ** \"maude\".\"tablename\" ** and use the correct table and column names as per the Confirmed Table Structures and Data Samples.\n"
        "Do NOT generate any SQL queries that may modify data, such as UPDATE, DELETE, INSERT, CREATE TABLE...\n\n"
        "Please ensure that each SQL query does not return too many hits by including a LIMIT 10 clause, distinct function or other techniques.\n\n"
        "Provide the SQL queries in the following format:\n\n"
        "```sql\nSELECT * FROM \"maude\".\"tablename\" LIMIT 10;\n```\n"
    )
    
    dqc_plan = generate_response(prompt)
    if not dqc_plan:
        logging.error("Failed to generate DQC plan.")
    return dqc_plan

def extract_sql_queries(dqc_plan):
    """
    Extract SQL queries from the DQC plan.
    """
    pattern = r'```sql\n(.*?)```'
    matches = re.findall(pattern, dqc_plan, re.DOTALL)
    sql_queries = [match.strip() for match in matches]
    return sql_queries

def execute_query_list(sql_queries, table_info):
    """
    Execute the DQC SQL queries and collect results.
    Manages database connection and handles query execution with error correction.
    
    Args:
        sql_queries (list): List of SQL queries for data quality checks.
        table_info (dict): Confirmed table structures and sample data.
    
    Returns:
        dict: A dictionary with SQL queries as keys and their execution results or errors as values.
    """
    # Connect to the database
    conn = connect_database()
    if not conn:
        logging.error("Database connection failed.")
        return {}
    
    # Initialize a dictionary to store results and a list to accumulate all data
    dqc_results = {}
    dataall = []
    max_retries = 5  # Increased retries to handle persistent issues
    
    try:
        for idx, sql_query in enumerate(sql_queries, start=1):
            logging.info(f"Executing DQC SQL Query {idx}/{len(sql_queries)}:\n{sql_query}\n")
            attempt = 0
            while attempt < max_retries:
                data, error = execute_sql(conn, sql_query)
                if error:
                    logging.error(f"SQL Execution Error on DQC Query {idx}: {error}\n")
    
                    if "current transaction is aborted" in error.lower():
                        try:
                            conn.rollback()
                            logging.info("Rolled back the aborted transaction.")
                        except Exception as rollback_error:
                            logging.error(f"Failed to rollback transaction: {rollback_error}")
                            break
                        attempt += 1
                        continue
    
                    # Prepare the correction prompt for Generative AI with code fences
                    correction_prompt = (
                        f"The following SQL query resulted in an error. Please correct it based on the error message and the table information.\n\n"
                        f"Original SQL Query:\n{sql_query}\n\n"
                        f"Error Message: {error}\n\n"
                        f"Table Information: {json.dumps(table_info, ensure_ascii=False, indent=2)}\n\n"
                        f"Provide the corrected SQL query enclosed within ```sql``` code fences.\n"
                        f"Do NOT include any additional commentary or text."
                    )
                    time.sleep(3)  # Delay to avoid frequent requests
                    corrected_sql_full = generate_response(correction_prompt)
    
                    if not corrected_sql_full:
                        logging.warning("Failed to correct SQL query. Skipping to the next query.\n")
                        dqc_results[f"DQC Query {idx}"] = {"error": error}
                        break
    
                    logging.info(f"Corrected DQC SQL Query {idx}:\n\n{corrected_sql_full}\n")
    
                    # Extract the corrected SQL query from ```sql``` code fences
                    pattern = r'```sql\s*\n(.*?)```'
                    matches = re.findall(pattern, corrected_sql_full, re.DOTALL | re.IGNORECASE)
                    if matches:
                        # Assume the first match is the corrected query
                        corrected_query = matches[0].strip()
                        logging.info(f"Updating DQC Query {idx} with corrected SQL.")
                        sql_query = corrected_query
                    else:
                        logging.warning("No SQL code block found in the corrected response. Skipping to the next query.")
                        dqc_results[f"DQC Query {idx}"] = {"error": error}
                        break
    
                    attempt += 1
                    time.sleep(1)  # Brief pause before retrying
                else:
                    logging.info(f"DQC SQL Query {idx} executed successfully.\n")
                    if data:
                        dqc_results[f"DQC Query {idx}"] = {"data": data}
                        dataall.extend(data)
                        logging.info(f"Retrieved {len(data)} records from DQC Query {idx}.\n")
                    else:
                        dqc_results[f"DQC Query {idx}"] = {"data": None}
                        logging.info(f"No data returned from DQC Query {idx}.\n")
                    break
    
            # After retries, check if the last attempt resulted in an error
            if error and attempt == max_retries:
                logging.error(f"Reached maximum retry attempts for DQC Query {idx}. Unable to execute this query.\n")
                dqc_results[f"DQC Query {idx}"] = {"error": error}
    
    finally:
        # Ensure the database connection is closed properly
        try:
            conn.close()
            logging.info("Database connection closed.")
        except Exception as close_error:
            logging.error(f"Error closing database connection: {close_error}")
    
    # Correctly calculate the total records
    total_records = sum(len(v["data"]) for v in dqc_results.values() if v.get("data"))
    logging.info(f"Total records retrieved from all queries: {total_records}")
    
    return dqc_results

def generate_dqc_report(dqc_plan, dqc_results):
    """
    Generate a Data Quality Control report using Generative AI based on the DQC plan and results.
    """
    # Convert DQC results to a readable format
    dqc_results_str = json.dumps(dqc_results, ensure_ascii=False, indent=2)
    
    prompt = (
        "Please generate a detailed Data Quality Control (DQC) report based on the following plan and execution results. "
        "The report should be in Markdown format and include the following sections:\n"
        "1. Introduction\n"
        "2. Data Quality Check Plan\n"
        "3. Execution Results\n"
        "4. Analysis and Recommendations\n\n"
        "### Data Quality Check Plan:\n"
        f"{dqc_plan}\n\n"
        "### Execution Results:\n"
        f"```json\n{dqc_results_str}\n```\n\n"
        "### Analysis and Recommendations:\n"
        "Based on the execution results, analyze the data quality issues identified and provide recommendations for improvement."
    )
    
    dqc_report = generate_response(prompt)
    if not dqc_report:
        logging.error("Failed to generate DQC report.")
    return dqc_report

def perform_data_quality_control(execution_steps_new, table_info, add_content):
    """
    Perform the entire Data Quality Control process:
    1. Generate DQC plan
    2. Extract and execute SQL queries
    3. Generate DQC report
    """
    # 1. Generate DQC Plan
    dqc_plan = generate_dqc_plan(execution_steps_new, table_info, add_content)
    if not dqc_plan:
        logging.error("Failed to generate Data Quality Check plan.")
        return
    
    logging.info("Data Quality Check Plan:\n")
    logging.info(dqc_plan)
    
    # 2. Extract SQL Queries from DQC Plan
    sql_queries = extract_sql_queries(dqc_plan)
    if not sql_queries:
        logging.error("No SQL queries found in the DQC plan.")
        return
    
    logging.info("Extracted DQC SQL Queries:\n")
    for idx, query in enumerate(sql_queries, start=1):
        logging.info(f"--- DQC SQL Query {idx} ---")
        logging.info(query)
        logging.info("\n")
    
    # 3. Execute SQL Queries and Gather Results
    dqc_results = execute_query_list(sql_queries, table_info)
    
    # 4. Generate DQC Report
    dqc_report = generate_dqc_report(dqc_plan, dqc_results)
    if dqc_report:
        logging.info("\nData Quality Control Report:\n")
        logging.info(dqc_report)
        
        # Save the report to a Markdown file
        try:
            with open("data_quality_report.md", "w", encoding="utf-8") as report_file:
                report_file.write(dqc_report)
            logging.info("Data Quality Control Report has been successfully saved to data_quality_report.md.")
        except Exception as e:
            logging.error(f"Failed to write DQC report to file: {e}")
    else:
        logging.error("Failed to generate Data Quality Control report.")


2024-12-26 17:39:08,135 [INFO] Logging is configured correctly.
2024-12-26 17:39:08,136 [ERROR] This is a test error message.


In [None]:
# 1.1 Read table information from prompt.txt
prompt_file = 'prompt.txt'
prompt_content = read_prompt_file(prompt_file)
if not prompt_content:
    logging.error("Failed to read prompt.txt file.")
    exit(1)  # Exit if prompt.txt cannot be read

# 1.2 Read additional table information from metadata.txt
add_file = 'metadata.txt'
add_content = read_prompt_file(add_file)
if not add_content:
    logging.error("Failed to read metadata.txt file.")
    exit(1)  # Exit if metadata.txt cannot be read

# 2. Generate a research question
research_prompt = (
    f"Based on the following MAUDE database table information, propose a meaningful research question and strategy.\n\n\n\n"
    f"\n\n{prompt_content}\n\n"
    f"Additional Table Information: \n{add_content}\n\n"
)
research_question = generate_response(research_prompt)
if not research_question:
    logging.error("Failed to generate a research question.")
    exit(1)  # Exit if research question cannot be generated
logging.info(f"Proposed Research Question:\n\n{research_question}\n")

# 3. Plan execution steps based on the research question
planning_prompt = (
    f"Based on the following research question, outline specific execution steps, including which tables and fields need to be queried.\n\n\n\n"
    f"Research Question: \n{research_question}"
)
execution_steps = generate_response(planning_prompt)
if not execution_steps:
    logging.error("Failed to plan execution steps.")
    exit(1)  # Exit if execution steps cannot be planned
logging.info(f"Planned Execution Steps:\n\n{execution_steps}\n")
# 4. Identify involved tables by extracting from execution steps
updated_steps, involved_tables = process_execution_steps_and_tables(execution_steps, prompt_content)

logging.info("Updated execution steps:")
logging.info(updated_steps)
logging.info("\nInvolved tables:")
logging.info(involved_tables)

# 5. Acquire Involved Table Info, Check table structures and sample data
# Connect to the database
conn = connect_database()
if not conn:
    logging.error("Database connection failed.")
    exit(1)  # Exit the script if the connection fails

table_info = {}
try:
    with conn.cursor() as cursor:
        for table in involved_tables:
            # Check if table exists
            structure = get_table_structure(cursor, table)
            if structure is None or not structure:
                logging.warning(f"Table '{table}' does not exist or has no columns. Skipping.")
                continue  # Skip to the next table if structure is invalid
            
            # Retrieve sample data
            samples = get_sample_data(cursor, table)
            if samples is None:
                logging.warning(f"Unable to retrieve sample data for table: {table}")
                continue  # Skip to the next table if samples cannot be retrieved
            
            # Populate table_info dictionary
            table_info[table] = {
                'structure': structure,
                'samples': samples
            }
            
            # Optional: Log table structure and sample data for verification
            # logging.debug(f"Table: {table}")
            # logging.debug(f"Structure: {json.dumps(structure, ensure_ascii=False, indent=2)}")
            # logging.debug(f"Sample Data: {json.dumps(samples, ensure_ascii=False, indent=2, default=serialize)}\n")
    
    if not table_info:
        logging.error("No valid tables found for analysis. Exiting.")
        exit(1)  # Exit the script if no valid tables are found

except psycopg2.Error as e:
    logging.error(f"An error occurred while accessing the database: {e}")
    exit(1)  # Exit the script on database errors

finally:
    # Ensure the database connection is closed properly
    try:
        conn.close()
        logging.info("Database connection closed.")
    except Exception as close_error:
        logging.error(f"Error closing database connection: {close_error}")

logging.info(f"Table Information: {json.dumps(table_info, ensure_ascii=False, indent=2)}")

# 7. Polish Execution Steps w/ Confirmed Table Info and Add Info
planning_prompt_new = (
    f"Optimize Execution Steps, Prevent table, field name and value format from errors according to Table Structures and Data Samples: \n{json.dumps(table_info, ensure_ascii=False, indent=2)}\n\n"
    f"Additional Table Information: \n{add_content}\n\n"
    f"Based on the above Confirmed information of MAUDE Database Structures and Data Samples and the following Execution Steps, polish the specific execution steps, especially on correcting the name and logic of tables and columns that need to be queried.\n\n\n\n"
    f"Current Execution Steps: \n{updated_steps}\n\n"
)
# print(planning_prompt_new)
execution_steps_new = generate_response(planning_prompt_new)
if not execution_steps_new:
    logging.error("Failed to polish execution steps.")
    exit(1)
logging.info(f"Optimized Execution Steps:\n\n{execution_steps_new}\n")

In [None]:
# Generate SQL queries with instructions
sql_prompt = (
    f"Based on the following execution steps and confirmed table structures and data samples, generate SQL queries.\n\n"
    f"Execution Steps: \n{execution_steps_new}\n\n"
    f"Confirmed Table Structures and Data Samples: \n{json.dumps(table_info, ensure_ascii=False, indent=2)}\n\n"
    f"Ensure that table names are formatted as **\"maude\".\"tablename\"** and use the correct table and column names as per the Confirmed Table Structures and Data Samples.\n\n"
    f"Please ensure that each SQL query does not return more than 10 hits by including a LIMIT 10 clause.\n\n"
    f"Each generated SQL statement should be enclosed within ```sql``` code fences, be self-contained and independent, meaning they should not rely on the execution of other SQL statements. If there are dependencies between queries, combine them into a single, cohesive SQL statement.\n\n"
    #f"Use no more than 10 simple SQL queries to fulfill the execution.\n\n"
    f"Do NOT generate any SQL queries that may modify data, such as UPDATE, DELETE, INSERT, CREATE TABLE...\n\n"
    f"Provide the SQL queries in the following format:\n\n"
    f"```sql\nSELECT * FROM \"maude\".\"tablename\" LIMIT 10;\n```\n"
)

# Generate SQL queries with instructions in JSON format
sql_query_full = generate_response(sql_prompt)
if not sql_query_full:
    logging.error("Failed to generate SQL queries.")
    exit(1)  # Exit if SQL queries cannot be generated

logging.info(f"Generated SQL Queries JSON:\n\n{sql_query_full}\n")

sql_queries = extract_sql_queries(sql_query_full)

# Execute DQC Queries
sql_results = execute_query_list(sql_queries, table_info)

In [None]:
# 9. Analyze data to validate the research question
if sql_results:
    # Construct the formatted string and assign it to a new variable
    tmp = (
        f"Execution Steps: {execution_steps_new}\n\n"
        f"SQL Queries: {sql_queries}\n\n"
        f"SQL Execution Outcome:\n{json.dumps(sql_results, ensure_ascii=False, indent=2)}\n\n"
    )

    analyze_data(research_question, tmp)

In [None]:
perform_data_quality_control(execution_steps_new, table_info, add_content)

In [None]:
import json
import logging
import re
import sys

# 假设您已在别处定义 generate_response 与 logging 配置
# from your_openai_module import generate_response
# logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')

def sanitize_code(code_str: str) -> str:
    """
    去除可能导致语法错误的 Markdown 代码块标记和其他无效字符，
    并去掉前后空白，保证可以被 Python 解释器正确执行。
    """
    code_str = re.sub(r'```python\s*', '', code_str, flags=re.IGNORECASE)
    code_str = re.sub(r'```', '', code_str, flags=re.IGNORECASE)
    code_str = code_str.strip()
    # 若有 BOM 或不可见字符:
    # code_str = code_str.encode('utf-8', 'ignore').decode('utf-8', 'ignore')
    return code_str

def execute_generated_code(analysis_code, sql_results):
    """
    执行生成的Python分析代码。
    
    Args:
        analysis_code (str): 生成的Python代码片段。
        sql_results (dict): SQL 查询返回的数据结果。
    
    Returns:
        tuple: (success: bool, output: str)
    """
    try:
        local_namespace = {}
        exec_globals = globals().copy()
        # 使生成的代码能直接访问 sql_results
        exec_globals['sql_results'] = sql_results

        from io import StringIO
        old_stdout = sys.stdout
        sys.stdout = StringIO()
        try:
            exec(analysis_code, exec_globals, local_namespace)
            output = sys.stdout.getvalue()
        finally:
            sys.stdout = old_stdout
        
        return True, output

    except Exception as e:
        return False, str(e)

def perform_data_analysis(execution_steps_new, sql_queries, sql_results, max_retries=3):
    """
    自动生成、执行并调试分析代码：
      1) 基于 execution_steps_new 和 sql_results 
      2) 要求只使用已有的 sql_results
      3) 可视化结果用 plt.savefig('xxx.png') 方式保存
      4) 最多可重试 max_retries 次
    """
    # 限制 sql_results 的数据量，防止数据过大
    # limited_sql_results = {k: v[:2] for k, v in sql_results.items() if isinstance(v, list)}
    
    analysis_prompt = (
        "请使用现有的 sql_results 变量(已在执行环境中)，将其转换为 pandas.DataFrame 进行分析。"
        "不要在代码里赋值 sql_results。"
        "根据以下执行步骤(仅做参考，不用回显)、以及 sql_results 的数据结构，"
        "编写简洁的 Python 代码做一些基础分析和可视化:"
        "  - 分析可以包含: 统计描述、groupby、简单图表等"
        "  - 每个图表以 plt.savefig('xxx.png') 保存, 不要 plt.show()"
        "  - 代码要尽量精简，可读性高，注释简洁"
        "  - 不要包含任何三引号或 ``` 代码块标记\n\n"
        f"执行步骤(参考): {execution_steps_new}\n\n"
        f"示例数据结构(参考): {sql_queries}\n\n"
        f"示例 sql_results(部分): {json.dumps(sql_results, ensure_ascii=False, indent=2)}\n\n"
        "请生成纯 Python 代码，不要多余解释。"
    )

    retry_count = 0
    success = False
    analysis_code = ""
    original_code = ""
    error_message = ""

    while retry_count < max_retries and not success:
        if retry_count == 0:
            # 第一次生成
            analysis_code = generate_response(analysis_prompt)
            if not analysis_code:
                logging.error("生成分析代码失败。")
                return
            analysis_code = sanitize_code(analysis_code)
        else:
            # 调试修正
            debug_prompt = (
                f"上次执行的代码报错：{error_message}\n"
                f"原始代码:\n{original_code}\n\n"
                f"请直接使用已有的 sql_results，不要在代码里赋值 sql_results。"
                f"图表用 plt.savefig('xxx.png')。保持代码简洁，删除多余逻辑和说明。\n\n"
                f"执行步骤(参考): {execution_steps_new}\n\n"
                f"示例数据结构(参考): {sql_queries}\n\n"
                f"示例 sql_results(部分): {json.dumps(sql_results, ensure_ascii=False, indent=2)}\n\n"
            )
            analysis_code = generate_response(debug_prompt)
            if not analysis_code:
                logging.error("生成修正后的分析代码失败。")
                return
            analysis_code = sanitize_code(analysis_code)

        logging.info(f"第 {retry_count + 1} 次生成/修正的代码:\n{analysis_code}\n")
        
        success, output = execute_generated_code(analysis_code, sql_results)
        if success:
            with open("analysis_output.log", "w", encoding="utf-8") as f:
                f.write(output)
            logging.info("分析代码执行成功，输出已保存到 analysis_output.log。")
            break
        else:
            logging.error(f"分析代码执行失败，第 {retry_count + 1} 次。错误信息:\n{output}")
            original_code = analysis_code
            error_message = output
            retry_count += 1

    if not success:
        logging.error(f"已尝试 {max_retries} 次，仍未成功执行分析代码。")

# 在主流程中调用:
perform_data_analysis(execution_steps_new, sql_queries, sql_results, max_retries=3)
