# Deploy and benchmark reranker models on Amazon SageMaker

In information retrieval and natural language processing applications, rerankers have emerged as powerful tools to enhance the accuracy and relevance of search results. Rerankers are specialized techniques or machine learning models designed to optimize the ordering of a set of retrieved items to improve the overall quality of information retrieval systems.

The objective of this notebook is to demonstrate how you can deploy and scale reranker models using Amazon SageMaker.

## Setup

Upgrade the necessary libraries

In [None]:
! pip install -U transformers hf_transfer sagemaker

Instantiate the necessary session paramters

In [None]:
import sagemaker
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri

from sagemaker.parameter import CategoricalParameter 
import time
from datetime import datetime, timedelta
import boto3
from botocore.config import Config
import json

sagemaker_session = sagemaker.session.Session(
    sagemaker_runtime_client=boto3.client(
        "sagemaker-runtime",
        config=Config(connect_timeout=10, retries={"mode": "standard", "total_max_attempts": 20}),
    )
)
region = sagemaker_session.boto_region_name
role = sagemaker.get_execution_role()
bucket = sagemaker_session.default_bucket()

## Create Model Objects

In this section, we create the SageMaker Model object. You can test the 3 options below:
1. Create a SageMaker Jumpstart Model Object
2. Create a HuggingFase SageMaker Model Object from a model downloaded from HuggingFace Hub
3. Create a HuggingFase SageMaker Model Object from a model downloaded from Amazon S3
4. Create a DJL Model Object that uses SageMaker Large Model Inference (LMI) container image

### Option 1: Create a SageMaker Jumpstart Model Object

Amazon SageMaker JumpStart is a machine learning (ML) hub that can help you accelerate your ML journey where you can can compare, and select foundation models (FMs) based on your use case like article summarization and image generation. You can fine-tune or deploy FMs in SageMaker Jumpstart via SageMaker Studio or SDK. You can find the full list of foundation models [here](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html).

In [None]:
from sagemaker.jumpstart.model import JumpStartModel

# create Jumpstart Model object
model_bge_rerank = JumpStartModel(
    model_id='cohere-rerank-multilingual-v2'
)

### Option 2: Create a HuggingFace SageMaker Model Object from a model downloaded from HuggingFace Hub

To deploy a model directly from the 🤗 Hub to SageMaker, define an environment variable when you create a HuggingFaceModel:

* HF_MODEL_ID defines the model ID which is automatically loaded from [huggingface.co/models](huggingface.co/models) when you create a SageMaker endpoint.

In [None]:
config_hf = {
    'HF_MODEL_ID':'BAAI/bge-reranker-v2-m3',
    'DTYPE':'float16'
}

model_name = sagemaker.utils.name_from_base(config_hf["HF_MODEL_ID"].replace("/","-"))

model_bge_rerank = HuggingFaceModel(
    image_uri=get_huggingface_llm_image_uri("huggingface-tei"),
    env=config_hf,
    role=role,
    name=model_name,
)

### Option 3: Create a HuggingFace SageMaker Model Object from a model downloaded from Amazon S3

If you dont want Amazon SageMaker to download the model artifcats from the HuggingFace hub when starting the inference endpoint, you can store the model on Amazon S3.

First, download the model artifacts locally.

In [None]:
from pathlib import Path
import os
 
# set HF_HUB_ENABLE_HF_TRANSFER env var to enable hf-transfer for faster downloads
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
from huggingface_hub import snapshot_download
 
HF_MODEL_ID = "BAAI/bge-reranker-v2-m3"
# create model dir
model_tar_dir = Path(HF_MODEL_ID.split("/")[-1])
model_tar_dir.mkdir(exist_ok=True)
 
# Download model from Hugging Face into model_dir
snapshot_download(
    HF_MODEL_ID,
    local_dir=str(model_tar_dir), # download to model dir
    revision="main", # use a specific revision, e.g. refs/pr/21
    ignore_patterns=["*.msgpack*", "*.h5*", "*.bin*", "assets*"], # to load safetensor weights
)
 
# check if safetensor weights are downloaded and available
assert len(list(model_tar_dir.glob("*.safetensors"))) > 0, "Model download failed"
 

Compress the model artifacts.

In [None]:
parent_dir=os.getcwd()
# change to model dir
os.chdir(str(model_tar_dir))
# use pigz for faster and parallel compression

In [None]:
!tar -cf model.tar.gz --use-compress-program=pigz *

In [None]:
# change back to parent dir
os.chdir(parent_dir)

Upload the compressed model artifacts to Amazon S3.

In [None]:
from sagemaker.s3 import S3Uploader
 
# upload model.tar.gz to s3
s3_model_uri = S3Uploader.upload(local_path=str(model_tar_dir.joinpath("model.tar.gz")), desired_s3_uri=f"s3://{bucket}/{HF_MODEL_ID}")
 
print(f"model uploaded to: {s3_model_uri}")
 

Now, you can create the Model Object.

In [None]:
model_name = sagemaker.utils.name_from_base(HF_MODEL_ID.replace("/","-"))

