# Text-to-SQL using Tool Use
이 섹션에서는 AWS와 상호 작용하기 위한 boto3, 로그를 캡처하기 위한 logging, 그리고 오류 처리를 위한 botocore를 포함한 필요한 라이브러리를 임포트합니다.

In [None]:
!pip install "boto3>=1.34.116"
!pip install -q langchain langchain-aws langchain-core langchain-community
# Restart Kernel after installation

In [1]:
import boto3
import logging
from botocore.config import Config
from botocore.exceptions import ClientError


class StationNotFoundError(Exception):
    """Raised when a radio station isn't found."""
    pass


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

### Amazon Bedrock 클라이언트 생성
이 섹션에서는 재시도 설정으로 AWS 클라이언트를 구성하고 'bedrock-runtime' 서비스에 대한 boto3 클라이언트를 생성합니다.

In [3]:
region_name = 'us-east-1'
retry_config = Config(
    region_name=region_name,
    retries={
        "max_attempts": 10,
        "mode": "standard",
    },
)
client = boto3.client("bedrock-runtime", region_name=region_name, config=retry_config)

INFO:botocore.credentials:Found credentials from IAM Role: SSMInstanceProfile


### Text-to-SQL을 위한 Tool 정의
- list_db_tables
- desc_table_ciolumns
- query_checker
- query_executor


In [4]:
from typing import List, Dict
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
from langchain_community.utilities import SQLDatabase
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_aws import ChatBedrock
import pandas as pd
import ast

def list_db_tables(uri: str) -> Dict[str, str]:
    try:
        engine = create_engine(uri)
        db = SQLDatabase(engine)
        
        table_names = db.get_usable_table_names()
        tables_dict = {table_name: "desc" for table_name in table_names}
        return tables_dict
    except SQLAlchemyError as e:
        print(f"Error: {e}")
        return {}

def desc_table_columns(uri: str, tables: List[str]) -> Dict[str, List[str]]:
    try:
        engine = create_engine(uri)
        db = SQLDatabase(engine)
        
        metadata = db._metadata
        metadata.reflect(bind=engine, only=tables)
        
        table_columns = {}
        
        for table in tables:
            if table in metadata.tables:
                table_obj = metadata.tables[table]
                column_names = [col.name for col in table_obj.columns]
                table_columns[table] = column_names
            else:
                table_columns[table] = []

        return table_columns
    except SQLAlchemyError as e:
        print(f"Error: {e}")
        return {}

def query_checker(query: str, dialect: str, model_id="meta.llama3-70b-instruct-v1:0"):
    chat = ChatBedrock(
        model_id=model_id,
        region_name=region_name,
        model_kwargs={"temperature": 0.1},
    )
    message = [
        SystemMessage(
            content="""
            Double check the {dialect} query above for common mistakes, including:
            - Using NOT IN with NULL values
            - Using UNION when UNION ALL should have been used
            - Using BETWEEN for exclusive ranges
            - Data type mismatch in predicates
            - Properly quoting identifiers
            - Using the correct number of arguments for functions
            - Casting to the correct data type
            - Using the proper columns for joins

            If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

            Output the final SQL query only. """.format(dialect=dialect)
        ),
        HumanMessage(
            content=query
        )
    ]
    res = chat.invoke(message).content
    return res

def query_executor(uri: str, query: str, output_columns: List[str]):
    engine = create_engine(uri)
    db = SQLDatabase(engine)

    data = db.run_no_throw(query)
    data = ast.literal_eval(data)
    if data:
        df = pd.DataFrame(data, columns=output_columns)  
        return df.to_csv(index=False)
    else:
        return None

### System Prompt 정의

In [5]:
system_prompts = [
    {
        "text":"""
            You are a helpful assistant tasked with answering user queries efficiently.
            Use the provided tools to progress towards answering the question. 
            Based on the user's question, compose a SQLite query if necessary, examine the results, and then provide an answer. 
            Provide a final answer to the user's question with specific data and include the SQL query used to obtain it within a Markdown code block. 
        """
    }
]

### Tool 모듈 별 테스트

In [6]:
input_text = "2022년 매출 상위 10개 국가를 알려줘"
uri="sqlite:///Chinook.db"
dialect = "sqlite"

In [7]:
tables = list_db_tables(uri)
print("tables:", tables)

## ReAct - Select Tables (Invoice, Customer)

tables: {'Album': 'desc', 'Artist': 'desc', 'Customer': 'desc', 'Employee': 'desc', 'Genre': 'desc', 'Invoice': 'desc', 'InvoiceLine': 'desc', 'MediaType': 'desc', 'Playlist': 'desc', 'PlaylistTrack': 'desc', 'Track': 'desc'}


In [8]:
tables = ["Invoice", "Customer"]
columns = desc_table_columns(uri, tables)
print("columns:", columns)

## ReAct - Select Columns and Compose a query

columns: {'Invoice': ['InvoiceId', 'CustomerId', 'InvoiceDate', 'BillingAddress', 'BillingCity', 'BillingState', 'BillingCountry', 'BillingPostalCode', 'Total'], 'Customer': ['CustomerId', 'FirstName', 'LastName', 'Company', 'Address', 'City', 'State', 'Country', 'PostalCode', 'Phone', 'Fax', 'Email', 'SupportRepId']}


