In [1]:
import re
from langchain_core.messages import HumanMessage, AIMessage
from langchain_community.utilities.sql_database import SQLDatabase
from local_chain.selfQuery import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

successful load embedding model


In [2]:
# 数据库连接信息
db_user = "postgres"
db_password = "postgres"  # 请更换成您的数据库密码
db_host = "localhost"
db_name = "postgres"

# 初始化 langchain 的 SQLDatabase 连接
# db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}/{db_name}")
db = SQLDatabase.from_uri(f"postgresql://{db_user}:{db_password}@{db_host}:5432/{db_name}")
# 输出数据库相关信息，以验证连接是否成功
print("本次数据库类型:", db.dialect)
print("可用的表格:", db.get_usable_table_names())
# print("数据库信息：",db.get_table_info())

本次数据库类型: postgresql
可用的表格: ['assets', 'companies', 'contracts', 'debts', 'departments', 'employees', 'projects', 'transactions']


In [3]:
# clean query 的对于用户问题文本进行清洗
def clean_query(query):
    query = query.replace("SQLResult:", "").replace("```", "").replace("sql", "").replace("SQLQuery:", "").replace("SQL", "").replace("Answer:","")
    # colon_index = query.find(':')
    # if colon_index != -1:
    #     query = query[colon_index + 1:]
    # colon_index_1 = query.find('：')
    # if colon_index_1 != -1:
    #     query = query[colon_index_1 + 1:]

    # # 查找第一个大写的 SELECT，删除 SELECT 之前的所有内容
    # select_match = re.search(r'SELECT', query)
    # if select_match:
    #     query = query[select_match.start():]
    # # 查找 LIMIT 语句，确保是大写，并保留数字，删除其后的所有内容
    # limit_match = re.search(r'LIMIT \d+\s*', query)
    # if limit_match:
    #     # 从数字后的空格开始删除
    #     query = query[:limit_match.end()].strip()

    # semicolon_index = query.find(';')
    # if semicolon_index != -1:
    #     query = query[:semicolon_index]

    # 检查最末尾字符是否为英文句号或中文句号
    while query.endswith('.') or query.endswith('。'):
        # 向前找到最近的大写字母或中文字符
        match = re.search(r'([A-Z]|[^\x00-\x7F])[^A-Z]*[。\.]$', query)
        if match:
            # 删除从找到的字符到句号的所有内容
            query = query[:match.start()]

    return query.strip()  # 移除前后的空白字符以清洁结果

In [4]:
import LLM_CUSTOM
llm = LLM_CUSTOM.CustomLLM(n=10)

In [5]:
# 创建数据库查询链
from langchain_core.prompts import BasePromptTemplate
template = '''
Given an input question, create a syntactically correct {dialect} query to run. Use the following format to OutPut:

SQLQuery: SQL Query to run

limit the query's row to {top_k} unless user have clear instruction

Pay attention to follow the relevant rules. 

{relevant_msg}

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:

{table_info}.

You can order the results by a relevant column to return the most interesting examples in the database.


Question: {input}
'''

prompt = PromptTemplate.from_template(template)
generate_query = create_sql_query_chain(llm, db, prompt=prompt)

execute_query = QuerySQLDataBaseTool(db=db)

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: 
"""
)


# 定义链式调用
chain = (
    RunnableLambda(
        lambda context: {
            'query': generate_query.invoke(context), 
            'question': context['question'],
            'print': print("源问题:", context)
        }) # 生成并清理查询，保持问题文本
    |  
    RunnableLambda(
        lambda context: {
            'query': clean_query(context['query']), 
            'question': context['question'],
            'print': print("源生成答案:", context['query'])
        }) # 生成并清理查询，保持问题文本
    |  
    RunnableLambda(
        lambda context: {
            'query': context['query'], 
            'question': context['question'], 
            'print': print("SQL查询语句:", context['query'])
        }) # 打印查询，继续传递'query'和'question'
    |  
    RunnableLambda(
        lambda context: {
            'result': execute_query.invoke(context['query']), 
            'query': context['query'], 
            'question': context['question']
        }) # 执行查询，维持'query', 'result'和'question'
    |  
    RunnableLambda(
        lambda context: {
            'question': context['question'], 
            'query': context['query'], 
            'result': context['result']
        }) # 保持完整的上下文
    |  
    answer_prompt   # 应用最终的重构答案处理
    | 
    llm
    | 
    StrOutputParser() 
)

In [6]:
import dashscope
import os
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
dashscope.api_key=DASHSCOPE_API_KEY

In [8]:
# 调用链以生成最终答案
# 问题点 1  不同表中内容id重名 
final_answer = chain.invoke({"question": "Find The most recent time contract's value greater than 20 million yuan and find the name of the company that signed the contract"})
print("千问回答:", final_answer)

text="\nGiven an input question, create a syntactically correct postgresql query to run. Use the following format to OutPut:\n\nSQLQuery: SQL Query to run\n\nlimit the query's row to 5 unless user have clear instruction\n\nPay attention to follow the relevant rules. \n\nthe price column in the contracts table have been divide 10000, when compare with constant please multiply 10000 to recovery\n    The unit of the revenue column in the companies table is RMB\n    \n    In the transaction table, the unit of total amount is yuan\n    \n\nNever query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\nOnly use the following tables:\n\n\nCREATE TABLE assets (\n\tasset_id SERIAL NOT NULL, \n\tcompany_id INTEGER, \n\tdescription VARCHAR(255), 