# Bedrock Model Routing - custom semantic routing

## Intro and Goal

This Jupyter Notebook is designed to test an LLM (Large Language Model) routing system. The goal is to take a prompt, embed it using a vector embedding in Bedrock, and then measure the distance with two specific vectors that represent the domain for two specific LLMs. Based on the distance, the prompt will be routed to the appropriate LLM.

The notebook is structured as follows:
1. Create the samples for the 2 domains that we'll route to (e.g., code generation and summarization).
2. Generate the embeddings for the 2 domain prompts.
3. Create a 3rd prompt, generate its embedding, and measure the distance to select which domain it relates to.
4. Construct the router that will take the prompt and automatically generate the answer from the LLM the prompt is routed to based on the distance.

In [10]:
# Import necessary libraries
import numpy as np
from scipy.spatial.distance import cdist
import json
from dotenv import load_dotenv, find_dotenv
import os
import boto3

# loading environment variables that are stored in local file dev.env
local_env_filename = 'bedrock-router-eval.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
REGION = os.environ['REGION']

client = boto3.client(service_name='bedrock-runtime', region_name=REGION)

model_id = "amazon.titan-embed-text-v2:0" #"anthropic.claude-3-haiku-20240307-v1:0" # "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

In [11]:
# Step 1: Create the samples for the 2 domains
# Code Generation Domain
code_gen_prompt = "Write a Python function that calculates the factorial of a given number."
code_gen_model = "code_generation_llm"

In [12]:
# Summarization Domain
summarization_prompt = "Summarize the key points of the given article in a concise paragraph."
summarization_model = "summarization_llm"

In [13]:
# Step 2: Generate the embeddings for the 2 domain prompts
# Create the request for the model.
native_request_code = {"inputText": code_gen_prompt}
native_request_sum = {"inputText": summarization_prompt}

# Convert the native request to JSON.
request_code = json.dumps(native_request_code)
request_sum = json.dumps(native_request_sum)

# Invoke the model with the request.
response_code = client.invoke_model(modelId=model_id, body=request_code)
response_sum = client.invoke_model(modelId=model_id, body=request_sum)

# Decode the model's native response body.
model_response_code = json.loads(response_code["body"].read())
model_response_sum = json.loads(response_sum["body"].read())

# Extract and print the generated embedding and the input text token count.
code_gen_embedding = model_response_code["embedding"]
summarization_embedding = model_response_sum["embedding"]

print("Embedding Code:")
print(code_gen_embedding)

print("Embedding Summarization:")
print(summarization_embedding)

Embedding Code:
[0.0008610966615378857, -0.0039164600893855095, 0.01125982217490673, 0.002255461411550641, -0.0038814914878457785, -0.0025002399925142527, 0.03482851758599281, -0.0509139783680439, -0.031611427664756775, -0.010280707851052284, 0.06406209617853165, -0.002482755808159709, 0.03720637038350105, -0.0077280146069824696, 0.02825446054339409, 0.023638634011149406, 0.005000479985028505, -0.022379770874977112, -0.014406978152692318, 0.02307913824915886, 0.04308106005191803, -0.05371145159006119, 0.017414258792996407, -0.014966472052037716, -0.05790765956044197, -0.03776586428284645, 0.029793070629239082, -0.03888485208153725, 0.02545698918402195, -0.03804561123251915, 0.039444345980882645, -0.02825446054339409, 0.0013375410344451666, -0.03342978283762932, 0.009791149757802486, 0.049515243619680405, 0.031191805377602577, -0.025876609608530998, 0.024757621809840202, 0.02755509316921234, -0.004860606510192156, 0.0255968626588583, -0.0031646394636482, 0.012099063955247402, 0.03580763

In [14]:
# Step 3: Create a 3rd prompt, generate its embedding, and measure the distance
third_prompt = "Explain the concept of object-oriented programming in Python."
native_request = {"inputText": third_prompt}
request = json.dumps(native_request)
response = client.invoke_model(modelId=model_id, body=request)
model_response = json.loads(response["body"].read())
third_prompt_embedding = model_response_code["embedding"]

In [15]:
# Measure the distance to the 2 domain prompts
third_promp_array = np.array(third_prompt_embedding)
code_gen_array = np.array(code_gen_embedding)
sum_array = np.array(summarization_embedding)
code_gen_distance = np.linalg.norm(third_promp_array - code_gen_array)
sum_distance = np.linalg.norm(third_promp_array - sum_array)

print(code_gen_distance)
print(sum_distance)

0.0
1.4240657676293242


In [16]:
# Determine the domain based on the distance
if code_gen_distance < sum_distance:
    print(f"The prompt '{third_prompt}' is routed to the {code_gen_model}.")
else:
    print(f"The prompt '{third_prompt}' is routed to the {summarization_model}.")

The prompt 'Explain the concept of object-oriented programming in Python.' is routed to the code_generation_llm.


In [17]:
# Step 4: Construct the router
def route_prompt(prompt):
    native_request = {"inputText": prompt}
    request = json.dumps(native_request)
    response = client.invoke_model(modelId=model_id, body=request)
    model_response = json.loads(response["body"].read())
    prompt_embedding = model_response_code["embedding"]
    
    code_gen_distance = np.linalg.norm(np.array(prompt_embedding) - code_gen_array)
    summarization_distance = np.linalg.norm(np.array(prompt_embedding) - sum_array)
    
    if code_gen_distance < summarization_distance:
        return code_gen_model, code_gen_model.generate(prompt)
    else:
        return summarization_model, summarization_model.generate(prompt)

# other
scaling the amount of prompts for the 2 domains, taking the average as the anchor point for routing. 
increase the number of domains
testing accuracy over many prompts

In [18]:
def plot_similarity_heatmap(embeddings_a, embeddings_b):
    inner_product = np.inner(embeddings_a, embeddings_b)
    sns.set(font_scale=1.1)
    graph = sns.heatmap(
        inner_product,
        vmin=np.min(inner_product),
        vmax=1,
        cmap="OrRd",
    )