In [24]:
import boto3
import json

from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_aws import ChatBedrock

In [47]:
!pip list

Package                       Version
----------------------------- -------------------
aiohttp                       3.9.3
aiosignal                     1.3.1
alabaster                     0.7.16
annotated-types               0.7.0
anthropic                     0.28.1
anyio                         4.3.0
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
astroid                       2.15.8
astropy                       6.0.0
astropy-iers-data             0.2024.3.25.0.29.50
asttokens                     2.4.1
async-lru                     2.0.4
async-timeout                 4.0.3
atomicwrites                  1.4.1
attrs                         23.2.0
audioread                     3.0.1
autopep8                      2.0.4
autovizwidget                 0.21.0
awscli                        1.33.12
Babel                         2.14.0
beautifulsoup4                4.12.3
binaryornot                   0.4.4
bitarray               

### Connect DB

In [20]:
db = SQLDatabase.from_uri("sqlite:///data/department_store.sqlite")
db.get_usable_table_names()

['Addresses',
 'Customer_Addresses',
 'Customer_Orders',
 'Customers',
 'Department_Store_Chain',
 'Department_Stores',
 'Departments',
 'Order_Items',
 'Product_Suppliers',
 'Products',
 'Staff',
 'Staff_Department_Assignments',
 'Supplier_Addresses',
 'Suppliers']

In [21]:
db.run("SELECT * FROM Customers LIMIT 10;")

"[(1, 'Credit Card', '401', 'Ahmed', '75099 Tremblay Port Apt. 163\\nSouth Norrisland, SC 80546', '254-072-4068x33935', 'margarett.vonrueden@example.com'), (2, 'Credit Card', '665', 'Chauncey', '8408 Lindsay Court\\nEast Dasiabury, IL 72656-3552', '+41(8)1897032009', 'stiedemann.sigrid@example.com'), (3, 'Direct Debit', '844', 'Lukas', '7162 Rodolfo Knoll Apt. 502\\nLake Annalise, TN 35791-8871', '197-417-3557', 'joelle.monahan@example.com'), (4, 'Direct Debit', '662', 'Lexus', '9581 Will Flat Suite 272\\nEast Cathryn, WY 30751-4404', '+08(3)8056580281', 'gbrekke@example.com'), (5, 'Credit Card', '848', 'Tara', '5065 Mraz Fields Apt. 041\\nEast Chris, NH 41624', '1-064-498-6609x051', 'nicholas44@example.com'), (6, 'Credit Card', '916', 'Jon', '841 Goyette Unions\\nSouth Dionbury, NC 62021', '(443)013-3112x528', 'cconroy@example.net'), (7, 'Credit Card', '172', 'Cristobal', '8327 Christiansen Lakes Suite 409\\nSchneiderland, IA 93624', '877-150-8674x63517', 'shawna.cummerata@example.net

### Langchain을 통한 DB 연결

https://python.langchain.com/v0.1/docs/use_cases/sql/prompting/

In [23]:
region_name = "us-west-2"

model_kwargs = {  # anthropic
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 2048,
    "temperature": 0,
}

llm = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",  # 파운데이션 모델 지정
    model_kwargs=model_kwargs,
    region_name=region_name,
)  # Claude 속성 구성

In [26]:
chain = create_sql_query_chain(llm, db)

In [None]:
type(chain.get_prompts()[0])

langchain_core.prompts.few_shot.FewShotPromptTemplate

In [31]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [32]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE "Addresses" (
	address_id INTEGER, 
	address_details VARCHAR(255), 
	PRIMARY KEY (address_id)
)

/*
3 rows from Addresses table:
address_id	address_details
1	28481 Crist Circle
East Burdettestad, IA 21232
2	0292 Mitchel Pike
Port Abefurt, IA 84402-4249
3	4062 Mante Place
West Lindsey, DE 76199-8015
*/


CREATE TABLE "Customer_Addresses" (
	customer_id INTEGER NOT NULL, 
	address_id INTEGER NOT NULL, 
	date_from DATETIME NOT NULL, 
	date_to DATETIME, 
	PRIMARY KEY (customer_id, address_id), 
	FOREIGN KEY(customer_id) REFERENCES "Customers" (customer_id), 
	FOREIGN KEY(address_id) REFERENCES "Addresses" (address_id)
)

