# Pre/Post-Processing for Hugging Face (HF) Text Embeddings Inference (TEI)

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/generative_ai|sm-text_embedding_custom_processing.ipynb)

---

## Overview

In this demo notebook, we demonstrate how to implement pre/post-processing logic for [HF TEI](https://huggingface.co/docs/text-embeddings-inference/en/index) use cases. While it provides improved performance and convenient SageMaker integrations, the [SageMaker TEI image](https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-support-policy.html#pre-built-containers-support-policy-dlc) does not support custom pre/post-processing logic out-of-the-box. For customers looking to customize their logic, we show how to leverage the [SageMaker Transformers Image](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-inference-containers) to load and invoke a HF TEI model with a custom InferenceSpec so that customers can still pre/post-process their results from their SageMaker endpoint(s).

## Prerequisites

This notebook was tested in region `us-west-2` with kernel `conda_python3 (3.10.15 | packaged by conda-forge | (main, Sep 20 2024, 16:37:05) [GCC 13.3.0])` and uses the following (versioned) resources:

| Resource | Value |
| :-------- | :----- |
| TEI Model | jinaai/jina-embeddings-v2-small-en |
| TEI Image (as a control for performance testing) | 246618743249.dkr.ecr.us-west-2.amazonaws.com/tei-cpu:2.0.1-tei1.4.0-cpu-py310-ubuntu22.04 |
| Transformers Image | 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:2.1.0-transformers4.37.0-cpu-py310-ubuntu22.04 |
| Instance Type | ml.m5.4xlarge |

- The accounts must match the account from which the region is pulled from. This varies based on resource. For example, `246618743249` is the account providing the TEI Image for the `us-west-2` region. Please contact the resource providers for more info about target resources. 
- The images used are based on `Python3.10`. The **image versions must match the PySDK version** for reasons related to [pickling](https://docs.python.org/3/library/pickle.html) performed by the PySDK.
- The images used are for target instance type `ml.m5.4xlarge` i.e. `CPU` . Please be sure to use the appropriate TEI and Transformer images for target hardware.
- The TEI model used must fit on the target instance type; otherwise, the inferences will fail (even if the endpoint still reports "online")

## Demonstration

### Dependencies



In [14]:
!pip install sagemaker numpy transformers datasets --upgrade

Collecting numpy
  Using cached numpy-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
INFO: pip is looking at multiple versions of pathos to determine which version is compatible with other requirements. This could take a while.
Collecting pathos (from sagemaker)
  Downloading pathos-0.3.2-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
Downloading pathos-0.3.2-py3-none

### Role

To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access.

In [3]:
from sagemaker import get_execution_role

ROLE = get_execution_role()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


### Constants for our Intended Resources

In [7]:
HF_TEI_MODEL = "jinaai/jina-embeddings-v2-small-en"
ROLE = "arn:aws:iam::987461069402:role/service-role/AmazonSageMaker-ExecutionRole-20240417T141916"
INSTANCE_TYPE = "ml.m5.4xlarge"
TEI_IMAGE = (
    "246618743249.dkr.ecr.us-west-2.amazonaws.com/tei-cpu:2.0.1-tei1.4.0-cpu-py310-ubuntu22.04"
)
TRANSFORMERS_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:2.1.0-transformers4.37.0-cpu-py310-ubuntu22.04"

### Standard Helper Functions to Deploy Endpoints

This notebook will deploy two different endpoints:
1. Endpoint using standard TEI image (as a control to compare against)
2. Endpoint using Transformers image

To simplify this, we define a few helper functions 

##### **deploy()**

To deploy a SageMaker endpoint, we need to first build a SageMaker Model. Since we will later use [InferenceSpec](https://sagemaker.readthedocs.io/en/v2.208.0/api/inference/model_builder.html#sagemaker.serve.spec.inference_spec.InferenceSpec), we should use the SageMaker Python SDK constrcut `ModelBuilder` to build a Model (rather than directly use the `Model` construct, which does not support defining a custom InferenceSpec)

When defining a new model with `ModelBuilder`, it requires a `SchemaBuilder` to understand the input and output data types. This is essential for properly (de)serializing the data. `SchemaBuilder` is able to infer the kind of (de)serializers to use by providing example inputs. For the TEI case, the input and output look like:
- Input: a (stringified) JSON object with the sentences to get embeddings for
- Output: a (nested) lists, which represents the embeddings. 

The exact shapes and names used within these samples are not critical since these are just used to infer the (de)serializers and not reshape or fit actual inputs/outputs into.

Below, we use some Python shorthand involving `**` to define default arguments and allow for overrides if they are passed.

In [8]:
import json
from sagemaker.serve.builder.schema_builder import SchemaBuilder
import sagemaker

sagemaker.model.FrameworkModel

from sagemaker.serve.builder.model_builder import ModelBuilder


def deploy(model_builder_kwargs={}):
    model = ModelBuilder(
        **{
            **dict(
                role_arn=ROLE,
                schema_builder=SchemaBuilder(
                    json.dumps({"inputs": ["hello", "world"]}), [[1, 2, 3], [4, 5, 6]]
                ),
                env_vars={
                    "TS_DISABLE_TOKEN_AUTHORIZATION": "true"  # See https://github.com/pytorch/serve/blob/master/docs/README.md
                },
            ),
            **model_builder_kwargs,  # allow for overridding these arguments if desired
        }
    ).build()
    endpoint = model.deploy(
        initial_instance_count=1,
        instance_type=INSTANCE_TYPE,
    )
    return (model, endpoint)

##### **clean()**


In [9]:
def clean(model, endpoint):
    try:
        endpoint.delete_endpoint()
    except Exception as e:
        print(e)
        pass

    try:
        model.delete_model()
    except Exception as e:
        print(e)
        pass

#### Endpoint-Specific Usage of Helpers

##### TEI Image (as-is, used as control to compare against)

The SageMaker Python SDK and SageMaker TEI image make it convenient to simply specify the intended `HF_TEI_MODEL` with the SageMaker `TEI_IMAGE`

In [10]:
import time


def deploy_tei():
    return deploy(
        model_builder_kwargs=dict(
            name=f"tei-{int(time.time())}",
            image_uri=TEI_IMAGE,
            model=HF_TEI_MODEL,
        )
    )

##### Transformers Image

Since we will be defining custom logic using the Transformers Image, we need to provide our own [`InferenceSpec`](https://sagemaker.readthedocs.io/en/v2.208.0/api/inference/model_builder.html#sagemaker.serve.spec.inference_spec.InferenceSpec). This construct requires the following implementations:
- `load()` : how to load your intended model
- `invoke()` : how to call your loaded model

To do this, we create our own subclass of `InferenceSpec` called `CustomerInferenceSpec`. Then, we provide our implementations of each by overriding the expected methods with the same signature in the base class. 

##### **load()**

[We can load our TEI model `jinaai/jina-embeddings-v2-small-en` using the transformers library as documented by the HF model owners](https://huggingface.co/jinaai/jina-embeddings-v2-base-en#usage).

##### **invoke()**

[We can invoke our loaded TEI model `jinaai/jina-embeddings-v2-small-en` with a single `encode()` call as documented by the HF model owners](https://huggingface.co/jinaai/jina-embeddings-v2-base-en#usage). 

Note: The inputs to `encode()` must be extracted from the input data sent to the endpoint. This is done with a simple JSON load and reference to `inputs`, which is what the endpoint was invoked with.

##### **Dependencies**
The SageMaker Python SDK can auto-detect dependencies at the class level of an InferenceSpec. It will install them in the intended serving container while provisioning our endpoint. 

For this demo, however, we would like a specific version of the HF `transformers` library that complies with the version of Python and other dependencies already in our target image. Therefore, we opt to use the `dependencies` parameter explicitly to specify a particular version of `transformers` to ensure compatibility.

In [43]:
from sagemaker.serve.spec.inference_spec import InferenceSpec


def deploy_transformer():
    class CustomerInferenceSpec(InferenceSpec):
        def load(self, model_dir):
            from transformers import AutoModel

            return AutoModel.from_pretrained(HF_TEI_MODEL, trust_remote_code=True)

        def invoke(self, x, model):
            return model.encode(json.loads(x)["inputs"])

        def preprocess(self, input_data):
            return json.loads(input_data)["inputs"]

        def postprocess(self, predictions):
            assert predictions is not None
            return predictions

    return deploy(
        dict(
            name=f"transformers-{int(time.time())}",
            image_uri=TRANSFORMERS_IMAGE,
            inference_spec=CustomerInferenceSpec(),
            dependencies={
                "custom": [
                    "transformers==4.38.0"  # so we don't override the DLC dependency versions
                ],
            },
        )
    )

### Running the Demo

In order to run the demo, we will need to
1. deploy our (models and) endpoints
2. invoke against our (models and) endpoints
3. clean up our (models and) endpoints 

We can define a just few more helpers for better reuse and customizations if needed in the future

#### Helper Functions to Run This Specific Demo

##### **invoke_many()**

This is a synchronous function that will be sent to a thread using asyncio. 

It a simple invoker used to invoke against the endpoint given a list of samples to call with.

We intentionally do not use batching here for a basic performance test later in this notebook.

Note: This is tightly-coupled with our sample [dataset](https://huggingface.co/datasets/sentence-transformers/stsb). Specfically, the way we run invocations with `predict` against our endpoints is based on how the dataset is structured. Each sample has two sentences. See dataset page for more info.

Note: The `initial_args` are needed only for the endpoint with the TEI Image.

In [12]:
def invoke_many(endpoint, samples):  # intentionally not batching
    results = []
    for sample in samples:
        start = time.perf_counter()
        res = endpoint.predict(
            json.dumps({"inputs": [sample["sentence1"], sample["sentence2"]]}),
            initial_args=(
                {"ContentType": "application/json"} if "tei" in endpoint.endpoint_name else None
            ),
        )
        end = time.perf_counter()
        results.append(
            {
                "latency": end - start,
                "embeddings": res,
                "endpoint": endpoint.endpoint_name,
                "sample": sample,
            }
        )
    return results

##### **basic_performance_test()**

This is an async function that will leverage asyncio to create, invoke, and clean up endpoints.

We use asyncio here for better concurrency since the endpoint (de)provisioning can take a long, indeterminate amount of time.

In this demo, we will 
1. Create two endpoints:
    1. One with TEI Image 
    2. One with Transformers Image
2. Load sample sentences from the [`sentence-transformers/stsb` dataset](https://huggingface.co/datasets/sentence-transformers/stsb).
3. Invoke each endpoint with the samples.
4. Analyze the latencies of the sample invocations.
5. Clean up all of the resources.

In [47]:
from datasets import load_dataset
import numpy as np
import asyncio
import pprint


async def basic_performance_test():
    deployments = []
    try:
        ######################################################
        # Deploy
        ######################################################
        print("Deploying endpoints...")
        deployments = await asyncio.gather(
            *[asyncio.to_thread(deploy_tei), asyncio.to_thread(deploy_transformer)]
        )

        ######################################################
        # Invoke
        ######################################################
        print("Invoking endpoints...")
        samples = load_dataset("sentence-transformers/stsb", streaming=True, split="test").take(500)

        results = await asyncio.gather(
            *[asyncio.to_thread(invoke_many, endpoint, samples) for _, endpoint in deployments]
        )

        ######################################################
        # Analyze
        ######################################################
        print("Analyzing invocations...")
        for invocations in results:
            latencies = np.array([invocation["latency"] for invocation in invocations])
            pprint.pp(
                {
                    "shape": np.shape(latencies),
                    "tm99": np.mean(
                        np.sort(latencies)[: (len(latencies) - int(len(latencies) * 0.01))]
                    ),
                    "p90": np.percentile(latencies, 90),
                    "avg": np.mean(latencies),
                    "max": np.max(latencies),
                    "min": np.min(latencies),
                    "endpoint": invocations[0]["endpoint"],
                },
                indent=4,
            )
    finally:
        ######################################################
        # Clean
        ######################################################
        print("Cleaning resources...")
        errors = await asyncio.gather(
            *[asyncio.to_thread(clean, model, endpoint) for model, endpoint in deployments],
            return_exceptions=True
        )
        for error in errors:
            if error:
                print(error)

    print("Complete!")

### Run the Demo

Note: To see live output from CloudWatch, please ensure your role has `logs:FilterLogEvents` permissions for the created endpoints

Note: The logging output from the endpoints will be in <span style="background-color:#ffdddd"> **red** because it is logged to stderr </span> by default

In [48]:
await basic_performance_test()

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole


Deploying endpoints...


ModelBuilder: INFO:     Either inference spec or model is provided. ModelBuilder is not handling MLflow model input
ModelBuilder: INFO:     Either inference spec or model is provided. ModelBuilder is not handling MLflow model input
ModelBuilder: INFO:     Skipping auto detection as the image uri is provided 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:2.1.0-transformers4.37.0-cpu-py310-ubuntu22.04
ModelBuilder: DEBUG:     Uploading the model resources to bucket=sagemaker-us-west-2-987461069402, key_prefix=huggingface-pytorch-inference-2024-11-02-19-03-52-724.
Uploading model artifacts: 100%|██████████████████████████| 7730/7730 [00:00<00:00, 53197.34bytes/s]
ModelBuilder: DEBUG:     Model resources uploaded to: s3://sagemaker-us-west-2-987461069402/huggingface-pytorch-inference-2024-11-02-19-03-52-724/serve.tar.gz
ModelBuilder: INFO:     ModelBuilder will collect telemetry to help us better understand our user's needs, diagnose issues, and deliver addition

---------!

ModelBuilder: DEBUG:     ModelBuilder metrics emitted.


---!

ModelBuilder: DEBUG:     ModelBuilder metrics emitted.


Invoking endpoints...


INFO:sagemaker:Deleting endpoint configuration with name: transformers-1730574231-2024-11-02-19-03-54-797
INFO:sagemaker:Deleting endpoint configuration with name: tei-1730574231-2024-11-02-19-03-56-898


Analyzing invocations...
{   'shape': (500,),
    'tm99': 0.021322881935375702,
    'p90': 0.027897494498756715,
    'avg': 0.021767640456022492,
    'max': 0.17949875600061205,
    'min': 0.013735082000494003,
    'endpoint': 'tei-1730574231-2024-11-02-19-03-56-898'}
{   'shape': (500,),
    'tm99': 0.03614241738792367,
    'p90': 0.04087031010003557,
    'avg': 0.03672696926404751,
    'max': 0.26481046400112973,
    'min': 0.030846894998830976,
    'endpoint': 'transformers-1730574231-2024-11-02-19-03-54-797'}
Cleaning resources...


INFO:sagemaker:Deleting endpoint with name: transformers-1730574231-2024-11-02-19-03-54-797
INFO:sagemaker:Deleting endpoint with name: tei-1730574231-2024-11-02-19-03-56-898
INFO:sagemaker:Deleting model with name: transformers-1730574231
INFO:sagemaker:Deleting model with name: tei-1730574231


Complete!


### Results

This demo demonstrates:
1. How to achieve custom pre/post-processing using the Transformers image alternative and customizing InferenceSpec
2. Reduced, yet comparable performance versus the SageMaker TEI image while using the Transformers image. If pre/post-processing is needed for the TEI endpoint, then this is may be a viable alternative for production endpoints

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/generative_ai|sm-text_embedding_custom_processing.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/generative_ai|sm-text_embedding_custom_processing.ipynb)