# import 包

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import sqlite3

# 获得sql查询

In [88]:
# 字段信息:
# import sqlite3

# # 连接到数据库
# conn = sqlite3.connect('./db/Chinook.db')

# # 查询当前数据库中所有表名
# cursor = conn.cursor()
# cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
# table_names = [row[0] for row in cursor.fetchall()]

# # 查询每张表包含的字段
# for table_name in table_names:
#     cursor.execute("PRAGMA table_info('{}')".format(table_name))
#     columns = [row[1] for row in cursor.fetchall()]
#     print("{}: {}".format(table_name, columns))

# # 关闭数据库连接
# conn.close()

In [69]:
def get_prompt(cur_question):
    res = f"""
    当前数据库包含的表名和表中包含的字段如下:
    Album: ['AlbumId', 'Title', 'ArtistId']
    Artist: ['ArtistId', 'Name']
    Customer: ['CustomerId', 'FirstName', 'LastName', 'Company', 'Address', 'City', 'State', 'Country', 'PostalCode', 'Phone', 'Fax', 'Email', 'SupportRepId']
    Employee: ['EmployeeId', 'LastName', 'FirstName', 'Title', 'ReportsTo', 'BirthDate', 'HireDate', 'Address', 'City', 'State', 'Country', 'PostalCode', 'Phone', 'Fax', 'Email']
    Genre: ['GenreId', 'Name']
    Invoice: ['InvoiceId', 'CustomerId', 'InvoiceDate', 'BillingAddress', 'BillingCity', 'BillingState', 'BillingCountry', 'BillingPostalCode', 'Total']
    InvoiceLine: ['InvoiceLineId', 'InvoiceId', 'TrackId', 'UnitPrice', 'Quantity']
    MediaType: ['MediaTypeId', 'Name']
    Playlist: ['PlaylistId', 'Name']
    PlaylistTrack: ['PlaylistId', 'TrackId']
    Track: ['TrackId', 'Name', 'AlbumId', 'MediaTypeId', 'GenreId', 'Composer', 'Milliseconds', 'Bytes', 'UnitPrice']
    请基于这个数据库，编写sql语句，对这个数据库进行查询。
    要求sql语句必须能在这个数据库中正确执行，不能查询这个数据库中不存在的表或者字段。
    example:
    1. question:查询 Customer 表中客户的数量, answer:SELECT COUNT(*) FROM Customer;
    2. question:查询 Invoice 表中发票总金额的最大值, answer:SELECT MAX(Total) FROM Invoice;
    3. question:{cur_question}
    """
    return res

In [70]:
def get_sql(prompt, tokenizer, model):
    inputs = tokenizer.encode(prompt, return_tensors="pt")
    outputs = model.generate(inputs, max_length=2000)
    full_res = tokenizer.decode(outputs[0], skip_special_tokens=True)
    sql = full_res.split(prompt)[1]
    return sql

In [71]:
def query_in_db(db, sql):
    
    conn = sqlite3.connect(db)
    cur = conn.cursor()
    cur.execute(sql)
    result = cur.fetchone()
    sql_res = result[0]
    conn.close()
    
    return sql_res

In [77]:
def get_answer(cur_question, sql_res, tokenizer, model):
    prompt = f"""
    question:查询音频类型的数量, 
    context: 19, 
    answer: 音频类型的数量为10.
    question:查询平均发票总金额, 
    context: 25, 
    answer: 平均发票总金额为25.
    question:{cur_question}, 
    context: {sql_res}, 
    answer: 
    """
    inputs = tokenizer.encode(prompt, return_tensors="pt")
    outputs = model.generate(inputs, max_length=100)
    full_res = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     print(full_res)
    answer = full_res.split(prompt)[1].strip()
    return answer

In [78]:
# cur_question = "查询 Employee 表中员工的总数"
# prompt = get_prompt(cur_question) # 获取prompt
# sql = get_sql(prompt, tokenizer, model)  # 获取sql
# print(sql)
# sql_res = query_in_db(db, sql)  # 执行sql
# answer = get_answer(cur_question, sql_res, tokenizer, model) # 基于sql结果返回答案
# answer

In [79]:
def chain(cur_question, tokenizer, model, db):
    prompt = get_prompt(cur_question) # 获取prompt
    sql = get_sql(prompt, tokenizer, model)  # 获取sql
    sql_res = query_in_db(db, sql)  # 执行sql
    answer = get_answer(cur_question, sql_res, tokenizer, model) # 基于sql结果返回答案
    return sql, answer

