In [261]:
from config import api_key_qwen
import mysql.connector
from mysql.connector import Error
import json
from utils import logger, clean_create_statement
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.vectorstores import Chroma

# Database connection configuration
DB_CONFIG = {
    "host": "gz-cdb-5scrcjb5.sql.tencentcdb.com",
    "user": "db",
    "password": "dbdb905905",
    "database": "sele",
    "port": 63432
}

def connect_db():
    """Connect to the MySQL database."""
    try:
        connection = mysql.connector.connect(**DB_CONFIG)
        if connection.is_connected():
            return connection
    except Error as e:
        logger.error(f"Error while connecting to MySQL: {e}")
        return None

def sql_query(query: str):
    """Execute an SQL query and return the results."""
    connection = connect_db()
    if not connection:
        return "Database connection failed."
    cursor = connection.cursor()
    try:
        cursor.execute(query)
        result = cursor.fetchall()
        formatted_result = ", ".join([str(row) for row in result])
    finally:
        cursor.close()
        connection.close()
    return formatted_result

def handle_tool_calls(raw_response):
    """Handle tool calls from the AI response."""
    tool_response = None
    if 'tool_calls' in raw_response.output.choices[0].message:
        for tool_call in raw_response.output.choices[0].message.tool_calls:
            if tool_call['function']['name'] == 'sql_query':
                query_args = json.loads(tool_call['function']['arguments'])
                function_result = sql_query(query_args['query'])
                tool_response = {
                    "role": "tool",
                    "content": function_result,
                    "tool_call_id": tool_call.get('id', 'N/A')
                }
                break
    return tool_response

def table_schema():
    """Generate the schema of a specific table (hardcoded for 'tender_key_detail')."""
    connection = connect_db()
    if not connection:
        return "Connection error."
    cursor = connection.cursor()
    cursor.execute("SHOW CREATE TABLE tender_key_detail")
    create_info = cursor.fetchone()
    return clean_create_statement(create_info[1])


def column_comments():
    """Return column comments as a JSON string."""
    columns = {
        'tender_id': '招标项目ID',
        'bid_price': '招标价格（元）',
        'construction_duration': '工期（天）',
        'construction_area': '建筑面积（平方米）',
        'construction_cost': '建安费（元）',
        'qualification_type': '监理企业资质类型,必须是来自于以下：房屋建筑工程、冶炼工程、矿山工程、化工石油工程、水利水电工程、电力工程、农林工程、铁路工程、公路工程、港口与航道工程、航天航空工程、通信工程、市政公用工程、机电安装工程，综合资质',
        'qualification_level': '监理企业资质等级',
        'qualification_profession': '总监注册资格证书专业',
        'title_level': '总监职称等级',
        'education': '总监学历',
        'performance_requirements': '总监相关业绩要求,例如，“至少担任过2项类似工程的监理负责人”',
        'simultaneous_projects_limit': '总监兼任项目限制,例如，“在任职期间能参与的其他在施项目不得超过2个”',
        'qualification_profession_addition': '附加信息'
    }
    return json.dumps(columns, indent=2, ensure_ascii=False)

def sample_entries(question):
    """Retrieve sample entries based on a question."""
    embeddings = DashScopeEmbeddings(model="text-embedding-v1", dashscope_api_key=api_key_qwen)
    persist_dir = "updated_tender_vector_store"
    vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
    results = vectorstore.similarity_search(question, k=3)
    
    full_output = ""
    connection = connect_db()
    if connection:
        cursor = connection.cursor()
        try:
            # 获取并输出列名
            cursor.execute("SELECT * FROM tender_key_detail WHERE FALSE")
            columns = [desc[0] for desc in cursor.description]
            header = ', '.join(columns)
            full_output += f"{header}\n"

            for result in results:
                # 确保之前的结果已经被完全处理
                while cursor.nextset():
                    pass
                
                query = f"SELECT * FROM tender_key_detail WHERE tender_id = '{result.metadata['tender_id']}'"
                cursor.execute(query)
                sample_contents = cursor.fetchall()
                
                for row in sample_contents:
                    formatted_row = ', '.join(map(str, row))
                    full_output += f"{formatted_row}\n"

        finally:
            cursor.close()
    connection.close()
    return full_output