In [9]:
query = """SELECT c.Country, SUM(i.Total) AS TotalSales  
FROM Invoice i  
JOIN Customer c ON i.CustomerId = c.CustomerId  
WHERE strftime('%Y', i.InvoiceDate) = '2022'  
GROUP BY c.Country  
ORDER BY TotalSales DESC  
LIMIT 10;"""

final_query = query_checker(query=query, dialect=dialect)
answer = query_executor(uri, final_query, output_columns=['Country', 'TotalSales'])
print(answer)

INFO:botocore.credentials:Found credentials from IAM Role: SSMInstanceProfile


Country,TotalSales
USA,102.97999999999999
Canada,76.25999999999999
Brazil,41.6
France,39.599999999999994
Hungary,32.75
United Kingdom,30.69
Austria,27.77
Germany,25.740000000000002
Chile,17.91
India,17.83



### Tool 스펙 정의 (ToolConfig)

In [12]:
tool_config = {
    "tools": [
        {
            "toolSpec": {
                "name": "list_tables",
                "description": "Get tables names and descriptions.",
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "uri": {
                                "type": "string",
                                "description": "database uri for which you want to access"
                            }
                        },
                        "required": [
                            "uri"
                        ]
                    }
                }
            }
        },
        {
            "toolSpec": {
                "name": "desc_columns",
                "description": """
                Input is a list of tables, output is the description about the DB schemas and sample rows for those tables.
                Use this tool before generating a query. Be sure that the tables actually exist by using 'list_tables' tool first!
                """,
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "uri": {
                                "type": "string",
                                "description": "database uri for which you want to access"
                            },
                            "tables": {
                                "type": "array",
                                "items": {
                                    "type": "string"
                                },
                                "description": "list of table names for which you want to get column descriptions"
                            }
                        },
                        "required": [
                            "uri",
                            "tables"
                        ]
                    }
                }
            }
        },
        {
            "toolSpec": {
                "name": "query_checker",
                "description": """
                Use an LLM to check if a query is correct.
                Always use this tool before executing a query with sql_db_query!
                """,
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "query": {
                                "type": "string",
                                "description": "The SQL query to check"
                            },
                            "dialect": {
                                "type": "string",
                                "description": "The SQL dialect of the database"
                            }
                        },
                        "required": [
                            "query",
                            "dialect"
                        ]
                    }
                }
            }
        },
        {
            "toolSpec": {
                "name": "query_executor",
                "description": """
                Execute a SQL query against the database and get back the result.
                If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again.
                Only one statement can be executed at a time, so if multiple queries need to be executed, use this tool repeatedly.
                """,
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "uri": {
                                "type": "string",
                                "description": "database uri for which you want to access"
                            },
                            "query": {
                                "type": "string",
                                "description": "The SQL query to execute"
                            },
                            "output_columns": {
                                "type": "array",
                                "descriptions": "The column names expected in the output"
                            }
                        },
                        "required": [
                            "uri",
                            "query",
                            "output_columns"
                        ]
                    }
                }
            }
        }
    ]
}


### ConverseAPI 호출 - Streaming 응답의 결과 파싱

In [14]:
import json
import logging

logger = logging.getLogger(__name__)

def stream_messages(bedrock_client, model_id, messages, tool_config):
    #logger.info("Streaming messages with model %s", model_id)
    response = bedrock_client.converse_stream(
        modelId=model_id,
        messages=messages,
        system=system_prompts,
        toolConfig=tool_config
    )

    stop_reason = ""
 
    message = {}
    content = []
    message['content'] = content
    text = ''
    tool_use = {}

    #stream the response into a message.
    for chunk in response['stream']:
        if 'messageStart' in chunk:
            message['role'] = chunk['messageStart']['role']
        elif 'contentBlockStart' in chunk:
            tool = chunk['contentBlockStart']['start']['toolUse']
            tool_use['toolUseId'] = tool['toolUseId']
            tool_use['name'] = tool['name']
        elif 'contentBlockDelta' in chunk:
            delta = chunk['contentBlockDelta']['delta']
            if 'toolUse' in delta:
                if 'input' not in tool_use:
                    tool_use['input'] = ''
                tool_use['input'] += delta['toolUse']['input']
            elif 'text' in delta:
                text += delta['text']
                print(delta['text'], end='')
        elif 'contentBlockStop' in chunk:
            if 'input' in tool_use:
                tool_use['input'] = json.loads(tool_use['input'])
                content.append({'toolUse': tool_use})
                tool_use = {}
            else:
                content.append({'text': text})
                text = ''

        elif 'messageStop' in chunk:
            stop_reason = chunk['messageStop']['stopReason']

    return stop_reason, message

### Router 정의 
- 모델의 응답에서 Tool Use 요청이 발생했을 때, 이를 해당 함수로 라우팅

