# Bedrock Model Routing - LLM routing

## Intro and Goal
This Jupyter Notebook is designed to test an LLM (Large Language Model) routing system. 
The goal is to evaluate a prompt to determine the level of difficulty, and then to route the prompt to a smaller or larger LLM accordingly.

The notebook is structured as follows:
1. Define and test LLM router prompt.
2. Define and test LLM router.

In [1]:
# Import necessary libraries
import boto3
import json

from dotenv import load_dotenv, find_dotenv
import os

# loading environment variables that are stored in local file dev.env
local_env_filename = 'bedrock-router-eval.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
REGION = os.environ['REGION']

client = boto3.client(service_name='bedrock-runtime', region_name=REGION)

In [2]:
# bedrock_client = boto3.client(service_name='bedrock', region_name=REGION)
# bedrock_client.list_foundation_models()

In [3]:
# Step 1: Define your LLM router

router_prompt = """Give this question a difficulty rating from 1 to 3, where 3 is the most difficult and 1 is the easiest.
                   Return the difficulty inside <score></score> tags. 
                   Do not include anything else in your response.\n"""
router_model = "anthropic.claude-3-haiku-20240307-v1:0"

In [4]:
# Step 2: Evaluate the prompt

client = boto3.client(service_name='bedrock-runtime', region_name=REGION)
from botocore.exceptions import ClientError
# Format the request payload using the model's native structure.
def eval(prompt):
    native_request = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 512,
        "temperature": 0,
        "messages": [
            {
                "role": "user",
                "content": [{"type": "text", "text": router_prompt + prompt}],
            }
        ],
    }
    
    # Convert the native request to JSON.
    request = json.dumps(native_request)
    # print(f'native_request: {str(native_request)}')
    try:
        # Invoke the model with the request.
        response = client.invoke_model(modelId=router_model, body=request)
    
    except (ClientError, Exception) as e:
        print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
        exit(1)
    
    # Decode the response body.
    model_response = json.loads(response["body"].read())
    
    # Extract and print the response text.
    response_text = model_response["content"][0]["text"]
    return response_text



In [5]:
# Step 3: Define test prompt