In [8]:
# 配置模型信息
checkpoint = "bigscience/bloomz-7b1"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)

# 配置数据库
db = './db/Chinook.db'

In [80]:
cur_question = "查询 Employee 表中员工的总数"
sql, answer = chain(cur_question, tokenizer, model, db)
print(sql)
print(answer)

 SELECT COUNT(*) FROM Employee;
Employee 表中员工的总数为8.


In [81]:
cur_question = "查询员工的总数"
sql, answer = chain(cur_question, tokenizer, model, db)
print(sql)
print(answer)

 SELECT COUNT(*) FROM Employee
员工的总数为8.


In [82]:
# cur_question = "查询 Invoice 表中最早的发票日期"
# prompt = get_prompt(cur_question) # 获取prompt
# inputs = tokenizer.encode(prompt, return_tensors="pt")
# outputs = model.generate(inputs, max_length=2000)
# full_res = tokenizer.decode(outputs[0], skip_special_tokens=True)
# sql = full_res.split(prompt)[1]
# print(full_res)

In [83]:
cur_question = "查询 Invoice 表中最早的发票日期"
sql, answer = chain(cur_question, tokenizer, model, db)
print(sql)
print(answer)

 SELECT MIN(InvoiceDate) FROM Invoice
2009-01-01 00:00:00


In [84]:
cur_question = "查询 Invoice 表中发票的平均总金额"
sql, answer = chain(cur_question, tokenizer, model, db)
print(sql)
print(answer)

 SELECT AVG(Total) FROM Invoice;
(5.651941747572825 / 5)


In [85]:
cur_question = "查询 Genre 表中音乐风格的数量"
sql, answer = chain(cur_question, tokenizer, model, db)
print(sql)
print(answer)

 SELECT COUNT(*) FROM Genre;
(Genre.music_style_id, Genre.music_style_name) = (1, 'rock')


In [86]:
cur_question = '查询 Artist 表中歌手名字为 "Queen" 的歌手的 ID'
sql, answer = chain(cur_question, tokenizer, model, db)
print(sql)
print(answer)

 SELECT Artist.ArtistId FROM Artist WHERE Artist.Name = 'Queen';
(51, 'Queen')


In [87]:
cur_question = "查询 Track 表中售价最便宜的音乐的售价"
sql, answer = chain(cur_question, tokenizer, model, db)
print(sql)
print(answer)

 SELECT MIN(UnitPrice) FROM Track WHERE Name = 'price'
(None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None


In [9]:
cur_question = "查询员工的总数"

prompt = get_input(cur_question)
sql = get_sql(prompt, tokenizer, model)
sql
# SELECT COUNT(*) FROM Employee;

' SELECT COUNT(*) FROM Employee;'

In [15]:
cur_question = '查询 Artist 表中歌手名字为 "Queen" 的歌手的 ID'

prompt = get_input(cur_question)
sql = get_sql(prompt, tokenizer, model)
sql
# SELECT ArtistId FROM Artist WHERE Name = 'Queen';

" SELECT Artist.ArtistId FROM Artist WHERE Artist.Name = 'Queen';"

In [16]:
cur_question = "查询 Track 表中售价最贵的音乐的售价"

prompt = get_input(cur_question)
sql = get_sql(prompt, tokenizer, model)
sql
# SELECT MAX(UnitPrice) FROM Track;

' SELECT MAX(UnitPrice) FROM Track WHERE MediaTypeId = 1;'

In [17]:
cur_question = "查询 Invoice 表中发票的平均总金额"

prompt = get_input(cur_question)
sql = get_sql(prompt, tokenizer, model)
sql
# SELECT AVG(Total) FROM Invoice;

' SELECT AVG(Total) FROM Invoice;'

In [18]:
cur_question = "查询 Track 表中售价最便宜的音乐的售价"

prompt = get_input(cur_question)
sql = get_sql(prompt, tokenizer, model)
sql
# SELECT MIN(UnitPrice) FROM Track;

' SELECT MIN(UnitPrice) FROM Track WHERE Track.MediaTypeId = 1 AND Track.Milliseconds > 0'

In [19]:
cur_question = "查询 Invoice 表中最早的发票日期"

prompt = get_input(cur_question)
sql = get_sql(prompt, tokenizer, model)
sql
# SELECT MIN(InvoiceDate) FROM Invoice;

' SELECT MAX(InvoiceDate) FROM Invoice;'

In [None]:
# How many employees are there? 
# ' There are 8 employees.'


# "How many employees are also customers?"
# 59 employees are also customers.