In [1]:
import os
import psycopg2
import re

In [2]:
# 数据库连接参数
dbname = ""
user = ""
password = ""
host = ""  # 或者是你 Docker 容器的 IP 地址，如果你在不同的机器上
port = ""

# 连接字符串
conn_string = f"host={host} dbname={dbname} user={user} password={password} port={port}"
# 连接到数据库
conn = psycopg2.connect(conn_string)
cur = conn.cursor()

In [3]:
# 获取表名
def get_table_name(cur):
    # 执行 SQL 查询以获取所有表的列表
    cur.execute("""
        SELECT table_name 
        FROM information_schema.tables 
        WHERE table_schema = 'public';
    """)
    
    # 获取查询结果
    tables = cur.fetchall()
    tname = []
    # 打印表名
    for table in tables:
        tname.append(table[0])
    
    return tname

In [4]:
# 获取数据库列名。
def get_table_columns(cur,table_name):
    try:
        # 执行 SQL 查询以获取表的列名
        cur.execute(f"""
            SELECT column_name 
            FROM information_schema.columns 
            WHERE table_name = '{table_name}';
        """)
        # 获取查询结果
        cols = [desc[0] for desc in cur.fetchall()]
        columns = []
        for col in cols:
            # column = "'"+col+"'"
            columns.append(col)
    except Exception as e:
        print(f"An error occurred: {e}")
    return columns

In [5]:
# 获取数据库表结构，表名对应列名
schema = dict()

table_names = get_table_name(cur)
for name in table_names:
    schema[name] = get_table_columns(cur,name)

In [6]:
def get_column_types(cur, table_name):
    # 执行 SQL 查询以获取表的列名和数据类型
    cur.execute("""
        SELECT column_name, data_type
        FROM information_schema.columns
        WHERE table_name = %s;
    """, (table_name,))

    # 获取查询结果
    results = cur.fetchall()
    return results

In [7]:
col = get_column_types(cur,"tb_process")
print(col)

[('share', 'integer'), ('deleted', 'smallint'), ('update_time', 'timestamp without time zone'), ('create_time', 'timestamp without time zone'), ('create_user', 'bigint'), ('reference_process_id', 'bigint'), ('parent_id', 'bigint'), ('node_type', 'smallint'), ('extend', 'json'), ('id', 'bigint'), ('version', 'bigint'), ('last_change', 'bigint'), ('category_id', 'bigint'), ('infrastructure_process', 'smallint'), ('quantitative_reference_id', 'bigint'), ('location_id', 'bigint'), ('process_doc_id', 'bigint'), ('dq_system_id', 'bigint'), ('exchange_dq_system_id', 'bigint'), ('social_dq_system_id', 'bigint'), ('last_internal_id', 'integer'), ('currency_id', 'bigint'), ('product_system_id', 'bigint'), ('process_source', 'smallint'), ('node', 'jsonb'), ('ref_id', 'character varying'), ('name', 'character varying'), ('default_allocation_method', 'character varying'), ('synonyms', 'text'), ('dq_entry', 'character varying'), ('tags', 'character varying'), ('library', 'character varying'), ('desc

In [20]:
def query_prompt(text, schema):
    # 将表结构信息包含在提示中
    schema_info = "\n".join([f"Table {table}: {', '.join(columns)}" for table, columns in schema.items()])

    # 定义 message 模板
    messages =  f'''你是一个 SQL 生成助手，能够根据用户提供的描述生成符合 SQLite 语法的 SQL 查询。
            请根据以下描述生成一个符合 SQLite 数据库语法的 SQL 查询，并且不能修改给出的数据表列名。
            描述：{text}。
            要求输出的 SQL 以 # 开头，以 # 结尾，参数类型一律按照字符串处理，样例如下：
            #SELECT * FROM table#
            #SELECT COUNT(*) FROM table where Column_name='abc'#
            注意不要输出分析过程和其他内容，直接给出 SQL 语句。'''

    return messages

In [21]:
import pandas as pd
df = pd.read_excel(r"./data/sql_ques_clear.xlsx")
df.head()

Unnamed: 0,SQL,Question,table
0,select count( * ) as total from tb_process whe...,数据库tb_process表中有多少未删除的数据？,tb_process
1,select tags from tb_process where product_syst...,在数据库tb_process表中，product_system_id为12983111603...,tb_process
2,"select id,ref_id,name,synonyms,category_id,pro...",在数据库tb_process表中，process_type为‘unit_process’且p...,tb_process
3,"select id,ref_id,name,synonyms,category_id,des...",在数据库tb_process表中，id为1300755365750636544的记录的id、...,tb_process
4,"select id,ref_id,name,synonyms,category_id,des...",在数据库tb_process表中，id为9237375的记录的id、ref_id、name、...,tb_process


In [22]:
ques_test = "数据库tb_process表中有多少未删除的数据？"
tmp_dict = dict()
tmp_dict["tb_process"] = schema["tb_process"]
# print(tmp_dict)
res = query_prompt(ques_test,tmp_dict)
print("res:",res)

res: 你是一个 SQL 生成助手，能够根据用户提供的描述生成符合 SQLite 语法的 SQL 查询。
            请根据以下描述生成一个符合 SQLite 数据库语法的 SQL 查询，并且不能修改给出的数据表列名。
            描述：数据库tb_process表中有多少未删除的数据？。
            要求输出的 SQL 以 # 开头，以 # 结尾，参数类型一律按照字符串处理，样例如下：
            #SELECT * FROM table#
            #SELECT COUNT(*) FROM table where Column_name='abc'#
            注意不要输出分析过程和其他内容，直接给出 SQL 语句。


In [29]:
import time
sql_gold = []
sql_pred = []
prompts = {}

fail_index = [12, 13, 14, 15, 16, 17, 18, 20, 21, 23, 24, 25, 27, 28, 29, 31, 32, 33, 35, 36, 39, 40, 41, 43, 46, 47, 48, 49, 51, 53, 54, 56, 58, 59, 61, 62, 64, 65]

for index,item in df.iterrows():
    if index not in fail_index:
        continue
    ques = item['Question']
    table = item['table']
    gold = item['SQL']
    
    if table not in table_names:
        print(table)
        continue
    if 'join' in gold:
        continue
        
    colum_type = get_column_types(cur,table)
    tmp = dict()
    tmp[table] = schema[table]
    prompt = query_prompt(ques,tmp)
    prompts[index+1]=prompt
    sql_gold.append(gold)

In [30]:
len(prompts)

38

In [31]:
len(sql_gold)

38

In [32]:
def save_txt(save_list,path):
    with open(path,"w",encoding="utf-8") as f:
        for item in save_list:
            f.write(item+"\n")

In [33]:
def save_txts(save_dict,path):
    with open(path,"w",encoding="utf-8") as f:
        for key,value in save_dict.items():
            f.write(str(key)+" "+value+"\n")

In [34]:
# save_txt(sql_gold,"./data/DeepSeek_enhance/gold.txt")
save_txts(prompts,"./data/DeepSeek/prompts.txt")