config_s3 = {
  'HF_MODEL_ID': "/opt/ml/model", # path to where sagemaker stores the model
}

# create Hugging Face Model object
model_bge_rerank = HuggingFaceModel(
    image_uri=get_huggingface_llm_image_uri("huggingface-tei"),
    env=config_s3,
    role=role,
    name=model_name,
    model_data=s3_model_uri # S3 URI where the model artefacts are found
)

### Option 4: Create a DJL Model Object that uses SageMaker Large Model Inference (LMI) container image

DJL Serving is a high performance universal stand-alone model serving solution. It takes a deep learning model, several models, or workflows and makes them available through an HTTP endpoint.

You can use one of the DJL Serving [Deep Learning Containers (DLCs)](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/what-is-dlc.html) to serve your models on AWS. To learn about the supported model types and frameworks, see the [DJL Serving documentation](https://docs.djl.ai/master/index.html).

In this notebook, we will use Large Model Inference (LMI) containers which are a set of high-performance Docker Containers purpose built for large language model (LLM) inference. With these containers, you can leverage high performance open-source inference libraries like vLLM, TensorRT-LLM, Transformers NeuronX to deploy LLMs on AWS SageMaker Endpoints. 

In [None]:
from sagemaker.djl_inference.model import DJLModel


model_id = "BAAI/bge-reranker-v2-m3" # model will be download form Huggingface hub
image_uri = "763104351884.dkr.ecr.{}.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124".format(region)


env = {
    "SERVING_MIN_WORKERS": "1",   # make sure min and max Workers are equals when deploy model on GPU
    "SERVING_MAX_WORKERS": "1",
    "ARGS_RERANKING": "true",
    "OPTION_ROLLING_BATCH": "disable"
}

# create DJL Model object
model_bge_rerank = DJLModel(
    model_id=model_id,
    task="text-classification",
    image_uri=image_uri,
    env=env,
    role=role)

## Deploy the Model to an endpoint

In [None]:
model_bge_rerank_predictor = model_bge_rerank.deploy(
	initial_instance_count=1,
	instance_type="ml.g5.xlarge",
    container_startup_health_check_timeout=300,
    wait=True
  )

Once the model is deployed, test the model invocation.

In [None]:
if isinstance(model_bge_rerank, DJLModel):
    payload = {
        "inputs": [
            {"text": "what is panda?", "text_pair": "A panda is a type of bear that is known for its distinctive black and white coloring."},
            {"text": "what is panda?", "text_pair": "Pandas are native to China and are known for their diet, which consists mostly of bamboo."},
            {"text": "what is panda?", "text_pair": "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."},
            {"text": "what is panda?", "text_pair": "Pandas have become a symbol of conservation due to their status as an endangered species."},
            {"text": "what is panda?", "text_pair": "The panda's distinctive black and white coat serves as camouflage in its natural habitat."},
            {"text": "what is panda?", "text_pair": "Pandas are known for their playful behavior and are a favorite among zoo visitors."},
            {"text": "what is panda?", "text_pair": "There are two main species of panda: the giant panda and the red panda, which are not closely related."},
            {"text": "what is panda?", "text_pair": "Pandas primarily live in temperate forests high in the mountains of southwest China."},
            {"text": "what is panda?", "text_pair": "The giant panda has a large head, heavy body, rounded ears, and a short tail."},
            {"text": "what is panda?", "text_pair": "Efforts to protect panda habitats have led to the establishment of several panda reserves in China."}
        ]
    }
else:
    payload = {
    "query": "what is panda?",
    "texts": [
            "A panda is a type of bear that is known for its distinctive black and white coloring.",
            "Pandas are native to China and are known for their diet, which consists mostly of bamboo.",
            "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.",
            "Pandas have become a symbol of conservation due to their status as an endangered species.",
            "The panda's distinctive black and white coat serves as camouflage in its natural habitat.",
            "Pandas are known for their playful behavior and are a favorite among zoo visitors.",
            "There are two main species of panda: the giant panda and the red panda, which are not closely related.",
            "Pandas primarily live in temperate forests high in the mountains of southwest China.",
            "The giant panda has a large head, heavy body, rounded ears, and a short tail.",
            "Efforts to protect panda habitats have led to the establishment of several panda reserves in China."
        ]
    }



In [None]:
model_bge_rerank_predictor.predict(payload)

## Benchmark the endpoint

Create a benchmark scrip that sends concurrent requests, stores and plots the latencies and throughputs

In [None]:
import time
import concurrent.futures
import numpy as np
import matplotlib.pyplot as plt

# Assuming predictor is already defined and initialized
# and predictor.predict(data=payload) is the method to be benchmarked

def benchmark_predictor(predictor, payload, steps, iterations=5):
    """
    Benchmarks a predictor's performance by measuring latency and throughput 
    under varying levels of concurrent requests.
    
    Args:
        predictor (object): The predictor object with a `predict` method.
        payload (any): The input data to be sent to the predictor.
        steps (list): A list of different numbers of concurrent requests to test.
        iterations (int, optional): The number of iterations for each concurrency level. Default is 5.
    
    Returns:
        tuple: Three lists containing the request counts, latencies, and throughputs.
    """
    latencies = []
    throughputs = []
    request_counts = []

    def send_request():
        """Sends a single request to the predictor and measures its latency."""
        start_time = time.time()
        resp = predictor.predict(data=payload)
        latency = time.time() - start_time
        return latency

    for num_requests in steps:
        iter_latencies = []
        iter_throughputs = []
        
        for _ in range(iterations):
            start_time = time.time()
            
            # Use ThreadPoolExecutor to send concurrent requests
            with concurrent.futures.ThreadPoolExecutor(max_workers=num_requests) as executor:
                futures = [executor.submit(send_request) for _ in range(num_requests)]
                latencies_batch = [future.result() for future in concurrent.futures.as_completed(futures)]
            
            total_time = time.time() - start_time
            
            # Calculate average latency for this iteration
            latency = np.mean(latencies_batch)
            # Calculate throughput for this iteration
            throughput = num_requests / total_time
            
            iter_latencies.append(latency)
            iter_throughputs.append(throughput)
        
        # Calculate average latency and throughput over all iterations
        avg_latency = np.mean(iter_latencies)
        avg_throughput = np.mean(iter_throughputs)
        
        latencies.append(avg_latency)
        throughputs.append(avg_throughput)
        request_counts.append(num_requests)
        
        # Print results for the current number of requests
        print(f"Requests: {num_requests}, Average Latency: {avg_latency:.4f}s, Average Throughput: {avg_throughput:.2f} req/s")

    return request_counts, latencies, throughputs

def plot_metrics(request_counts, latencies, throughputs):
    """
    Plots the benchmarking results, showing the average latency and throughput 
    as a function of the number of concurrent requests.
    
    Args:
        request_counts (list): The list of different numbers of concurrent requests tested.
        latencies (list): The list of average latencies corresponding to the request counts.
        throughputs (list): The list of average throughputs corresponding to the request counts.
    """
    fig, ax1 = plt.subplots()

    color = 'tab:blue'
    ax1.set_xlabel('Number of Concurrent Requests')
    ax1.set_ylabel('Average Latency (s)', color=color)
    ax1.plot(request_counts, latencies, color=color)
    ax1.tick_params(axis='y', labelcolor=color)

    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

    color = 'tab:green'
    ax2.set_ylabel('Throughput (requests/s)', color=color)  # we already handled the x-label with ax1
    ax2.plot(request_counts, throughputs, color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    fig.tight_layout()  # otherwise the right y-label is slightly clipped
    plt.title('Latency and Throughput Benchmarking')
    plt.show()

def plot_latency_vs_throughput(latencies, throughputs, request_counts):
    """
    Plots latency against throughput, with annotations for the number of concurrent requests.
    
    Args:
        latencies (list): The list of average latencies.
        throughputs (list): The list of average throughputs.
        request_counts (list): The list of different numbers of concurrent requests tested.
    """
    plt.figure()
    plt.plot(throughputs, latencies, 'o-')
    plt.xlabel('Throughput (requests/s)')
    plt.ylabel('Average Latency (s)')
    plt.title('Latency vs Throughput')
    plt.grid(True)
    
    # Label each point with the request count
    for i, request_count in enumerate(request_counts):
        plt.annotate(request_count, (throughputs[i], latencies[i]), textcoords="offset points", xytext=(0,10), ha='center')
    
    plt.show()

Benchmark the endpoint and plot the results.

In [None]:
import pandas as pd
df_benchmark = pd.DataFrame(columns=["client_batch_size", "concurrent_request_counts", "latencies", "throughputs"])

In [None]:
min_requests=0 # 2^0=1
max_requests = 9 # 2^9=512
step_size = 1
iterations = 20
client_batch_size = 32

steps = list(map(lambda x:2**x,range(min_requests, max_requests, step_size)))

if isinstance(model_bge_rerank, DJLModel):
    payload = {
        "inputs": [
            {"text": "what is panda?", "text_pair": "A panda is a type of bear that is known for its distinctive black and white coloring."}
        ] * client_batch_size
    }
else:
    payload = {
        "query": "what is panda?",
        "texts": [
            "A panda is a type of bear that is known for its distinctive black and white coloring."
        ] * client_batch_size
    }


request_counts, latencies, throughputs = benchmark_predictor(model_bge_rerank_predictor, payload, steps, iterations)
plot_metrics(request_counts, latencies, throughputs)
plot_latency_vs_throughput(latencies, throughputs, request_counts)

new_data = {
    'client_batch_size': [client_batch_size] * len(request_counts),
    'concurrent_request_counts': request_counts,
    'latencies': latencies,
    'throughputs' : throughputs
}

df_benchmark = df_benchmark.append(pd.DataFrame(new_data), ignore_index=True)
df_benchmark

## Cleanup

In [None]:
model_bge_rerank_predictor.delete_endpoint()