# Run Multiple Models on the Same GPU with Amazon SageMaker Multi-Model Endpoints Powered by NVIDIA Triton Inference Server

This notebook was run on a `ml.g4dn.xlarge` SageMaker Notebook instance type, with `conda_pytorch_p38` kernel.

## Prerequisites

Install the necessary Python modules to use and interact with [NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/).

In [None]:
! pip install torch==1.10.0 sagemaker transformers==4.9.1 tritonclient[all]

# Part 1 - Setup

In [None]:
import argparse
import boto3
import copy
import datetime
import json
import numpy as np
import os
import pandas as pd
import pprint
import re
import sagemaker
import sys
import time
from time import gmtime, strftime
import tritonclient.http as http_client

In [None]:
session = boto3.Session()
role = sagemaker.get_execution_role()

sm_client = session.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=session)
sm_runtime_client = boto3.client("sagemaker-runtime")

region = boto3.Session().region_name

In [None]:
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}

***

# Part 2 - Save Model and tokenizer

We now save the tokenizer and the model to folders within the model repository

### Parameters:

* `model_name`: Model identifier from the Hugging Face model hub library

In [None]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"
from transformers import AutoTokenizer,AutoModel

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
tokenizer.save_pretrained('model_repo/e2e/tokenizer')
model.save_pretrained('model_repo/e2e/model')

# Part 3 - Run Local Triton Inference Server

> **WARNING**: The cells under part 3 will only work if run within a SageMaker Notebook Instance!




The following cells run the Triton Inference Server container in the background and load all the models within the folder `/model_repo`. The docker won't fail if one or more of the model fails because of `--exit-on-error=false`, which is useful for iterative code and model repository building. Remove `-d` to see the logs.

In [None]:
!sudo docker system prune -f

In [None]:
!docker run --gpus=all -d --shm-size=4G --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/model_repo:/model_repository nvcr.io/nvidia/tritonserver:22.09-py3 tritonserver --model-repository=/model_repository --exit-on-error=false --strict-model-config=false
# time.sleep(20)

In [None]:
CONTAINER_ID=!docker container ls -q
FIRST_CONTAINER_ID = CONTAINER_ID[0]

Uncomment the next cell and run it to view the container logs and understand Triton model loading.

In [None]:
# !docker logs $FIRST_CONTAINER_ID -f
!docker logs $FIRST_CONTAINER_ID

## Test TensorRT model by invoking the local Triton Server

In [None]:
# Start a local Triton client
try:
    triton_client = http_client.InferenceServerClient(url="localhost:8000", verbose=True)
except Exception as e:
    print("context creation failed: " + str(e))
    sys.exit()

In [None]:
# Create inputs to send to Triton
model_name = "e2e"

text_inputs = ["Sentence 1", "Sentence 2"]

# Text is passed to Trtion as BYTES
inputs = []
inputs.append(http_client.InferInput("INPUT0", [len(text_inputs), 1], "BYTES"))

# We need to structure batch inputs as such
batch_request = [[text_inputs[i]] for i in range(len(text_inputs))]
input0_real = np.array(batch_request, dtype=np.object_)

inputs[0].set_data_from_numpy(input0_real, binary_data=False)

In [None]:
outputs = []

outputs.append(http_client.InferRequestedOutput("SENT_EMBED"))

In [None]:
results = triton_client.infer(model_name=model_name, inputs=inputs, outputs=outputs)

In [None]:
outputs0 = results.as_numpy("SENT_EMBED")


In [None]:
for idx, output in enumerate(outputs0):
    print(text_inputs[idx])
    print(output)

In [None]:
# Use this to stop the container that was started in detached mode
!docker kill $FIRST_CONTAINER_ID

***

# Part 4 - Deploy Triton to SageMaker MME Endpoint

# MME Experiments

In [None]:
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"

triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.09-py3".format(
    account_id=account_id_map[region], region=region, base=base
)

triton_image_uri

In [None]:
bucket = sagemaker_session.default_bucket()
print(bucket)

In [None]:
!tar -C model_repo/ -czf e2e.tar.gz e2e
prefix = 'bert_mme_gpu'
e2e_uri = sagemaker_session.upload_data(path="e2e.tar.gz", key_prefix=prefix)

In [None]:
model_data_url = f"s3://{bucket}/{prefix}/"
!aws s3 ls $model_data_url

In [None]:
model_data_url = f"s3://{bucket}/{prefix}/"

container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_data_url,
    "Mode": "MultiModel",
}

In [None]:
sm_model_name = "triton-e2e-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
endpoint_config_name = "triton-e2e-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g4dn.xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

In [None]:
endpoint_name = "triton-e2e-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

## Test endpoint

In [None]:
text_inputs

In [None]:
http_client.InferInput("INPUT0", [len(text_inputs), 1], "BYTES")

In [None]:
text_inputs = ["Sentence 1", "Sentence 2"]

inputs = []
inputs.append(http_client.InferInput("INPUT0", [len(text_inputs), 1], "BYTES"))

batch_request = [[text_inputs[i]] for i in range(len(text_inputs))]

input0_real = np.array(batch_request, dtype=np.object_)

inputs[0].set_data_from_numpy(input0_real, binary_data=False)

len(input0_real)

In [None]:
outputs = []

outputs.append(http_client.InferRequestedOutput("SENT_EMBED"))

In [None]:
outputs

In [None]:
request_body, header_length = http_client.InferenceServerClient.generate_request_body(
    inputs, outputs=outputs
)

print(request_body)

In [None]:
response = sm_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    # ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
        # header_length
    # ),
    ContentType='application/octet-stream',
    Body=request_body,
    TargetModel='e2e.tar.gz'
)

In [None]:
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = http_client.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)

outputs_data = result.as_numpy("SENT_EMBED")

for idx, output in enumerate(outputs_data):
    print(text_inputs[idx])
    print(output)

# Part 5 - Test SageMaker Endpoint with Java Client

## Build Java App Docker Container

Get credentials first

In [None]:
!curl http://169.254.169.254/latest/meta-data/iam/security-credentials/BaseNotebookInstanceEc2InstanceRole>tmp.json
f = open('tmp.json')
metadata=json.load(f)
os.remove('tmp.json')

In [None]:
with open('./java_client/credentials', 'a') as credentials_file:
    credentials_file.write("[default]\n")
    credentials_file.write(f"aws_access_key_id = {metadata['AccessKeyId']}\n")
    credentials_file.write(f"aws_secret_access_key = {metadata['SecretAccessKey']}\n")
    credentials_file.write(f"aws_session_token = {metadata['Token']}\n")

### Build the Docker Image

In [None]:
!docker build  -t sagemaker-runtime-java-example ./java_client

In [None]:
os.remove('./java_client/credentials')

### Run the Docker Container to invoke the endpoint from Java Client

In [None]:
!docker run -e AWS_REGION=us-east-1 -e ENDPOINT_NAME={endpoint_name} sagemaker-runtime-java-example

# Part 6 - Delete the Endpoint

In [None]:
#sm_client.delete_endpoint(EndpointName=endpoint_name)