def sqlAgent(user_input):
    """Main function to interact with the SQL database based on user input."""
    tools = [
        {
        "type": "function",
        "function": {
            "name": "sql_query",
            "description": "此函数负责执行传入的MySQL查询语句，并返回查询结果。它专门用于检索数据，不支持修改、删除或更新数据库的操作。这包括但不限于仅执行SELECT查询语句。",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "description": "一个有效的MySQL SELECT查询语句，用于从数据库中检索数据。",
                        "type": "string"
                    }
                },
                "required": ["query"]
            },
        },
        }
    ]
    
    system_prompt = """
    基于用户的问题及元数据表的结构信息，使用sql_query工具去查询可回答问题的数据。
    注意：不用直接输出，而是调用“sql_query”函数，传入查询语句。
    """
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"用户的问题是: {user_input}"},
        {"role": "user", "content": f"元数据表的结构： {table_schema()}"},
        {"role": "user", "content": f"元数据注释： {column_comments()}"},
        {"role": "user", "content": f"元数据表数据示例： {sample_entries(user_input)}"}
    ]
    
    import dashscope
    dashscope.api_key = api_key_qwen
    model_name = "qwen-plus"
    temperature = 0.2
    
    logger.info('AI model initialized: %s', model_name)
    
    raw_response = dashscope.Generation.call(
        messages=messages,
        model=model_name,
        tools=tools,
        temperature=temperature,
        result_format='message'
    )
    
    query_call = raw_response.output.choices[0].message.tool_calls[0]
    query = query_call['function']['arguments']
    query_result = handle_tool_calls(raw_response)['content'] if handle_tool_calls(raw_response) else "No result."

    return query, query_result


In [270]:
user_input = "和水利相关的项目总共有多少，按照招标价格大小排列？记得附上其企业资质要求和等级"

# sample_entries(user_input)
aa,bb = sqlAgent(user_input)

In [271]:
aa

'{"query": "SELECT tender_id, bid_price, qualification_type, qualification_level FROM tender_key_detail WHERE qualification_type = \'水利水电工程\' ORDER BY bid_price DESC"}'

In [272]:
print(bb)

