In [4]:
import boto3
import json

from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_aws import ChatBedrock
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import StrOutputParser
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

In [5]:
### Connect DB & DB Info

db = SQLDatabase.from_uri("sqlite:///data/department_store.sqlite", sample_rows_in_table_info=3)
db.get_usable_table_names()

db.run("SELECT * FROM Customers LIMIT 10;")

context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE "Addresses" (
	address_id INTEGER, 
	address_details VARCHAR(255), 
	PRIMARY KEY (address_id)
)

/*
3 rows from Addresses table:
address_id	address_details
1	28481 Crist Circle
East Burdettestad, IA 21232
2	0292 Mitchel Pike
Port Abefurt, IA 84402-4249
3	4062 Mante Place
West Lindsey, DE 76199-8015
*/


CREATE TABLE "Customer_Addresses" (
	customer_id INTEGER NOT NULL, 
	address_id INTEGER NOT NULL, 
	date_from DATETIME NOT NULL, 
	date_to DATETIME, 
	PRIMARY KEY (customer_id, address_id), 
	FOREIGN KEY(customer_id) REFERENCES "Customers" (customer_id), 
	FOREIGN KEY(address_id) REFERENCES "Addresses" (address_id)
)

/*
3 rows from Customer_Addresses table:
customer_id	address_id	date_from	date_to
2	9	2017-12-11 05:00:22	2018-03-20 20:52:34
1	6	2017-10-07 23:00:26	2018-02-28 14:53:52
10	8	2017-04-04 20:00:27	2018-02-27 20:08:33
*/


CREATE TABLE "Customer_Orders" (
	order_id INTEGER, 
	customer_id INTEGER NOT NULL, 
	order_status_code VARCHAR

### Langchain을 통한 DB 연결

https://python.langchain.com/v0.1/docs/use_cases/sql/prompting/

In [8]:
region_name = "us-west-2"

model_kwargs = {  # anthropic
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 2048,
    "temperature": 0,
    "stop_sequences": ["\n\nHuman"]
}

llm = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",  # 파운데이션 모델 지정
    model_kwargs=model_kwargs,
    region_name=region_name,
    streaming=True,
    callbacks=[StreamingStdOutCallbackHandler()]
)  # Claude 속성 구성

### Fewshot

In [14]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field


class SQLOutput(BaseModel):
    question: str = Field(description="User Question")
    sql_query: str = Field(description="SQL Query to run")
    
output_parser = JsonOutputParser(pydantic_object=SQLOutput)

In [15]:
examples = [
    {
        "input": "What are the ids of the top three products that were purchased in the largest amount?", 
         "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "Give the ids of the three products purchased in the largest amounts.",
        "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "What are the product id and product type of the cheapest product?",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1"
    },
    {
        "input": "Give the id and product type of the product with the lowest price.",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1",
    },
    {
        "input": "Find the number of different product types.",
        "query": "SELECT count(DISTINCT product_type_code) FROM products"
    }
]

In [16]:
prefix_template = """
You are a SQLite expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
{table_info}

Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Below are a number of examples of questions and their corresponding SQL queries.
"""

suffix_template = """
User input: {input}
SQL query: 

{output_format_instruction}
SQL query part should be written in one line.

If you don't have any appropriate answer, just return json with empty string values.

Skip the preamble and go straight into json
"""

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    examples=examples[:5],
    example_prompt=example_prompt,
    prefix=prefix_template,
    suffix=suffix_template,
    input_variables=["input", "top_k", "table_info", "dialect"],
    partial_variables={"output_format_instruction": output_parser.get_format_instructions()}
)

In [18]:
write_query = create_sql_query_chain(llm, db, prompt)
write_query.get_prompts()[0].pretty_print()


You are a SQLite expert. Given an input question, first create a syntactically correct sqlite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Write an initial draft of the qu

### 잘되는 쿼리 실행

In [None]:
question = "How many customers are there?"

In [22]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

chain = write_query | output_parser
result = chain.invoke({"question": question})

query_result = db.run(result["sql_query"])
query_result

{"question": "How many customers are there?", "sql_query": "SELECT COUNT(*) FROM \"Customers\""}

In [23]:
result

{'question': 'How many customers are there?',
 'sql_query': 'SELECT COUNT(*) FROM "Customers"'}

In [21]:
chain.invoke({"question": question})

{"question": "How many customers are there?", "sql_query": "SELECT COUNT(*) FROM \"Customers\""}

{'question': 'How many customers are there?',
 'sql_query': 'SELECT COUNT(*) FROM "Customers"'}

In [24]:


advanced_query_generation_template='''

Human: You are a expert.
Given the following user Question, corresponding SQLQuery, and SQLResult, Answer the user Question.
|
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: 

Assistant: 
'''

answer_prompt = PromptTemplate.from_template(
    advanced_query_generation_template,
    partial_variables={"query": result["sql_query"], "result": query_result}
)

answer = answer_prompt | llm | StrOutputParser()

response = answer.invoke({"question": question})

print ("\n================")
print ("\n\nRESPONSE:\n")
print (response)

{"question": "How many customers are there?", "sql_query": "SELECT COUNT(*) FROM \"Customers\""}Based on the provided SQL query `SELECT COUNT(*) FROM "Customers"` and the SQL result `[(15,)]`, the answer to the question "How many customers are there?" is:

Answer: There are 15 customers.

The `COUNT(*)` function in SQL is used to count the number of rows in a table. When applied to the "Customers" table, it returns the total number of rows, which represents the number of customers. The result `[(15,)]` indicates that the count is 15, so there are 15 customers in the database.


RESPONSE:

Based on the provided SQL query `SELECT COUNT(*) FROM "Customers"` and the SQL result `[(15,)]`, the answer to the question "How many customers are there?" is:

