In [1]:
import boto3
import json

from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_aws import ChatBedrock
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import StrOutputParser
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

In [3]:
### Connect DB & DB Info

db = SQLDatabase.from_uri("sqlite:///data/department_store.sqlite", sample_rows_in_table_info=3)
db.get_usable_table_names()

db.run("SELECT * FROM Customers LIMIT 10;")

context = db.get_context()
print(list(context))
print(context["table_info"])

In [4]:
db.run("SELECT * FROM Customers LIMIT 10;")

"[(1, 'Credit Card', '401', 'Ahmed', '75099 Tremblay Port Apt. 163\\nSouth Norrisland, SC 80546', '254-072-4068x33935', 'margarett.vonrueden@example.com'), (2, 'Credit Card', '665', 'Chauncey', '8408 Lindsay Court\\nEast Dasiabury, IL 72656-3552', '+41(8)1897032009', 'stiedemann.sigrid@example.com'), (3, 'Direct Debit', '844', 'Lukas', '7162 Rodolfo Knoll Apt. 502\\nLake Annalise, TN 35791-8871', '197-417-3557', 'joelle.monahan@example.com'), (4, 'Direct Debit', '662', 'Lexus', '9581 Will Flat Suite 272\\nEast Cathryn, WY 30751-4404', '+08(3)8056580281', 'gbrekke@example.com'), (5, 'Credit Card', '848', 'Tara', '5065 Mraz Fields Apt. 041\\nEast Chris, NH 41624', '1-064-498-6609x051', 'nicholas44@example.com'), (6, 'Credit Card', '916', 'Jon', '841 Goyette Unions\\nSouth Dionbury, NC 62021', '(443)013-3112x528', 'cconroy@example.net'), (7, 'Credit Card', '172', 'Cristobal', '8327 Christiansen Lakes Suite 409\\nSchneiderland, IA 93624', '877-150-8674x63517', 'shawna.cummerata@example.net

In [6]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE "Addresses" (
	address_id INTEGER, 
	address_details VARCHAR(255), 
	PRIMARY KEY (address_id)
)

/*
3 rows from Addresses table:
address_id	address_details
1	28481 Crist Circle
East Burdettestad, IA 21232
2	0292 Mitchel Pike
Port Abefurt, IA 84402-4249
3	4062 Mante Place
West Lindsey, DE 76199-8015
*/


CREATE TABLE "Customer_Addresses" (
	customer_id INTEGER NOT NULL, 
	address_id INTEGER NOT NULL, 
	date_from DATETIME NOT NULL, 
	date_to DATETIME, 
	PRIMARY KEY (customer_id, address_id), 
	FOREIGN KEY(customer_id) REFERENCES "Customers" (customer_id), 
	FOREIGN KEY(address_id) REFERENCES "Addresses" (address_id)
)

/*
3 rows from Customer_Addresses table:
customer_id	address_id	date_from	date_to
2	9	2017-12-11 05:00:22	2018-03-20 20:52:34
1	6	2017-10-07 23:00:26	2018-02-28 14:53:52
10	8	2017-04-04 20:00:27	2018-02-27 20:08:33
*/


CREATE TABLE "Customer_Orders" (
	order_id INTEGER, 
	customer_id INTEGER NOT NULL, 
	order_status_code VARCHAR

### Langchain을 통한 DB 연결

https://python.langchain.com/v0.1/docs/use_cases/sql/prompting/

In [7]:
region_name = "us-west-2"

model_kwargs = {  # anthropic
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 2048,
    "temperature": 0,
    "stop_sequences": ["\n\nHuman"]
}

llm = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",  # 파운데이션 모델 지정
    model_kwargs=model_kwargs,
    region_name=region_name,
    streaming=True,
    callbacks=[StreamingStdOutCallbackHandler()]
)  # Claude 속성 구성

In [8]:
chain = create_sql_query_chain(llm, db) | StrOutputParser()

In [9]:
chain.get_prompts()[0]

PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question in

In [10]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [11]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE "Addresses" (
	address_id INTEGER, 
	address_details VARCHAR(255), 
	PRIMARY KEY (address_id)
)

/*
3 rows from Addresses table:
address_id	address_details
1	28481 Crist Circle
East Burdettestad, IA 21232
2	0292 Mitchel Pike
Port Abefurt, IA 84402-4249
3	4062 Mante Place
West Lindsey, DE 76199-8015
*/


CREATE TABLE "Customer_Addresses" (
	customer_id INTEGER NOT NULL, 
	address_id INTEGER NOT NULL, 
	date_from DATETIME NOT NULL, 
	date_to DATETIME, 
	PRIMARY KEY (customer_id, address_id), 
	FOREIGN KEY(customer_id) REFERENCES "Customers" (customer_id), 
	FOREIGN KEY(address_id) REFERENCES "Addresses" (address_id)
)

/*
3 rows from Customer_Addresses table:
customer_id	address_id	date_from	date_to
2	9	2017-12-11 05:00:22	2018-03-20 20:52:34
1	6	2017-10-07 23:00:26	2018-02-28 14:53:52
10	8	2017-04-04 20:00:27	2018-02-27 20:08:33
*/