(799, 12098700000.0, '水利水电工程', '乙级'), (598, 5754220000.0, '水利水电工程', '甲级'), (553, 3633170000.0, '水利水电工程', '甲级'), (707, 2033150000.0, '水利水电工程', '丙级'), (597, 1768330000.0, '水利水电工程', '乙级'), (812, 1424400000.0, '水利水电工程', '乙级'), (753, 1360870000.0, '水利水电工程', '乙级'), (736, 1216460000.0, '水利水电工程', '乙级'), (848, 1163620000.0, '水利水电工程', '乙级'), (805, 1000050000.0, '水利水电工程', '乙级'), (110, 68743400.0, '水利水电工程', '乙级'), (580, 60316700.0, '水利水电工程', '丙级'), (573, 29818000.0, '水利水电工程', '乙级'), (741, 6583900.0, '水利水电工程', '乙级'), (601, 6326890.0, '水利水电工程', '甲级'), (772, 6067400.0, '水利水电工程', '乙级'), (773, 6067400.0, '水利水电工程', '乙级'), (708, 5647100.0, '水利水电工程', '丙级'), (822, 5647100.0, '水利水电工程', '乙级'), (264, 5500000.0, '水利水电工程', '乙级'), (58, 4190000.0, '水利水电工程', '乙级'), (717, 4170760.0, '水利水电工程', '甲级'), (648, 3981880.0, '水利水电工程', '乙级'), (56, 2789800.0, '水利水电工程', '乙级'), (433, 2595200.0, '水利水电工程', '乙级'), (728, 2550000.0, '水利水电工程', '丙级'), (859, 2544200.0, '水利水电工程', '乙级'), (860, 2544200.0, '水利水电工程', '乙级'), (182, 2271200.0,

In [299]:
# 现在要写一个AI agent， sql_retrieval_grader 判断上述query 和 query result 是否和问题相关

def sql_retrieval_grader(user_input, query, query_result):
    """
    使用DashScope的大模型评估SQL查询及其结果与用户问题的相关性。
    """
    import dashscope

    # 设置DashScope API Key
    dashscope.api_key = api_key_qwen

    # 构建用于评估的prompt字符串，确保清晰、完整地传达了查询、结果与原始问题
    # evaluation_prompt = f"用户的问题是：'{user_input}'\n"+f"为此问题设计的SQL查询是：\n{query}\n"+f"执行此查询后得到的结果概要是：\n{query_result}\n"+"\n请根据上述信息，判断这个SQL查询及其结果是否紧密相关且能够准确回答用户的问题。"+"如果你认为查询完全符合问题需求，回答'是'；若不符合或有偏差，请回答'否'。"
    evaluation_prompt = (
    f"用户的问题是：'{user_input}'\n"
    f"为此问题设计的SQL查询是：\n{query}\n"
    f"执行此查询后得到的结果概要是：\n{query_result}\n"
    "\n请根据上述信息，判断这个SQL查询及其结果是否紧密相关且能够准确回答用户的问题。"
    "如果你认为查询结果和问题相关，回答'是'；若不符合或有偏差，请回答'否'。"
    "\n 回答基于你的判断，仅仅包含“是”或“否”，不得有其他解释内容"
    )


    # 调用DashScope模型进行评估
    model_name = "qwen-plus"
    temperature = 0.2
    raw_response = dashscope.Generation.call(
        messages=[{"role": "user", "content": evaluation_prompt}],
        model=model_name,
        temperature=temperature,
        result_format='message'
    )

    # 解析模型的回复
    ai_response = raw_response.output.choices[0].message.content.strip().lower()

    # 判断并返回结果
    if ai_response in ['是', 'yes']:
        return True, ai_response
    elif ai_response in ['否', 'no']:
        return False, ai_response
    else:
        return None, ai_response

# # 示例调用（请用实际值替换下面的占位符）
# user_input_test = "最近一周内新增的招标公告详情"
# query_test = "SELECT * FROM tender_announcements WHERE date_added BETWEEN DATE_SUB(CURDATE(), INTERVAL 1 WEEK) AND CURDATE()"
# query_result_test = "示例输出，具体数据省略..."

# relevance, feedback = sql_retrieval_grader(user_input_test, query_test, query_result_test)
# print(f"查询与问题的相关性判断结果：{relevance}. 反馈信息：{feedback}")

In [305]:
# 示例调用（请用实际值替换下面的占位符）
user_input_test = "给我三条最近三天的市政工程项目的标讯"
query_test = "select * from tender_key_detail where qualification_type = '市政公用工程' and date_added BETWEEN DATE_SUB(CURDATE(), INTERVAL 3 DAY) AND CURDATE() limit 3"
query_result_test = "ERROR 1146 (42S02): Table 'sele.tender_key_details' doesn't exist"

relevance, feedback = sql_retrieval_grader(user_input_test, query_test, query_result_test)
print(f"查询与问题的相关性判断结果：{relevance}. 反馈信息：{feedback}")

查询与问题的相关性判断结果：False. 反馈信息：否


In [306]:
print(sql_retrieval_grader(user_input, aa, bb))

(True, '是')


In [None]:
# 现将元数据表（tender_key_detail）内容embedding

import os
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.vectorstores import Chroma

def embed_and_persist_data(api_key='sk-9811d3c24a0f4ff99c63733a080f3aef', persist_dir="updated_tender_vector_store")
    """
    从数据库提取数据，构造中文键值对的公告数据，进行文本嵌入，并持久化到Chroma向量数据库中。
    """
    def sql_query(query):
        try:
            connection = connect_db()
            cursor = connection.cursor()
            cursor.execute(query)
            results = cursor.fetchall()
            columns = [desc[0] for desc in cursor.description]  # 获取列名
            cursor.close()
            connection.close()
            return columns, results  # 确保这里返回列名和结果
        except mysql.connector.Error as err:
            print(f"Database Error occurred: {err}")
            return None, None  


    query = "SELECT * FROM tender_key_detail;"
    columns, raw_results = sql_query(query)
    if columns is None or raw_results is None:
        print("数据库查询失败，无法继续执行。")
        return


    chinese_columns = {
        'tender_id': '招标项目ID',
        'bid_price': '招标金额（元）',
        'construction_duration': '工期（天）',
        'construction_area': '建筑面积（平方米）',
        'construction_cost': '建安费（元）',
        'qualification_type': '企业资质类型',
        'qualification_level': '企业资质等级',
        'qualification_profession': '总监专业类型',
        'title_level': '总监职称',
        'education': '总监学历',
        'performance_requirements': '总监业绩要求',
        'simultaneous_projects_limit': '总监兼任项目限制',
        'qualification_profession_addition': '附加信息'
    }
    results = [list(row) for row in raw_results]
    chinese_column_names = [chinese_columns[col] if col in chinese_columns else col for col in columns]

    tender_announcements = []
    for row in results:
        announcement = {chinese_column_names[i]: value for i, value in enumerate(row)}
        tender_announcements.append(announcement)

    # 准备嵌入数据和元数据
    texts = [json.dumps(announcement,ensure_ascii=False) for announcement in tender_announcements]
    metadata_list = [{"tender_id": announcement["招标项目ID"]} for announcement in tender_announcements]


    # 初始化 DashScope 嵌入模型
    embeddings = DashScopeEmbeddings(
        model="text-embedding-v1", dashscope_api_key=api_key
    )

    # 确保持久化目录存在
    if not os.path.exists(persist_dir):
        os.makedirs(persist_dir)
    else:
        # 清空collection,便于下面初始化向量存储
        vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
        vectorstore.delete_collection()



    # 初始化向量存储，并添加嵌入向量和元数据
    vectorstore = Chroma.from_texts(
        texts=texts,  
        embedding=embeddings, 
        metadatas=metadata_list,  
        persist_directory=persist_dir
    )

    print("数据嵌入及持久化完成。")
    return vectorstore


In [150]:
quest = "总共有多少市政工程甲级资质要求的招标项目？"

samples = sample_entries(quest)

print(samples)

tender_id, bid_price, construction_duration, construction_area, construction_cost, qualification_type, qualification_level, qualification_profession, title_level, education, performance_requirements, simultaneous_projects_limit, qualification_profession_addition
801, 1654300.0, None, None, 1654300.0, 公路工程, 乙级, 公路工程, 高级工程师, None, None, None, None
815, 8148360000.0, None, None, 583463000.0, 公路工程, 乙级, 公路工程, 高级工程师, None, None, None, None
802, 530100.0, 540, None, 56062600.0, 公路工程, 乙级, 公路工程, 工程师, None, None, None, None

