In [None]:
!pip install -qU awscli boto3 sagemaker --quiet
!pip install tritonclient[http] --quiet

In [None]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role

sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")

In [None]:
!sudo amazon-linux-extras install epel -y
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash
!sudo yum install git-lfs -y

In [None]:
MODEL_NAME="t5-small"
MODEL_TYPE="t5"

# For BART
# MODEL_NAME="bart-base"
# MODEL_TYPE="bart"

In [None]:
!git lfs install

In [None]:
!git clone https://huggingface.co/google-t5/t5-small workspace/hf_models/$MODEL_NAME
# !git clone git clone https://huggingface.co/facebook/bart-base workspace/hf_models/$MODEL_NAME

In the [generate_trtllm_triton_model_repo.sh](trtllm_backend_sagemaker/workspace/generate_trtllm_triton_model_repo.sh) script we build the TRT-LLM engine for encoder-decoder T5/BART model and prepare the Triton Model Repository. In this example we build TP Size=1 single_GPU engine with beam search (max beam width = 2), input len = 1024, output len = 200. To change this edit [generate_trtllm_triton_model_repo.sh](trtllm_backend_sagemaker/workspace/generate_trtllm_triton_model_repo.sh) script. 

In [None]:
TRITON_IMAGE_URI="nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3"

In [None]:
!docker run --gpus all --ulimit memlock=-1 --shm-size=12g -v ${PWD}/workspace:/workspace -w /workspace $TRITON_IMAGE_URI \
/bin/bash generate_trtllm_triton_model_repo.sh

In [None]:
!pip install tree

```
triton_model_repo/
├── ensemble
│   ├── 1
│   └── config.pbtxt
├── postprocessing
│   ├── 1
│   │   └── model.py
│   └── config.pbtxt
├── preprocessing
│   ├── 1
│   │   └── model.py
│   └── config.pbtxt
└── tensorrt_llm
    ├── 1
    │   ├── engines
    │   │   └── t5-small
    │   │       ├── decoder
    │   │       └── encoder
    │   ├── hf_models
    │   │   └── t5-small
    │   │       ├── config.json
    │   │       ├── flax_model.msgpack
    │   │       ├── generation_config.json
    │   │       ├── model.safetensors
    │   │       ├── onnx
    │   │       ├── pytorch_model.bin
    │   │       ├── README.md
    │   │       ├── rust_model.ot
    │   │       ├── spiece.model
    │   │       ├── tf_model.h5
    │   │       ├── tokenizer_config.json
    │   │       └── tokenizer.json
    │   └── model.py
    └── config.pbtxt
```

Next we push this image to ECR

In [None]:
!docker tag nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3 triton-trtllm
!bash push_ecr.sh triton-trtllm

In [None]:
triton_image_uri = ACCOUNT_ID.dkr.ecr.us-east-1.amazonaws.com/triton-trtllm:latest"

For a simple use case we will take the pre-trained NLP Bert model from Hugging Face and deploy it on SageMaker with Triton as the model server. The script for exporting this model can be found here. This is run as part of the generate_models.sh script from the previous cell. After the model is serialized we package it into the format that Triton and SageMaker expect it to be. We used the pre-configured config.pbtxt file provided with this repo here to specify model configuration which Triton uses to load the model. We tar the model directory and upload it to s3 to later create a SageMaker Model.

## Packaging model files and uploading to s3

In [None]:
!tar --exclude='.ipynb_checkpoints' --exclude='*.bin' \
--exclude='*.h5' --exclude='*.safetensors' --exclude="onnx" \
--exclude='.git*' --exclude='.gitignore' --exclude='.gitattributes' --exclude='.gitmodules' \
-czvf model.tar.gz -C workspace/triton_model_repo/ .

In [None]:
model_uri = s3://sagemaker-us-east-1-ACCOUNTID/triton-trtllm-model/model.tar.gz"

In [None]:
model_uri = sagemaker_session.upload_data(path="model.tar.gz", key_prefix="triton-trtllm-model")

## Create SageMaker Endpoint

We start off by creating a sagemaker model from the model files we uploaded to s3 in the previous step.

In this step we also provide an additional Environment Variable i.e. SAGEMAKER_TRITON_DEFAULT_MODEL_NAME which specifies the name of the model to be loaded by Triton. The value of this key should match the folder name in the model package uploaded to s3. This variable is optional in case of a single model. In case of ensemble models, this key has to be specified for Triton to startup in SageMaker.

Additionally, customers can set SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT and SAGEMAKER_TRITON_THREAD_COUNT for optimizing the thread counts.

In [None]:
sm_model_name = "triton-trtllm-model-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "ensemble"},
}

create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Using the model above, we create an endpoint configuration where we can specify the type and number of instances we want in the endpoint.

In [None]:
endpoint_config_name = "triton-trtllm-model-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g5.xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to InService once the deployment is successful.

In [None]:
endpoint_name = "triton-trtllm-model-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

## Run inference
Once we have the endpoint running we can use a sample text to do an inference using json as the payload format. 

In [None]:
def invoke_endpoint_test(text_input, max_tokens,beam_width,temperature,repetition_penalty,min_length,bad_words,stop_words, endpoint_name): 
    payload = {}
    payload["inputs"] = [{"name" : "text_input", "data" : [text_input], "datatype" : "BYTES", "shape" : [1,1]},
        {"name" : "beam_width", "data" : [beam_width], "datatype" : np_to_triton_dtype(np.int32), "shape" : [1,1]}, 
        {"name" : "max_tokens", "data" : [max_tokens], "datatype" : np_to_triton_dtype(np.int32), "shape" : [1,1]},
        {"name" : "temperature", "data" : [temperature], "datatype" : np_to_triton_dtype(np.float32), "shape" : [1,1]},
        {"name" : "repetition_penalty", "data" : [repetition_penalty], "datatype" : np_to_triton_dtype(np.float32), "shape" : [1,1]},
        {"name" : "min_length", "data" : [min_length], "datatype" : np_to_triton_dtype(np.float32), "shape" : [1,1]},
        {"name" : "bad_words", "data" : [bad_words], "datatype" : "BYTES", "shape" : [1,1]},
        {"name" : "stop_words", "data" : [stop_words], "datatype" : "BYTES", "shape" : [1,1]},
        ]
    response = smr_client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=json.dumps(payload)
    )
    response_str = response["Body"].read().decode()
    json_object = json.loads(response_str)
    return json_object['outputs']

## Terminate endpoint and clean up artifacts

In [None]:
sm.delete_model(ModelName=sm_model_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_endpoint(EndpointName=endpoint_name)