https://github.com/aws-samples/aws-ai-ml-workshop-kr/blob/master/genai/aws-gen-ai-kr/06_CodeGeneration/Text2SQL/02_sql_based_qna_w_langchain.ipynb

In [1]:
import boto3
import json

from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_aws import ChatBedrock
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import StrOutputParser
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

### Connect DB

In [2]:
db = SQLDatabase.from_uri("sqlite:///data/department_store.sqlite")
db.get_usable_table_names()

['Addresses',
 'Customer_Addresses',
 'Customer_Orders',
 'Customers',
 'Department_Store_Chain',
 'Department_Stores',
 'Departments',
 'Order_Items',
 'Product_Suppliers',
 'Products',
 'Staff',
 'Staff_Department_Assignments',
 'Supplier_Addresses',
 'Suppliers']

In [3]:
db.run("SELECT * FROM Customers LIMIT 10;")

"[(1, 'Credit Card', '401', 'Ahmed', '75099 Tremblay Port Apt. 163\\nSouth Norrisland, SC 80546', '254-072-4068x33935', 'margarett.vonrueden@example.com'), (2, 'Credit Card', '665', 'Chauncey', '8408 Lindsay Court\\nEast Dasiabury, IL 72656-3552', '+41(8)1897032009', 'stiedemann.sigrid@example.com'), (3, 'Direct Debit', '844', 'Lukas', '7162 Rodolfo Knoll Apt. 502\\nLake Annalise, TN 35791-8871', '197-417-3557', 'joelle.monahan@example.com'), (4, 'Direct Debit', '662', 'Lexus', '9581 Will Flat Suite 272\\nEast Cathryn, WY 30751-4404', '+08(3)8056580281', 'gbrekke@example.com'), (5, 'Credit Card', '848', 'Tara', '5065 Mraz Fields Apt. 041\\nEast Chris, NH 41624', '1-064-498-6609x051', 'nicholas44@example.com'), (6, 'Credit Card', '916', 'Jon', '841 Goyette Unions\\nSouth Dionbury, NC 62021', '(443)013-3112x528', 'cconroy@example.net'), (7, 'Credit Card', '172', 'Cristobal', '8327 Christiansen Lakes Suite 409\\nSchneiderland, IA 93624', '877-150-8674x63517', 'shawna.cummerata@example.net

### Langchain을 통한 DB 연결

https://python.langchain.com/v0.1/docs/use_cases/sql/prompting/

In [4]:
region_name = "us-west-2"

model_kwargs = {  # anthropic
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 2048,
    "temperature": 0,
    "stop_sequences": ["\n\nHuman"]
}

llm = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",  # 파운데이션 모델 지정
    model_kwargs=model_kwargs,
    region_name=region_name,
    streaming=True,
    callbacks=[StreamingStdOutCallbackHandler()]
)  # Claude 속성 구성

In [9]:
chain = create_sql_query_chain(llm, db) | StrOutputParser()

In [10]:
chain.get_prompts()[0]

PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question in

In [11]:
chain.get_prompts()[0].pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

In [None]:
context = db.get_context()
print(list(context))
print(context["table_info"])

In [None]:
prompt_with_context = chain.get_prompts()[0].partial(table_info=context["table_info"])
print(prompt_with_context.pretty_repr()[:1500])

### Fewshot

In [None]:
examples = [
    {
        "input": "What are the ids of the top three products that were purchased in the largest amount?", 
         "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "Give the ids of the three products purchased in the largest amounts.",
        "query": "SELECT product_id FROM product_suppliers ORDER BY total_amount_purchased DESC LIMIT 3"
    },
    {
        "input": "What are the product id and product type of the cheapest product?",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1"
    },
    {
        "input": "Give the id and product type of the product with the lowest price.",
        "query": "SELECT product_id ,  product_type_code FROM products ORDER BY product_price LIMIT 1",
    },
    {
        "input": "Find the number of different product types.",
        "query": "SELECT count(DISTINCT product_type_code) FROM products"
    }
]

In [None]:
example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}")
prompt = FewShotPromptTemplate(
    examples=examples[:5],
    example_prompt=example_prompt,
    prefix="You are a SQLite expert. Given an input question, create a syntactically correct SQLite query to run.\nUnless otherwise specificed, do not return more than {top_k} rows.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries.",
    suffix="User input: {input}\nSQL query: \nSkip the preamble and go straight into the sql.",
    input_variables=["input", "top_k", "table_info"],
)

In [None]:
print(prompt.format(input="How many artists are there?", top_k=3, table_info="foo"))

In [None]:
chain = create_sql_query_chain(llm, db, prompt)
chain.invoke({"question": "How many artists are there?"})

In [None]:
# chain = create_sql_query_chain(llm, db, prompt)
chain = create_sql_query_chain(llm, db, prompt)
result = chain.invoke({"question": "Return the address of customer 10."})
result

In [None]:
execute_query = QuerySQLDataBaseTool(db=db)
chain = create_sql_query_chain(llm, db, prompt)

chain = chain | execute_query
chain.invoke({"question": "Return the address of custor 10."})