CREATE TABLE "Customer_Orders" (
	order_id INTEGER, 
	customer_id INTEGER NOT NULL, 
	order_status_code VARCHAR

In [12]:
prompt_with_context = chain.get_prompts()[0].partial(table_info=context["table_info"])
print(prompt_with_context.pretty_repr())

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

### Fewshot

In [13]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field


class SQLOutput(BaseModel):
    question: str = Field(description="User Question")
    sql_query: str = Field(description="SQL Query to run")
    
output_parser = JsonOutputParser(pydantic_object=SQLOutput)

In [14]:
examples = [
    {
        "input": "What are the ids of the top three products that were purchased in the largest amount?", 
         "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "Give the ids of the three products purchased in the largest amounts.",
        "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "What are the product id and product type of the cheapest product?",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1"
    },
    {
        "input": "Give the id and product type of the product with the lowest price.",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1",
    },
    {
        "input": "Find the number of different product types.",
        "query": "SELECT count(DISTINCT product_type_code) FROM products"
    }
]

In [15]:
prefix_template = """
You are a SQLite expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
{table_info}

Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Below are a number of examples of questions and their corresponding SQL queries.
"""

suffix_template = """
User input: {input}
SQL query: 

{output_format_instruction}
SQL query part should be written in one line.

If you don't have any appropriate answer, just return json with empty string values.

Skip the preamble and go straight into json
"""

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    examples=examples[:5],
    example_prompt=example_prompt,
    prefix=prefix_template,
    suffix=suffix_template,
    input_variables=["input", "top_k", "table_info", "dialect"],
    partial_variables={"output_format_instruction": output_parser.get_format_instructions()}
)

In [16]:
write_query = create_sql_query_chain(llm, db, prompt)

In [17]:
write_query.get_prompts()[0].pretty_print()


You are a SQLite expert. Given an input question, first create a syntactically correct sqlite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Write an initial draft of the qu

In [18]:
question = "How many customers are there?"

### 잘되는 쿼리 실행

In [19]:
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

chain = write_query | output_parser
result = chain.invoke({"question": question})

query_result = db.run(result["sql_query"])
query_result

advanced_query_generation_template='''

Human: You are a expert.
Given the following user Question, corresponding SQLQuery, and SQLResult, Answer the user Question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: 

Assistant: 
'''

answer_prompt = PromptTemplate.from_template(
    advanced_query_generation_template,
    partial_variables={"query": result["sql_query"], "result": query_result}
)

answer = answer_prompt | llm | StrOutputParser()

response = answer.invoke({"question": question})

print ("\n================")
print ("\n\nRESPONSE:\n")
print (response)

{"question": "How many customers are there?", "sql_query": "SELECT COUNT(*) FROM \"Customers\""}

### 없는 결과 쿼리 예시

In [23]:
write_query.invoke({"question": "Return the backpacker's loving item"})

{"question": "Return the backpacker's loving item", "sql_query": ""}

'{"question": "Return the backpacker\'s loving item", "sql_query": ""}'

## Evaluation

In [None]:
import json
from typing import List, Tuple
from sqlalchemy import create_engine

def load_spider_dataset(file_path: str) -> List[dict]:
    with open(file_path, 'r') as f:
        dataset = json.load(f)
    return dataset

def calculate_exact_match_accuracy(predictions: List[str], references: List[str]) -> float:
    correct = sum(p == r for p, r in zip(predictions, references))
    return correct / len(references)

def execute_query(database: str, query: str):
    engine = create_engine(f'sqlite:///{database}')
    with engine.connect() as connection:
        result = connection.execute(query).fetchall()
    return result

def calculate_execution_accuracy(database: str, predictions: List[str], references: List[str]) -> float:
    correct = 0
    for pred_query, ref_query in zip(predictions, references):
        try:
            pred_result = execute_query(database, pred_query)
            ref_result = execute_query(database, ref_query)
            if pred_result == ref_result:
                correct += 1
        except Exception as e:
            # 쿼리 실행 실패 시, 해당 쿼리는 틀린 것으로 간주
            print(f"Query execution error: {e}")
    return correct / len(references)

# Spider 데이터셋 로드 예제
spider_dataset = load_spider_dataset('path_to_spider_dataset.json')

# 예측 및 참조 쿼리 생성 예제
# 여기서는 단순한 리스트로 예시를 제공
predicted_queries = [
    "SELECT * FROM Customers WHERE customer_id = 1",
    "SELECT store_name FROM Department_Stores WHERE dept_store_id = 3"
    # 모델 예측 결과 추가
]

reference_queries = [
    "SELECT * FROM Customers WHERE customer_id = 1",
    "SELECT store_name FROM Department_Stores WHERE dept_store_id = 3"
    # 정답 쿼리 추가
]

# 정확도 계산
exact_match_accuracy = calculate_exact_match_accuracy(predicted_queries, reference_queries)
print(f"Exact Match Accuracy: {exact_match_accuracy:.2%}")

# 실행 기반 정확도 계산 (데이터베이스 파일 경로 필요)
database_path = 'path_to_your_database.db'
execution_accuracy = calculate_execution_accuracy(database_path, predicted_queries, reference_queries)
print(f"Execution Accuracy: {execution_accuracy:.2%}")

In [29]:
question =  "How many customers are there?"

In [31]:
result = chain.invoke({"question": question})
query_result = db.run(result["sql_query"])
query_result


{"question": "How many customers are there?", "sql_query": "SELECT COUNT(*) FROM \"Customers\""}

'[(15,)]'

In [35]:
final_pred = eval(query_result)

In [39]:
final_pred = final_pred[0][0]