# Lab. 1-2 Function Calling

## Initialize Bedrock 

In [1]:
import boto3

from sqlalchemy import create_engine, text

In [3]:
boto_session = boto3.Session()
region_name = boto_session.region_name
llm_model = "anthropic.claude-3-5-sonnet-20241022-v2:0"

boto3_client = boto3.client("bedrock-runtime", region_name=region_name)

In [5]:
def converse_with_bedrock(sys_prompt, usr_prompt):    
    temperature = 0.0
    top_p = 0.1
    inference_config = {"temperature": temperature, "topP": top_p}
    
    response = boto3_client.converse(
        modelId=llm_model, 
        messages=usr_prompt, 
        system=sys_prompt,
        inferenceConfig=inference_config
    )

    return response['output']['message']['content'][0]['text']

## Developing the Text2SQL modules

#### List Tables

In [None]:
import json

uri = "sqlite:///../Chinook.db"

def get_table_info():
    with open('chinook_schema.json', 'r') as file:
        schema_data = json.load(file)

    tables_dict = {}
    for table_info in schema_data:
        for table_name, table_data in table_info.items():
            tables_dict[table_name] = table_data['table_desc']

    return tables_dict

# Test
table_info = get_table_info()
for table_name, table_desc in table_info.items():
    print(f"Table: {table_name}")
    print(f"Description: {table_desc}")
    print()

#### List Columns

In [None]:
def get_table_columns(tables=None):
    with open('chinook_schema.json', 'r') as file:
        schema_data = json.load(file)

    table_columns = {}

    for table_info in schema_data:
        for table_name, table_data in table_info.items():
            if tables is None or table_name in tables:
                column_info = {}
                for col in table_data['cols']:
                    column_info[col['col']] = col['col_desc']
                table_columns[table_name] = column_info

    return table_columns

# Test
tables = ["Album", "Customer"]
get_table_columns(tables)

#### Query Evaluation

In [None]:
def query_checker(question: str, sql_query: str):
    sys_prompt = [{
        "text": f"""
            Double check the {dialect} query above for common mistakes, including:
            - Using NOT IN with NULL values
            - Using UNION when UNION ALL should have been used
            - Using BETWEEN for exclusive ranges
            - Data type mismatch in predicates
            - Properly quoting identifiers
            - Using the correct number of arguments for functions
            - Casting to the correct data type
            - Using the proper columns for joins

            If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""" 
    }] 
    
    user_prompt = [{
        "role": "user",
        "content": [{"text": f"question: {question}\n query: {sql_query}\n\n Skip the preamble and provide only the final SQL query. ]"}]
    }]

    response = converse_with_bedrock(sys_prompt, user_prompt)
    return response

In [None]:
question = "Find the average invoice total for each country, but only for countries with more than 5 customers, ordered by the average total descending."
sql_query = """SELECT 
    c."Country",
    COUNT(DISTINCT c."CustomerId") as "CustomerCount",
    ROUND(AVG(i."Total"), 2) as "AverageTotal"
FROM Customer c
LEFT JOIN Invoice i ON c."CustomerId" = i."CustomerId"
GROUP BY c."Country"
HAVING COUNT(DISTINCT c."CustomerId") > 5
ORDER BY "AverageTotal" DESC;"""
dialect = "sqlite"


response = query_checker(question, sql_query)
print(response)

#### Query Execution

In [None]:
import pandas as pd
from typing import List

def query_executor(query: str, output_columns: List[str]):
    engine = create_engine(uri)

    try:
        with engine.connect() as connection:
            result = connection.execute(text(query))

            if result.returns_rows:
                df = pd.DataFrame(result.fetchall(), columns=output_columns)
                return df.to_csv(index=False)
            else:
                return None
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None
    finally:
        engine.dispose()

query = """SELECT 
    c."Country",
    COUNT(DISTINCT c."CustomerId") as "CustomerCount",
    ROUND(AVG(i."Total"), 2) as "AverageTotal"
FROM Customer c
LEFT JOIN Invoice i ON c."CustomerId" = i."CustomerId"
GROUP BY c."Country"
HAVING COUNT(DISTINCT c."CustomerId") > 5
ORDER BY "AverageTotal" DESC;"""

output_columns = ["Country", "CustomerCount", "AverageTotal"]
response = query_executor(query, output_columns)

print(response)

## Tool Use Configurations

In [17]:
sys_prompts = [{
    "text":"""
        You are a helpful assistant tasked with answering user queries efficiently.
        Use the provided tools to progress towards answering the question. 
        Based on the user's question, compose a SQLite query if necessary, examine the results, and then provide an answer. 
        Provide a final answer to the user's question with specific data and include the SQL query used to obtain it within a Markdown code block. 
    """
}]