Answer: There are 15 customers.

The `COUNT(*)` function in SQL is used to count the number of rows in a table. When applied to the "Customers" table, it returns the total number of rows, which represents the number of customers. The result `[

## Evaluation QA set 생성

In [26]:
## 전체 테스트 데이터에서 target db 관련된것만 추출 
target_db = 'department_store'
with open('./evaluation/examples/train_spider.json', 'rb') as ofp:
    train = json.load(ofp)

res = []
for t in train:
    if t['db_id'] == target_db:
        res.append((t['question'], t['query']))


# question file 저장할 파일
question_path = './evaluation/question.txt'

# answer file  저장할 파일
answer_path = './evaluation/answer.txt'

# 파일1과 파일2에 데이터를 쓰는 함수
def write_to_files(data, file1_path, file2_path):
    with open(file1_path, 'w') as file1, open(file2_path, 'w') as file2:
        for item in data:
            file1.write(item[0] + '\n')
            file2.write(item[1]+f'\t{target_db}' + '\n')

# 함수 호출하여 파일 생성
write_to_files(res, question_path, answer_path)

print(f"Data has been written to {question_path} and {answer_path}")


[nltk_data] Downloading package punkt to /home/ec2-user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

### pred data 생성

In [24]:
import re

def normalize_sql(query):
    # Remove double quotes from column and table names
    query = re.sub(r'"(\w+)"', r'\1', query)
    # Normalize whitespace
    query = re.sub(r'\s+', ' ', query).strip()
    return query



# question file 저장할 파일
question_path = './evaluation/question.txt'

# answer file  저장할 파일
answer_path = './evaluation/answer.txt'

with open(question_path) as f:
    qlist = [l.strip().split('\t')[0] for l in f.readlines() if len(l.strip()) > 0]


test = False
file_prefix = 'All_ver2'
if test:
    qlist = qlist[:5]
    

In [None]:
pred = []

for q in qlist:
    print('question===>', q)
    p = chain.invoke({"question": q})
    pred.append(normalize_sql(p['sql_query']))

## save file

pred_path = f'./evaluation/{file_prefix}_pred.txt'
print(pred_path)
with open(pred_path, 'w') as file1:
    for item in pred:
        file1.write(item + '\n')


question===> What are the ids of the top three products that were purchased in the largest amount?
{"question": "What are the ids of the top three products that were purchased in the largest amount?", "sql_query": "SELECT \"product_id\" FROM \"Product_Suppliers\" ORDER BY \"total_amount_purchased\" DESC LIMIT 3;"}question===> Give the ids of the three products purchased in the largest amounts.
{"question": "Give the ids of the three products purchased in the largest amounts.", "sql_query": "SELECT \"product_id\" FROM \"Product_Suppliers\" ORDER BY \"total_amount_purchased\" DESC LIMIT 3;"}question===> What are the product id and product type of the cheapest product?
{"question": "What are the product id and product type of the cheapest product?", "sql_query": "SELECT \"product_id\", \"product_type_code\" FROM \"Products\" ORDER BY \"product_price\" LIMIT 1;"}question===> Give the id and product type of the product with the lowest price.
{"question": "Give the id and product type of the

### get score

In [None]:
import subprocess

# Python 파일 실행 명령어
command = f'''python evaluation.py --gold '{answer_path}' --pred '{pred_path}' --db './data/database' --table './tables.json' --etype exec '''

print(command)
# subprocess 모듈을 사용하여 명령어 실행
result = subprocess.run(command, shell=True, capture_output=True, text=True)

# 출력 결과 확인
print(result.stdout)
print(result.stderr)


In [3]:
import subprocess

# Python 파일 실행 명령어
command = f'''python evaluation.py --gold '{answer_path}' --pred '{pred_path}' --db './data/database' --table './tables.json' --etype exec '''

print(command)
# subprocess 모듈을 사용하여 명령어 실행
result = subprocess.run(command, shell=True, capture_output=True, text=True)

# 출력 결과 확인
print(result.stdout)
print(result.stderr)


python evaluation.py --gold './evaluation/answer.txt' --pred './evaluation/All_pred.txt' --db './data/database' --table './tables.json' --etype exec 
[pred]
[('7308 Joan Lake Suite 346\nLizethtown, DE 56522',)]
SELECT customer_address FROM Customers WHERE customer_id = 10;
[gold]
[("36594 O'Keefe Lock\nNew Cali, RI 42319",)]
SELECT T1.address_details FROM addresses AS T1 JOIN customer_addresses AS T2 ON T1.address_id  =  T2.address_id WHERE T2.customer_id  =  10
eval_err_num:1
------------------
[pred]
[('7308 Joan Lake Suite 346\nLizethtown, DE 56522',)]
SELECT c.customer_address FROM Customers c WHERE c.customer_id = 10 LIMIT 1;
[gold]
[("36594 O'Keefe Lock\nNew Cali, RI 42319",)]
SELECT T1.address_details FROM addresses AS T1 JOIN customer_addresses AS T2 ON T1.address_id  =  T2.address_id WHERE T2.customer_id  =  10
eval_err_num:2
------------------
[pred]
[(13, 3)]
SELECT product_id, COUNT(*) AS order_count FROM Order_Items GROUP BY product_id ORDER BY order_count DESC LIMIT 1;
[g

In [9]:
a = db.run(''' SELECT customer_id, customer_name FROM Customers WHERE customer_address LIKE '%WY%' AND payment_method_code <> 'Credit Card' ''')

In [10]:
b = db.run(''' SELECT customer_id ,  customer_name FROM customers WHERE customer_address LIKE "%WY%" AND payment_method_code != "Credit Card" ''')