sql_prompt ='''
Question:
<user_question>
What is the total number of customers?
</user_question>

SQL schema:
<sql_database_schema>
CREATE TABLE categories (
    category_id smallint NOT NULL PRIMARY KEY,
    category_name character varying(15) NOT NULL,
    description text,
    picture bytea
);

CREATE TABLE customer_demographics (
    customer_type_id bpchar NOT NULL PRIMARY KEY,
    customer_desc text
);

CREATE TABLE customers (
    customer_id bpchar NOT NULL PRIMARY KEY,
    company_name character varying(40) NOT NULL,
    contact_name character varying(30),
    contact_title character varying(30),
    address character varying(60),
    city character varying(15),
    region character varying(15),
    postal_code character varying(10),
    country character varying(15),
    phone character varying(24),
    fax character varying(24)
);

CREATE TABLE customer_customer_demo (
    customer_id bpchar NOT NULL,
    customer_type_id bpchar NOT NULL,
    PRIMARY KEY (customer_id, customer_type_id),
    FOREIGN KEY (customer_type_id) REFERENCES customer_demographics,
    FOREIGN KEY (customer_id) REFERENCES customers
);

CREATE TABLE employees (
    employee_id smallint NOT NULL PRIMARY KEY,
    last_name character varying(20) NOT NULL,
    first_name character varying(10) NOT NULL,
    title character varying(30),
    title_of_courtesy character varying(25),
    birth_date date,
    hire_date date,
    address character varying(60),
    city character varying(15),
    region character varying(15),
    postal_code character varying(10),
    country character varying(15),
    home_phone character varying(24),
    extension character varying(4),
    photo bytea,
    notes text,
    reports_to smallint,
    photo_path character varying(255),
	FOREIGN KEY (reports_to) REFERENCES employees
);

CREATE TABLE suppliers (
    supplier_id smallint NOT NULL PRIMARY KEY,
    company_name character varying(40) NOT NULL,
    contact_name character varying(30),
    contact_title character varying(30),
    address character varying(60),
    city character varying(15),
    region character varying(15),
    postal_code character varying(10),
    country character varying(15),
    phone character varying(24),
    fax character varying(24),
    homepage text
);

CREATE TABLE products (
    product_id smallint NOT NULL PRIMARY KEY,
    product_name character varying(40) NOT NULL,
    supplier_id smallint,
    category_id smallint,
    quantity_per_unit character varying(20),
    unit_price real,
    units_in_stock smallint,
    units_on_order smallint,
    reorder_level smallint,
    discontinued integer NOT NULL,
	FOREIGN KEY (category_id) REFERENCES categories,
	FOREIGN KEY (supplier_id) REFERENCES suppliers
);

CREATE TABLE region (
    region_id smallint NOT NULL PRIMARY KEY,
    region_description bpchar NOT NULL
);

CREATE TABLE shippers (
    shipper_id smallint NOT NULL PRIMARY KEY,
    company_name character varying(40) NOT NULL,
    phone character varying(24)
);

CREATE TABLE orders (
    order_id smallint NOT NULL PRIMARY KEY,
    customer_id bpchar,
    employee_id smallint,
    order_date date,
    required_date date,
    shipped_date date,
    ship_via smallint,
    freight real,
    ship_name character varying(40),
    ship_address character varying(60),
    ship_city character varying(15),
    ship_region character varying(15),
    ship_postal_code character varying(10),
    ship_country character varying(15),
    FOREIGN KEY (customer_id) REFERENCES customers,
    FOREIGN KEY (employee_id) REFERENCES employees,
    FOREIGN KEY (ship_via) REFERENCES shippers
);

CREATE TABLE territories (
    territory_id character varying(20) NOT NULL PRIMARY KEY,
    territory_description bpchar NOT NULL,
    region_id smallint NOT NULL,
	FOREIGN KEY (region_id) REFERENCES region
);

CREATE TABLE employee_territories (
    employee_id smallint NOT NULL,
    territory_id character varying(20) NOT NULL,
    PRIMARY KEY (employee_id, territory_id),
    FOREIGN KEY (territory_id) REFERENCES territories,
    FOREIGN KEY (employee_id) REFERENCES employees
);

CREATE TABLE order_details (
    order_id smallint NOT NULL,
    product_id smallint NOT NULL,
    unit_price real NOT NULL,
    quantity smallint NOT NULL,
    discount real NOT NULL,
    PRIMARY KEY (order_id, product_id),
    FOREIGN KEY (product_id) REFERENCES products,
    FOREIGN KEY (order_id) REFERENCES orders
);

CREATE TABLE us_states (
    state_id smallint NOT NULL PRIMARY KEY,
    state_name character varying(100),
    state_abbr character varying(2),
    state_region character varying(50)
);


</sql_database_schema>
              
'''

In [6]:
# Step 4: Test eval
eval(sql_prompt)

'<score>1</score>'

In [7]:
# Step 5: Construct the router
SCORE_PATTERN = r'<score>(.*?)</score>'
SQL_PATTERN = r'<SQL>(.*?)</SQL>'
import re

# Strip out the portion of the response with regex.
def extract_with_regex(response, regex):
    matches = re.search(regex, response, re.DOTALL)
    # Extract the matched content, if any
    return matches.group(1).strip() if matches else None

