# Lab. 1-2 Function Calling

## Initialize Bedrock 

In [15]:
import boto3
from botocore.config import Config

from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import Session

In [3]:
region_name = 'us-west-2'
llm_model = "anthropic.claude-3-5-sonnet-20241022-v2:0"

boto3_client = boto3.client("bedrock-runtime", region_name=region_name)

In [5]:
def converse_with_bedrock(sys_prompt, usr_prompt):    
    temperature = 0.0
    top_p = 0.1
    inference_config = {"temperature": temperature, "topP": top_p}
    
    response = boto3_client.converse(
        modelId=llm_model, 
        messages=usr_prompt, 
        system=sys_prompt,
        inferenceConfig=inference_config
    )

    return response['output']['message']['content'][0]['text']

## Developing the Text2SQL modules

#### List Tables

In [10]:
import json

uri = "sqlite:///../Chinook.db"

def get_table_info():
    with open('chinook_schema.json', 'r') as file:
        schema_data = json.load(file)

    tables_dict = {}
    for table_info in schema_data:
        for table_name, table_data in table_info.items():
            tables_dict[table_name] = table_data['table_desc']

    return tables_dict

# Test
table_info = get_table_info()
for table_name, table_desc in table_info.items():
    print(f"Table: {table_name}")
    print(f"Description: {table_desc}")
    print()

Table: Album
Description: Stores album data with unique ID, title, and links to artist via artist ID.

Table: Artist
Description: Holds artist information with an ID and name.

Table: Customer
Description: Contains customer details and links to their support representative.

Table: Employee
Description: Stores employee details, including their supervisory chain.

Table: Genre
Description: Catalogs music genres with a unique identifier and name.

Table: Invoice
Description: Records details of transactions, linked to customers.

Table: InvoiceLine
Description: Details each line item on an invoice, linked to tracks and invoices.

Table: MediaType
Description: Defines types of media for tracks.

Table: Playlist
Description: Organizes tracks into playlists.

Table: PlaylistTrack
Description: Links tracks to playlists.

Table: Track
Description: Stores detailed information about music tracks, linked to albums, genres, and media types.



#### List Columns

In [11]:
def get_table_columns(tables=None):
    with open('chinook_schema.json', 'r') as file:
        schema_data = json.load(file)

    table_columns = {}

    for table_info in schema_data:
        for table_name, table_data in table_info.items():
            if tables is None or table_name in tables:
                column_info = {}
                for col in table_data['cols']:
                    column_info[col['col']] = col['col_desc']
                table_columns[table_name] = column_info

    return table_columns

# Test
tables = ["Album", "Customer"]
get_table_columns(tables)

{'Album': {'AlbumId': 'Primary key, unique identifier for the album.',
  'Title': 'Title of the album.',
  'ArtistId': 'Foreign key that references the artist of the album.'},
 'Customer': {'CustomerId': 'Primary key, unique customer identifier.',
  'FirstName': 'First name of the customer.',
  'LastName': 'Last name of the customer.',
  'Company': 'Company of the customer.',
  'Address': 'Address of the customer.',
  'City': 'City of the customer.',
  'State': 'State of the customer.',
  'Country': 'Country of the customer.',
  'PostalCode': 'Postal code of the customer.',
  'Phone': 'Phone number of the customer.',
  'Fax': 'Fax number of the customer.',
  'Email': 'Email address of the customer.',
  'SupportRepId': 'Foreign key that references the employee who supports this customer.'}}

#### Query Evaluation

In [13]:
def query_checker(question: str, sql_query: str, dialect: str):
    sys_prompt = [{
        "text": f"""
            Double check the {dialect} query above for common mistakes, including:
            - Using NOT IN with NULL values
            - Using UNION when UNION ALL should have been used
            - Using BETWEEN for exclusive ranges
            - Data type mismatch in predicates
            - Properly quoting identifiers
            - Using the correct number of arguments for functions
            - Casting to the correct data type
            - Using the proper columns for joins

            If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""" 
    }] 
    
    user_prompt = [{
        "role": "user",
        "content": [{"text": f"question: {question}\n query: {sql_query}\n\n Skip the preamble and provide only the final SQL query. ]"}]
    }]

    response = converse_with_bedrock(sys_prompt, user_prompt)
    return response


question = "Find the average invoice total for each country, but only for countries with more than 5 customers, ordered by the average total descending."
sql_query = """SELECT 
    c."Country",
    COUNT(DISTINCT c."CustomerId") as "CustomerCount",
    ROUND(AVG(i."Total"), 2) as "AverageTotal"
FROM Customer c
LEFT JOIN Invoice i ON c."CustomerId" = i."CustomerId"
GROUP BY c."Country"
HAVING COUNT(DISTINCT c."CustomerId") > 5
ORDER BY "AverageTotal" DESC;"""

dialect = "sqlite"

response = query_checker(question, sql_query, dialect)
print(response)

SELECT 
    c."Country",
    COUNT(DISTINCT c."CustomerId") as "CustomerCount",
    ROUND(AVG(i."Total"), 2) as "AverageTotal"
FROM Customer c
LEFT JOIN Invoice i ON c."CustomerId" = i."CustomerId"
GROUP BY c."Country"
HAVING COUNT(DISTINCT c."CustomerId") > 5
ORDER BY "AverageTotal" DESC;


#### Query Execution

In [16]:
import pandas as pd
from typing import List

def query_executor(query: str, output_columns: List[str]):
    engine = create_engine(uri)

    try:
        with engine.connect() as connection:
            result = connection.execute(text(query))

            if result.returns_rows:
                df = pd.DataFrame(result.fetchall(), columns=output_columns)
                return df.to_csv(index=False)
            else:
                return None
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None
    finally:
        engine.dispose()

query = """SELECT 
    c."Country",
    COUNT(DISTINCT c."CustomerId") as "CustomerCount",
    ROUND(AVG(i."Total"), 2) as "AverageTotal"
FROM Customer c
LEFT JOIN Invoice i ON c."CustomerId" = i."CustomerId"
GROUP BY c."Country"
HAVING COUNT(DISTINCT c."CustomerId") > 5
ORDER BY "AverageTotal" DESC;"""

output_columns = ["Country", "CustomerCount", "AverageTotal"]
response = query_executor(query, output_columns)

print(response)

Country,CustomerCount,AverageTotal
USA,13,5.75
Canada,8,5.43