In [48]:
tool_config = {
    "tools": [
        {
            "toolSpec": {
                "name": "list_tables",
                "description": "Get tables names and descriptions.",
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                        }
                    }
                }
            }
        },
        {
            "toolSpec": {
                "name": "desc_columns",
                "description": """Use this tool before generating a query. Be sure that the tables actually exist by using 'list_tables' tool first!. Input is a list of tables, output is the description about the DB schemas and sample rows for those tables.""",
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "tables": {
                                "type": "array",
                                "items": {
                                    "type": "string"
                                },
                                "description": "list of table names for which you want to get column descriptions"
                            }
                        },
                        "required": [
                            "tables"
                        ]
                    }
                }
            }
        },
        {
            "toolSpec": {
                "name": "query_checker",
                "description": """Use an LLM to check if a query is correct. Always use this tool before executing a query with sql_db_query!""",
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "question": {
                                "type": "string",
                                "description": "user's question for query generation"
                            },
                            "sql_query": {
                                "type": "string",
                                "description": "The SQL query to check"
                            }
                        },
                        "required": [
                            "question",
                            "sql_query"
                        ]
                    }
                }
            }
        },
        {
            "toolSpec": {
                "name": "query_executor",
                "description": """Execute a SQL query against the database and get back the result. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. Only one statement can be executed at a time, so if multiple queries need to be executed, use this tool repeatedly.""",
                "inputSchema": {
                    "json": {
                        "type": "object",
                        "properties": {
                            "query": {
                                "type": "string",
                                "description": "The SQL query to execute"
                            },
                            "output_columns": {
                                "type": "array",
                                "descriptions": "The column names expected in the output"
                            }
                        },
                        "required": [
                            "query",
                            "output_columns"
                        ]
                    }
                }
            }
        }
    ]
}

In [43]:
import json

def stream_messages(bedrock_client, model_id, messages, tool_config):
    response = bedrock_client.converse_stream(
        modelId=model_id,
        messages=messages,
        system=sys_prompts,
        toolConfig=tool_config
    )

    stop_reason = ""
 
    message = {}
    content = []
    message['content'] = content
    text = ''
    tool_use = {}

    #stream the response into a message.
    for chunk in response['stream']:
        if 'messageStart' in chunk:
            message['role'] = chunk['messageStart']['role']
        elif 'contentBlockStart' in chunk:
            tool = chunk['contentBlockStart']['start']['toolUse']
            tool_use['toolUseId'] = tool['toolUseId']
            tool_use['name'] = tool['name']
        elif 'contentBlockDelta' in chunk:
            delta = chunk['contentBlockDelta']['delta']
            if 'toolUse' in delta:
                if 'input' not in tool_use:
                    tool_use['input'] = ''
                tool_use['input'] += delta['toolUse']['input']
            elif 'text' in delta:
                text += delta['text']
                print(delta['text'], end='')
        elif 'contentBlockStop' in chunk:
            if 'input' in tool_use:
                try:
                    tool_use['input'] = json.loads(tool_use['input'])
                except json.JSONDecodeError:
                    tool_use['input'] = {"error": "Invalid JSON"}
                content.append({'toolUse': tool_use})
                tool_use = {}
            else:
                content.append({'text': text})
                text = ''

        elif 'messageStop' in chunk:
            stop_reason = chunk['messageStop']['stopReason']

    return stop_reason, message

In [49]:
def tool_router(tool, messages):
    print(f"\n<Tool: {tool['name']}>")

    def safe_get_input(key):
        return tool.get('input', {}).get(key)

    tool_functions = {
        'list_tables': get_table_info,
        'desc_columns': lambda: get_table_columns(safe_get_input('tables')),
        'query_checker': lambda: query_checker(safe_get_input('question'), safe_get_input('sql_query')),
        'query_executor': lambda: query_executor(safe_get_input('query'), safe_get_input('output_columns'))
    }

    if tool['name'] in tool_functions:
        res = tool_functions[tool['name']]()
        content_type = "json" if tool['name'] in ['list_tables', 'desc_columns'] else "text"

        tool_result = {
            "toolUseId": tool['toolUseId'],
            "content": [{content_type: res}]
        }
    else:
        tool_result = {
            "toolUseId": tool['toolUseId'],
            "content": [{"text": "Unknown tool"}]
        }

    print(f"Result: {tool_result['content'][0]}\n")
    tool_result_message = {"role": "user", "content": [{"toolResult": tool_result}]}

    return tool_result_message

## Text-to-SQL - modular approach (workflow)

In [None]:
llm_model = "anthropic.claude-3-5-sonnet-20241022-v2:0"

question = "Top 10 customers who purchased the most in 2022"

messages = [{
    "role": "user",
    "content": [{"text": question}]
}]

stop_reason, message = stream_messages(boto3_client, llm_model, messages, tool_config)
messages.append(message)


while stop_reason == "tool_use":
    contents = message["content"]
    for c in contents:  
        if "toolUse" not in c:
            continue
        tool_use = c["toolUse"] 
        message = tool_router(tool_use, messages)
        messages.append(message)
    
    stop_reason, message = stream_messages(boto3_client, llm_model, messages, tool_config)
    messages.append(message)

print(f"\nFinal Response: {message['content'][0]['text']}")