def route_prompt(prompt):
    route = int(extract_with_regex(eval(prompt), SCORE_PATTERN))
    print(f'route: {route}')
    if route==1:
        # easiest difficulty
        model_1 = "mistral.mixtral-8x7b-instruct-v0:1"
        native_request_1 = {
            "prompt": '<s>[INST] ' + prompt + '[/INST]',
            # "max_gen_len": 512,
            "temperature": 0,
        }
        request = json.dumps(native_request_1)
        try:
            # Invoke the model with the request.
            response = client.invoke_model(modelId=model_1, body=request)
        
        except (ClientError, Exception) as e:
            print(f"ERROR: Can't invoke '{model_1}'. Reason: {e}")
            exit(1)
        
        # Decode the response body.
        model_response = json.loads(response["body"].read())
        response_text = model_response.get('outputs')[0].get('text')

        return response_text
    if route==2:
        # medium difficulty
        model_2 = "anthropic.claude-3-haiku-20240307-v1:0"
        native_request_2 = {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 512,
            "temperature": 0,
            "messages": [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": prompt}],
                }
            ],
        }
        request = json.dumps(native_request_2)
        try:
            # Invoke the model with the request.
            response = client.invoke_model(modelId=model_2, body=request)
        
        except (ClientError, Exception) as e:
            print(f"ERROR: Can't invoke '{model_2}'. Reason: {e}")
            exit(1)
        
        # Decode the response body.
        model_response = json.loads(response["body"].read())
        
        response_text = model_response.get('content')[0].get('text')
        return response_text
    else:
        # most difficult
        model_3 = 'anthropic.claude-3-sonnet-20240229-v1:0' #"anthropic.claude-3-5-sonnet-20240620-v1:0"
        native_request_3 = {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 512,
            "temperature": 0,
            "messages": [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": prompt}],
                }
            ],
        }
        request = json.dumps(native_request_3)
        try:
            # Invoke the model with the request.
            response = client.invoke_model(modelId=model_3, body=request)
        
        except (ClientError, Exception) as e:
            print(f"ERROR: Can't invoke '{model_3}'. Reason: {e}")
            exit(1)
        
        # Decode the response body.
        model_response = json.loads(response["body"].read())
        
        response_text = model_response.get('content')[0].get('text')
        
        return response_text

In [8]:
# Step 6: Test router
route_prompt(sql_prompt)

route: 1


' To find the total number of customers, you can use the following SQL query:\n```\nSELECT COUNT(*) FROM customers;\n```\nThis query will count all the rows in the `customers` table, which corresponds to the total number of customers.\n\nAlternatively, if you want to count the number of distinct customers, you can use the following query:\n```\nSELECT COUNT(DISTINCT customer_id) FROM customers;\n```\nThis query will count the number of unique customer IDs in the `customers` table.\n\nI hope this helps! Let me know if you have any other questions.'

