## Using LMI Containers with SageMaker Async Endpoints

This notebook will demonstrate usage of LMI DLCs to host models on SageMaker Async Inference Endpoints. Support for Async Inference with LMI requires using 0.31.0 container versions or later.

### Install and Update Dependencies

In [None]:
%pip install -U sagemaker boto3

### Create and deploy a Model for Async Inference

You need to create an [AsyncInferenceConfig](https://sagemaker.readthedocs.io/en/stable/api/inference/async_inference.html) in order to deploy an async endpoint. In this example, we will be using the default AsyncInferenceConfig, but you are welcome to customize it as needed.

This example deploys the [Llama3.1-8b-Instruct](meta-llama/Llama-3.1-8B-Instruct) model. This is a gated model and requires a HuggingFace account that has been granted permissions to the model, and a valid hub access token. If you do not have access to this model, you can use another text generation model. In this example we will use the OpenAI Chat Completions request format, so you need to use a model with a chat template.

In [None]:
import sagemaker
from sagemaker.djl_inference import DJLModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.async_inference.waiter_config import WaiterConfig
from sagemaker.session import Session

role = sagemaker.get_execution_role()
image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124"
session = Session()

model = DJLModel(
    image_uri=image_uri,
    env = {
        "HF_MODEL_ID": "meta-llama/Llama-3.1-8B-Instruct",
        "HF_TOKEN": "<your hub token>"
    },
    role=role,
    sagemaker_session=session,
)

In [None]:
endpoint_name = sagemaker.utils.name_from_base("my-lmi-async-endpoint")
async_inference_config = AsyncInferenceConfig()

model.deploy(
    initial_instance_count=1,
    instance_type="ml.g6.12xlarge",
    endpoint_name=endpoint_name,
    async_inference_config=async_inference_config,
    container_startup_health_check_timeout=2400,
)

### Create Sample inputs and upload to s3

Async Endpoints are invoked with an s3 object that contains your inference request. We will create two sample inference requests and upload them to s3.

Async inference is not compatible with streaming. You cannot specify `"stream": true` in the payload.

In [None]:
%%writefile sample_inputs.json
{
    "inputs": "Please give me a 10 day itinerary for my trip to New York. Sure, starting on day 1 ",
    "parameters": {
        "temperature": 0.6,
        "top_p": 0.9,
        "max_new_tokens": 1024
    }
}

In [None]:
%%writefile sample_messages.json
{
    "messages": [
        {"role": "user", "content": "Please give me a 10 day itinerary for my trip to New York. Sure, starting on day 1"}
    ],
    "temperature": 0.6,
    "top_p": 0.9,
    "max_tokens": 1024
}

In [None]:
bucket = session.default_bucket()

# Upload the request following the default LMI schema
sample_input_path = session.upload_data(
    "sample_inputs.json",
    bucket=bucket,
    key_prefix="async_lmi_inputs"
)
# Upload the request following the OpenAI Chat Completions schema
sample_messages_path = session.upload_data(
    "sample_messages.json",
    bucket=bucket,
    key_prefix="async_lmi_inputs"
)

### Create the Async Predictor and make inference requests

The [AsyncPredictor](https://sagemaker.readthedocs.io/en/stable/api/inference/predictor_async.html) provides utility methods for interacting with the async endpoint and making inference requests.

In this example, we'll be using the predict_async method as it is non-blocking. We also specify `initial_args={"ContentType": "application/json"}` so that the request gets serialized correctly and can be hanlded by the container.

In [None]:
predictor = sagemaker.Predictor(endpoint_name=endpoint_name)
async_predictor = sagemaker.predictor_async.AsyncPredictor(predictor)

In [None]:
async_response_sample_inputs = async_predictor.predict_async(
    input_path=sample_input_path,
    initial_args={"ContentType": "application/json"},
)
async_response_sample_messages = async_predictor.predict_async(
    input_path=sample_messages_path,
    initial_args={"ContentType": "application/json"},
)

### Poll for Inference Completion

You can use the [WaiterConfig](https://sagemaker.readthedocs.io/en/stable/api/inference/async_inference.html#sagemaker.async_inference.waiter_config.WaiterConfig) to configure the polling cadence for inference results. We will use the default WaiterConfig in this example

In [None]:
import json
waiter_config = WaiterConfig()

inputs_result = async_response_sample_inputs.get_result(waiter_config=waiter_config)
messages_result = async_response_sample_messages.get_result(waiter_config=waiter_config)

print(f"Result from LMI style request is:\n {json.loads(inputs_result)}")
print("--------------------------")
print(f"Result from OpenAI style request is:\n {json.loads(messages_result)}")

### Clean up Resources

In [None]:
async_predictor.delete_endpoint()
model.delete_model()