# Lab. 1-1 Basic Implementation

## Database Test

In [1]:
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import Session

In [None]:
engine = create_engine('sqlite:///../Chinook.db')

# table names
inspector = inspect(engine)
table_names = inspector.get_table_names()
print("table_names:\n", table_names)

# simple query
with Session(engine) as session:
    result = session.execute(text("SELECT * FROM Artist LIMIT 10"))
    print("\nrows:")
    for row in result:
        print(row)

## Bedrock Model Access

In [3]:
import boto3
from botocore.config import Config

region_name = "us-west-2"
llm_model = "anthropic.claude-3-5-haiku-20241022-v1:0"

def init_boto3_client(region: str):
    retry_config = Config(
        region_name=region,
        retries={"max_attempts": 10, "mode": "standard"}
    )
    return boto3.client("bedrock-runtime", region_name=region, config=retry_config)

def converse_with_bedrock(boto3_client, sys_prompt, usr_prompt):    
    temperature = 0.0
    top_p = 0.1
    inference_config = {"temperature": temperature, "topP": top_p}
    
    response = boto3_client.converse(
        modelId=llm_model, 
        messages=usr_prompt, 
        system=sys_prompt,
        inferenceConfig=inference_config
    )

    return response['output']['message']['content'][0]['text']


boto3_client = init_boto3_client(region_name)

In [None]:
test_sys_prompt = [{
    "text": "You are a cool assistant."
}]

test_user_prompt = [{
    "role": "user",
    "content": [{"text": "Hi! What's your name?"}]
}]

response = converse_with_bedrock(boto3_client, test_sys_prompt, test_user_prompt)
print(response)

## Basic Text-to-SQL Prompt

In [None]:
def get_schema_info(db_path):
    engine = create_engine(f'sqlite:///{db_path}')

    inspector = inspect(engine)
    schema_info = {}

    tables = inspector.get_table_names()
    for table_name in tables:
        columns = inspector.get_columns(table_name)

        table_info = f"Table: {table_name}\n"
        table_info += "\n".join(f"  - {col['name']} ({col['type']})" for col in columns)
        schema_info[table_name] = table_info

    return schema_info

schema = get_schema_info("../Chinook.db")
print(schema['Employee'])

In [None]:
dialect = "sqlite"
top_k = 10
table_info = schema['Customer']

sys_prompt = [{
    "text": f"""You are a {dialect} expert.
Given an input question, first create a syntactically correct SQLite query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date(\'now\') function to get the current date, if the question involves "today" 
    
Only use the following tables:
{table_info}
""" 
}]


def get_user_prompt(question):
    return [{
        "role": "user",
        "content": [{"text": f"Question:\n{question}]\n Skip the preamble and provide only the SQL."
        }]
    }]

response = converse_with_bedrock(boto3_client, sys_prompt, get_user_prompt("List the total sales per country. Which country's customers spent the most?"))
sql_query = text(response)
print(sql_query)


In [None]:
with Session(engine) as session:
    result = session.execute(sql_query)
    for row in result:
        print(row)
    

## Chain-of-Thought Prompt

In [None]:
table_info = schema['Customer'] + "\n" + schema['Invoice']
print(table_info)

In [9]:
example = """
<example>
<query>
Find the top 3 customers who have spent the most money in 2023, showing their names and total spending.
</query> 
<thought_process> 
1. We need to join the Customer and Invoice tables. 
2. We'll sum up the Total from Invoice for each customer. 
3. We'll filter for invoices from the year 2023. 
4. We'll order by the total spending in descending order. 
5. We'll limit the results to the top 3 customers. 
</thought_process> 
<sql> 
SELECT c.FirstName, c.LastName, SUM(i.Total) AS TotalSpending FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId WHERE YEAR(i.InvoiceDate) = 2023 GROUP BY c.CustomerId, c.FirstName, c.LastName ORDER BY TotalSpending DESC LIMIT 3; 
</sql> 
</example> 

<example> 
<query>
List all customers from the USA who have not made any purchases in the last 6 months.
</query> 
<thought_process> 
1. We need to use both the Customer and Invoice tables. 
2. We'll filter for customers from the USA. 
3. We'll use a LEFT JOIN to include customers with no invoices. 
4. We'll check for the absence of recent invoices (within the last 6 months). 
5. We'll return the customer's full name and email. 
</thought_process> 
<sql> 
SELECT c.FirstName, c.LastName, c.Email FROM Customer c LEFT JOIN Invoice i ON c.CustomerId = i.CustomerId AND i.InvoiceDate >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) WHERE c.Country = 'USA' AND i.InvoiceId IS NULL; 
</sql> 
</example>
"""