In [9]:
# Step 7: Evaluate the router with a more difficult question
sql_prompt2 = '''
Question:
<user_question>
List all products that have a higher than average unit price in their category.
</user_question>

SQL schema:
<sql_database_schema>
CREATE TABLE categories (
    category_id smallint NOT NULL PRIMARY KEY,
    category_name character varying(15) NOT NULL,
    description text,
    picture bytea
);

CREATE TABLE customer_demographics (
    customer_type_id bpchar NOT NULL PRIMARY KEY,
    customer_desc text
);

CREATE TABLE customers (
    customer_id bpchar NOT NULL PRIMARY KEY,
    company_name character varying(40) NOT NULL,
    contact_name character varying(30),
    contact_title character varying(30),
    address character varying(60),
    city character varying(15),
    region character varying(15),
    postal_code character varying(10),
    country character varying(15),
    phone character varying(24),
    fax character varying(24)
);

CREATE TABLE customer_customer_demo (
    customer_id bpchar NOT NULL,
    customer_type_id bpchar NOT NULL,
    PRIMARY KEY (customer_id, customer_type_id),
    FOREIGN KEY (customer_type_id) REFERENCES customer_demographics,
    FOREIGN KEY (customer_id) REFERENCES customers
);

CREATE TABLE employees (
    employee_id smallint NOT NULL PRIMARY KEY,
    last_name character varying(20) NOT NULL,
    first_name character varying(10) NOT NULL,
    title character varying(30),
    title_of_courtesy character varying(25),
    birth_date date,
    hire_date date,
    address character varying(60),
    city character varying(15),
    region character varying(15),
    postal_code character varying(10),
    country character varying(15),
    home_phone character varying(24),
    extension character varying(4),
    photo bytea,
    notes text,
    reports_to smallint,
    photo_path character varying(255),
	FOREIGN KEY (reports_to) REFERENCES employees
);

CREATE TABLE suppliers (
    supplier_id smallint NOT NULL PRIMARY KEY,
    company_name character varying(40) NOT NULL,
    contact_name character varying(30),
    contact_title character varying(30),
    address character varying(60),
    city character varying(15),
    region character varying(15),
    postal_code character varying(10),
    country character varying(15),
    phone character varying(24),
    fax character varying(24),
    homepage text
);

CREATE TABLE products (
    product_id smallint NOT NULL PRIMARY KEY,
    product_name character varying(40) NOT NULL,
    supplier_id smallint,
    category_id smallint,
    quantity_per_unit character varying(20),
    unit_price real,
    units_in_stock smallint,
    units_on_order smallint,
    reorder_level smallint,
    discontinued integer NOT NULL,
	FOREIGN KEY (category_id) REFERENCES categories,
	FOREIGN KEY (supplier_id) REFERENCES suppliers
);

CREATE TABLE region (
    region_id smallint NOT NULL PRIMARY KEY,
    region_description bpchar NOT NULL
);

CREATE TABLE shippers (
    shipper_id smallint NOT NULL PRIMARY KEY,
    company_name character varying(40) NOT NULL,
    phone character varying(24)
);

CREATE TABLE orders (
    order_id smallint NOT NULL PRIMARY KEY,
    customer_id bpchar,
    employee_id smallint,
    order_date date,
    required_date date,
    shipped_date date,
    ship_via smallint,
    freight real,
    ship_name character varying(40),
    ship_address character varying(60),
    ship_city character varying(15),
    ship_region character varying(15),
    ship_postal_code character varying(10),
    ship_country character varying(15),
    FOREIGN KEY (customer_id) REFERENCES customers,
    FOREIGN KEY (employee_id) REFERENCES employees,
    FOREIGN KEY (ship_via) REFERENCES shippers
);

CREATE TABLE territories (
    territory_id character varying(20) NOT NULL PRIMARY KEY,
    territory_description bpchar NOT NULL,
    region_id smallint NOT NULL,
	FOREIGN KEY (region_id) REFERENCES region
);

CREATE TABLE employee_territories (
    employee_id smallint NOT NULL,
    territory_id character varying(20) NOT NULL,
    PRIMARY KEY (employee_id, territory_id),
    FOREIGN KEY (territory_id) REFERENCES territories,
    FOREIGN KEY (employee_id) REFERENCES employees
);

CREATE TABLE order_details (
    order_id smallint NOT NULL,
    product_id smallint NOT NULL,
    unit_price real NOT NULL,
    quantity smallint NOT NULL,
    discount real NOT NULL,
    PRIMARY KEY (order_id, product_id),
    FOREIGN KEY (product_id) REFERENCES products,
    FOREIGN KEY (order_id) REFERENCES orders
);

CREATE TABLE us_states (
    state_id smallint NOT NULL PRIMARY KEY,
    state_name character varying(100),
    state_abbr character varying(2),
    state_region character varying(50)
);
</sql_database_schema>

Instructions:
Generate a SQL query that answers the original user question.
Use the schema, first create a syntactically correct SQlite query to answer the question. 
Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. 
Be careful to not query for columns that do not exist. 
Pay attention to which column is in which table. 
Also, qualify column names with the table name when needed.
If you cannot answer the user question with the help of the provided SQL database schema, 
then output that this question question cannot be answered based of the information stored in the database.
You are required to use the following format, each taking one line:
Return the sql query inside the <SQL></SQL> tab.
                
'''

In [10]:
route_prompt(sql_prompt2)

route: 3


'<SQL>\nSELECT p.product_name, p.unit_price, c.category_name\nFROM products p\nJOIN categories c ON p.category_id = c.category_id\nWHERE p.unit_price > (\n    SELECT AVG(unit_price)\n    FROM products p2\n    WHERE p2.category_id = p.category_id\n);\n</SQL>\n\nThis query first joins the products and categories tables to get the product name, unit price, and category name. It then filters the results to only include products where the unit price is greater than the average unit price for that category. The average unit price per category is calculated in a subquery.'