/*
3 rows from Customer_Addresses table:
customer_id	address_id	date_from	date_to
2	9	2017-12-11 05:00:22	2018-03-20 20:52:34
1	6	2017-10-07 23:00:26	2018-02-28 14:53:52
10	8	2017-04-04 20:00:27	2018-02-27 20:08:33
*/


CREATE TABLE "Customer_Orders" (
	order_id INTEGER, 
	customer_id INTEGER NOT NULL, 
	order_status_code VARCHAR

In [35]:
prompt_with_context = chain.get_prompts()[0].partial(table_info=context["table_info"])
print(prompt_with_context.pretty_repr()[:1500])

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

### Fewshot

In [37]:
examples = [
    {
        "input": "What are the ids of the top three products that were purchased in the largest amount?", 
         "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "Give the ids of the three products purchased in the largest amounts.",
        "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "What are the product id and product type of the cheapest product?",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1"
    },
    {
        "input": "Give the id and product type of the product with the lowest price.",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1",
    },
    {
        "input": "Find the number of different product types.",
        "query": "SELECT count(DISTINCT product_type_code) FROM products"
    }
]

In [38]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    examples=examples[:5],
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: ",
    input_variables=["input", "top_k", "table_info"],
)

In [40]:
print(prompt.format(input="How many artists are there?", top_k=3, table_info="foo"))

You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run. Unless otherwise specificed, do not return more than 3 rows.

Here is the relevant table info: foo

Below are a number of examples of questions and their corresponding SQL queries.

User input: What are the ids of the top three products that were purchased in the largest amount?
SQL query: SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3

User input: Give the ids of the three products purchased in the largest amounts.
SQL query: SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3

User input: What are the product id and product type of the cheapest product?
SQL query: SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1

User input: Give the id and product type of the product with the lowest price.
SQL query: SELECT product_id ,  product_type_code FROM products ORDER BY product_price L

In [41]:
chain = create_sql_query_chain(llm, db, prompt)
chain.invoke({"question": "how many artists are there?"})

'Unfortunately, there is no table related to artists in the provided database schema. The tables seem to be related to a department store chain, customers, orders, products, suppliers, staff, and addresses. Without a table containing artist information, I cannot provide a meaningful SQL query to count the number of artists.'

In [44]:
chain = create_sql_query_chain(llm, db, prompt)
result = chain.invoke({"question": "Return the address of customer 10."})
result

'To return the address of customer with customer_id 10, we can join the Customers and Customer_Addresses tables and filter for customer_id = 10:\n\nSELECT a.address_details\nFROM Customers c\nJOIN Customer_Addresses ca ON c.customer_id = ca.customer_id  \nJOIN Addresses a ON ca.address_id = a.address_id\nWHERE c.customer_id = 10\nLIMIT 1;\n\nThis query joins the Customers, Customer_Addresses, and Addresses tables to get the address_details for the given customer_id. The LIMIT 1 ensures we only return one row.'

In [45]:
from io import StringIO
import sys
import textwrap


def print_ww(*args, width: int = 100, **kwargs):
    """Like print(), but wraps output to `width` characters (default 100)"""
    buffer = StringIO()
    try:
        _stdout = sys.stdout
        sys.stdout = buffer
        print(*args, **kwargs)
        output = buffer.getvalue()
    finally:
        sys.stdout = _stdout
    for line in output.splitlines():
        print("\n".join(textwrap.wrap(line, width=width)))

In [46]:
print_ww(result)

To return the address of customer with customer_id 10, we can join the Customers and
Customer_Addresses tables and filter for customer_id = 10:

SELECT a.address_details
FROM Customers c
JOIN Customer_Addresses ca ON c.customer_id = ca.customer_id
JOIN Addresses a ON ca.address_id = a.address_id
WHERE c.customer_id = 10
LIMIT 1;

This query joins the Customers, Customer_Addresses, and Addresses tables to get the address_details
for the given customer_id. The LIMIT 1 ensures we only return one row.