In [15]:
def tool_router(tool, messages):
    print(f"\n<Tool: {tool['name']}>")
    match tool['name']:
        case 'list_tables':
            res = list_db_tables(tool['input']['uri'])
            tool_result = {
                "toolUseId": tool['toolUseId'],
                "content": [{"json": res}]
            }
        case 'desc_columns':
            res = desc_table_columns(tool['input']['uri'], tool['input']['tables'])
            tool_result = {
                "toolUseId": tool['toolUseId'],
                "content": [{"json": res}]
            }
        case 'query_checker':
            res = query_checker(tool['input']['query'], tool['input']['dialect'])
            tool_result = {
                "toolUseId": tool['toolUseId'],
                "content": [{"text": res}]
            }
        case 'query_executor':
            res = query_executor(tool['input']['uri'], tool['input']['query'], tool['input']['output_columns']) 
            tool_result = {
                "toolUseId": tool['toolUseId'],
                "content": [{"text": res}]
            }

    print(f"Result: {tool_result['content'][0]}\n")
    tool_result_message = {"role": "user", "content": [{"toolResult": tool_result}]}

    return tool_result_message

### 사용자 프롬프트 정자 (DB에 대한 자연어 질문)

In [16]:
model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
input_text = """
2022년 매출 상위 10개 국가를 알아내는 쿼리를 작성해줘

DB_URI: sqlite:///Chinook.db
"""

In [17]:
messages = [{
    "role": "user",
    "content": [{"text": input_text}]
}]

stop_reason, message = stream_messages(client, model_id, messages, tool_config)
messages.append(message)

while stop_reason == "tool_use":
    contents = message["content"]
    for c in contents:  
        if "toolUse" not in c:
            continue
        tool_use = c["toolUse"] 
        message = tool_router(tool_use, messages)
        messages.append(message)
    
    stop_reason, message = stream_messages(client, model_id, messages, tool_config)
    messages.append(message)

print(f"\nFinal Response: {message['content'][0]['text']}")


<Tool: list_tables>
Result: {'json': {'Album': 'desc', 'Artist': 'desc', 'Customer': 'desc', 'Employee': 'desc', 'Genre': 'desc', 'Invoice': 'desc', 'InvoiceLine': 'desc', 'MediaType': 'desc', 'Playlist': 'desc', 'PlaylistTrack': 'desc', 'Track': 'desc'}}

먼저 'list_tables' 함수를 사용하여 Chinook 데이터베이스의 테이블들을 확인했습니다. Invoices와 Customers 테이블을 조회하면 요구사항을 충족할 수 있을 것 같습니다.
<Tool: desc_columns>
Result: {'json': {'Invoice': ['InvoiceId', 'CustomerId', 'InvoiceDate', 'BillingAddress', 'BillingCity', 'BillingState', 'BillingCountry', 'BillingPostalCode', 'Total'], 'Customer': ['CustomerId', 'FirstName', 'LastName', 'Company', 'Address', 'City', 'State', 'Country', 'PostalCode', 'Phone', 'Fax', 'Email', 'SupportRepId']}}

desc_columns 툴을 사용하여 Invoice와 Customer 테이블의 스키마를 확인했습니다. Invoice 테이블에는 BillingCountry 열이 있어 국가 정보를 얻을 수 있고, Total 열로 매출액을 알 수 있습니다.

이제 쿼리를 작성해보겠습니다.

INFO:botocore.credentials:Found credentials from IAM Role: SSMInstanceProfile



<Tool: query_checker>
Result: {'text': "SELECT BillingCountry, SUM(Total) AS TotalSales\nFROM Invoice\nWHERE InvoiceDate >= '2022-01-01' AND InvoiceDate < '2023-01-01'\nGROUP BY BillingCountry\nORDER BY TotalSales DESC\nLIMIT 10;"}

query_checker 툴을 사용하여 작성한 쿼리를 검사했습니다. 쿼리의 WHERE 절에서 BETWEEN 대신 >= 와 < 를 사용하는 것이 더 효율적이라는 조언을 받았습니다.

이제 수정한 쿼리를 실행해보겠습니다.
<Tool: query_executor>
Result: {'text': 'BillingCountry,TotalSales\nUSA,102.97999999999999\nCanada,76.25999999999999\nBrazil,41.6\nFrance,39.599999999999994\nHungary,32.75\nUnited Kingdom,30.69\nAustria,27.77\nGermany,25.740000000000002\nChile,17.91\nIndia,17.83\n'}

최종 쿼리 결과는 다음과 같습니다:

```sql
SELECT BillingCountry, SUM(Total) AS TotalSales
FROM Invoice  
WHERE InvoiceDate >= '2022-01-01' AND InvoiceDate < '2023-01-01'
GROUP BY BillingCountry
ORDER BY TotalSales DESC
LIMIT 10;
```

매출이 가장 높은 상위 10개 국가는 미국, 캐나다, 브라질, 프랑스, 헝가리, 영국, 오스트리아, 독일, 칠레, 인도 순입니다.
Final Response: 최종 쿼리 결과는 다음과 같습니다:

```sql
SELECT BillingCountry, SUM(Total) AS To