In [None]:
llm_model = "anthropic.claude-3-5-sonnet-20241022-v2:0"

sys_prompt = [{
    "text": f"""You are a {dialect} expert.
Given an input question, first create a syntactically correct SQLite query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date(\'now\') function to get the current date, if the question involves "today" 
    
<schema>
{table_info}
</schema>

<examples>
{example}
</examples>
""" 
}]


def get_user_prompt(question):
    return [{
        "role": "user",
        "content": [{"text": f"<query>\n{question}</query>]"
        }]
    }]

response = converse_with_bedrock(boto3_client, sys_prompt, get_user_prompt("Find the average invoice total for each country, but only for countries with more than 5 customers, ordered by the average total descending."))
print(response)

In [None]:
thought_process = response.split('<thought_process>')[1].split('</thought_process>')[0].strip()
sql = response.split('<sql>')[1].split('</sql>')[0].strip()

print("Thought:\n", thought_process)

print("\nSQL:\n", sql)

In [None]:
sql_query = text(sql)
with Session(engine) as session:
    result = session.execute(sql_query)
    for row in result:
        print(row)
    

## Dynamic Few Shot Samples

In [14]:
examples = [
    {
        "input": "List all artists.", 
        "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of invoices.",
        "query": "SELECT COUNT(*) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]

In [None]:
import json

embed_model = "amazon.titan-embed-text-v2:0"
region_name = "us-west-2"

def input_embedding(example):
    for example in examples:
        input_text = example['input']
        query = example['query']

        response = boto3_client.invoke_model(
            modelId=embed_model,
            body=json.dumps({"inputText": input_text})
        )

        # Data part
        body = {
            "input": input_text,
            "query": query,
            "input_v": json.loads(response['body'].read())['embedding']
        }
        memory_storage.append(body)

memory_storage = []
input_embedding(examples)

for item in memory_storage:
    truncated_item = item.copy()
    truncated_item['input_v'] = str(item['input_v'][:3]) + '...' 
    print(json.dumps(truncated_item, indent=2))
    print()

In [None]:
question = "Let me know the 10 customers who purchased the most"

response = boto3_client.invoke_model(
    modelId=embed_model,
    body=json.dumps({"inputText": question})
)
question_v = json.loads(response['body'].read())['embedding']

print(str(question_v[:5]) + '...')

In [None]:
!pip install scipy numpy

In [None]:
import numpy as np
from scipy.spatial.distance import cosine
import heapq

def find_most_similar_samples(question_v, memory_storage, top_k=3):
    similar_docs = []

    for doc in memory_storage:
        # Cosine similarity 
        similarity = 1 - cosine(question_v, doc['input_v'])

        if len(similar_docs) < top_k:
            heapq.heappush(similar_docs, (similarity, doc))
        elif similarity > similar_docs[0][0]:
            heapq.heapreplace(similar_docs, (similarity, doc))


    return sorted(similar_docs, key=lambda x: x[0], reverse=True)

top_k = 3
top_similar_samples = find_most_similar_samples(question_v, memory_storage, top_k)

samples = ""
for i, (similarity, doc) in enumerate(top_similar_samples, 1):
    samples += f"\n{i}. Score: {similarity:.4f}\n"
    samples += f"Input: {doc['input']}\n"
    samples += f"Query: {doc['query']}\n"

print(samples)    


In [40]:
llm_model = "anthropic.claude-3-5-sonnet-20241022-v2:0"

sys_prompt = [{
    "text": f"""You are a {dialect} expert.
Given an input question, first create a syntactically correct SQLite query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date(\'now\') function to get the current date, if the question involves "today" 
    
<schema>
{table_info}
</schema>

<examples>
{example}
</examples>

<samples_queries>
{samples}
<samples_queries>
""" 
}]

In [None]:
def get_user_prompt(question):
    return [{
        "role": "user",
        "content": [{"text": f"<query>\n{question}</query>]"
        }]
    }]

response = converse_with_bedrock(boto3_client, sys_prompt, get_user_prompt(question))
print(response)

In [None]:
sql = response.split('<sql>')[1].split('</sql>')[0].strip()

sql_query = text(sql)
with Session(engine) as session:
    result = session.execute(sql_query)
    for row in result:
        print(row) 