In [20]:
import re
from langchain_core.messages import HumanMessage, AIMessage
from langchain_community.utilities.sql_database import SQLDatabase
from local_chain.selfQuery import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

In [21]:
# 数据库连接信息
db_user = "postgres"
db_password = "postgres"  # 请更换成您的数据库密码
db_host = "localhost"
db_name = "postgres"

# 初始化 langchain 的 SQLDatabase 连接
# db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}/{db_name}")
db = SQLDatabase.from_uri(f"postgresql://{db_user}:{db_password}@{db_host}:5432/{db_name}")
# 输出数据库相关信息，以验证连接是否成功
print("本次数据库类型:", db.dialect)
print("可用的表格:", db.get_usable_table_names())
# print("数据库信息：",db.get_table_info())

本次数据库类型: postgresql
可用的表格: ['assets', 'companies', 'contracts', 'debts', 'departments', 'employees', 'projects', 'transactions']


In [22]:
# clean query 的对于用户问题文本进行清洗
def clean_query(query):
    query = query.replace("SQLResult:", "").replace("```", "").replace("sql", "").replace("SQLQuery:", "").replace("SQL", "").replace("Answer:","")
    # colon_index = query.find(':')
    # if colon_index != -1:
    #     query = query[colon_index + 1:]
    # colon_index_1 = query.find('：')
    # if colon_index_1 != -1:
    #     query = query[colon_index_1 + 1:]

    # # 查找第一个大写的 SELECT，删除 SELECT 之前的所有内容
    # select_match = re.search(r'SELECT', query)
    # if select_match:
    #     query = query[select_match.start():]
    # # 查找 LIMIT 语句，确保是大写，并保留数字，删除其后的所有内容
    # limit_match = re.search(r'LIMIT \d+\s*', query)
    # if limit_match:
    #     # 从数字后的空格开始删除
    #     query = query[:limit_match.end()].strip()

    # semicolon_index = query.find(';')
    # if semicolon_index != -1:
    #     query = query[:semicolon_index]

    # 检查最末尾字符是否为英文句号或中文句号
    while query.endswith('.') or query.endswith('。'):
        # 向前找到最近的大写字母或中文字符
        match = re.search(r'([A-Z]|[^\x00-\x7F])[^A-Z]*[。\.]$', query)
        if match:
            # 删除从找到的字符到句号的所有内容
            query = query[:match.start()]

    return query.strip()  # 移除前后的空白字符以清洁结果

In [23]:
import LLM_CUSTOM
llm = LLM_CUSTOM.CustomLLM(n=10)

In [24]:
template_prompt = PromptTemplate.from_template(
"""
First, Generate query statements in {dialect} format language for the given questions

Second, If there is an exact number in the query statements you generated, please replace it with A placeholder like##CONSTANT##

Third, Use the following format to OutPut:

SQLQuery: SQL Query to run

Only use the following SQL tables:

{table_info}

limit the query's row to {top_k} unless user have clear instruction

Question: {input}
"""
)


inputs = {
    "input": lambda x: x["Question"] + "\nSQLQuery: ",
    "table_info" : lambda x : db.get_table_info(
        table_names=x.get("table_names_to_use")
    ),
    "dialect": lambda x : db.dialect,
    "top_k": lambda x : 5,
}

chain = (
    RunnableLambda(
        lambda x : {
            'Question' : x["Question"]
        }
    )
    |
    RunnablePassthrough.assign(**inputs)
    |
    template_prompt
    |
    llm.bind(stop=["\nSQLResult:"])
    |
    StrOutputParser()
)

In [25]:
chain.invoke({"Question":"Find The most recent time contract's value greater than 20 million yuan and find the name of the company that signed the contract"})

'SQLQuery: SELECT c.company_id, c.name FROM contracts c WHERE c.value > 20000000 ORDER BY c.start_date DESC LIMIT 1'

In [49]:


First_Prompt = PromptTemplate.from_template(
    """
    Replace all constants in the user question with placeholders of column name @COLUMN NAME@,for example @company_name@, and output the processed question

    Question: {input}
    """
)

Second_Prompt = PromptTemplate.from_template(
    """
    First, Generate query statements in {dialect} format language for the given questions

    Third, Use the following format to OutPut, in Raw Text without format:

    SQLQuery: SQL Query to run

    Only use the following SQL tables:

    {table_info}

    limit the query's row to {top_k} unless user have clear instruction

    Question: {input}
    """
)

def _output(text) -> str:
    print(text.text)
    return text

chain = (
    First_Prompt
    |
    _output
    |
    llm
    |
    StrOutputParser()
    |
    RunnableLambda(
        lambda x : {
            "dialect": db.dialect,
            "table_info": db.get_table_info(),
            "top_k" : 5,
            "input": x
        }
    )
    |
    Second_Prompt
    |
    _output
    |
    llm
    |
    StrOutputParser()
)

In [50]:
chain.invoke("Find a company with revenue exceeding 2 million, and find the chairman of this company")


    Replace all constants in the user question with placeholders of column name @COLUMN NAME@,for example @company_name@, and output the processed question

    Question: Find a company with revenue exceeding 2 million, and find the chairman of this company
    

    First, Generate query statements in postgresql format language for the given questions

    Third, Use the following format to OutPut, in Raw Text without format:

    SQLQuery: SQL Query to run

    Only use the following SQL tables:

    
CREATE TABLE assets (
	asset_id SERIAL NOT NULL, 
	company_id INTEGER, 
	description VARCHAR(255), 
	purchase_date DATE, 
	cost NUMERIC, 
	condition VARCHAR(100), 
	CONSTRAINT assets_pkey PRIMARY KEY (asset_id), 
	CONSTRAINT assets_company_id_fkey FOREIGN KEY(company_id) REFERENCES companies (company_id)
)

/*
3 rows from assets table:
asset_id	company_id	description	purchase_date	cost	condition
1	2	Asset 1	2019-01-02	12495	Used
2	3	Asset 2	2019-01-03	128977	Used
3	4	Asset 3	2019-01-04

"SQLQuery: SELECT c.name AS company_name, e.name AS chairman_name FROM companies c JOIN employees e ON c.company_id = e.company_id WHERE c.revenue > @revenue_threshold@ AND c.name = '@company_name@' LIMIT 1;"