In [None]:
# Requires GPU Strictly. This tutorial was implemented with g5.2xlarge Notebook Instance

In [1]:
import sagemaker

role = sagemaker.get_execution_role()
print(role)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml
arn:aws:iam::648758970526:role/service-role/AmazonSageMakerServiceCatalogProductsUseRole


In [None]:
%pip install transformers tritonclient[http]

Collecting transformers
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
Collecting tritonclient[http]
  Downloading tritonclient-2.63.0-py3-none-manylinux1_x86_64.whl.metadata (2.9 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface-hub<1.0,>=0.34.0->transformers)
  Downloading hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Collecting perf-analyzer (from tritonclient[http])
  Downloading perf_analyzer-0.1.0-py3-none-any.whl.metadata (135 bytes)
Collecting python-rapidjson>=0.9.1

In [3]:
import transformers
import boto3, json, sagemaker, time
from sagemaker import get_execution_role

sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")
bucket = sagemaker_session.default_bucket()
default_bucket_prefix = sagemaker_session.default_bucket_prefix
print(bucket)

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


sagemaker-us-east-1-648758970526


In [4]:
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}

In [5]:
region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

In [6]:
base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:23.02-py3".format(
    account_id=account_id_map[region], region=region, base=base
)

In [7]:
triton_image_uri

'785573368785.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tritonserver:23.02-py3'

In [8]:
import tritonclient.http as httpclient
from transformers import BertTokenizer
import numpy as np


def get_tokenizer():
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    return tokenizer


def tokenize_text(text):
    enc = get_tokenizer()
    encoded_text = enc(text, padding="max_length", max_length=128)
    return encoded_text["input_ids"], encoded_text["attention_mask"]


def _get_sample_tokenized_text_binary(text, input_names, output_names):
    inputs = []
    outputs = []
    inputs.append(httpclient.InferInput(input_names[0], [1, 128], "INT32"))
    inputs.append(httpclient.InferInput(input_names[1], [1, 128], "INT32"))
    indexed_tokens, attention_mask = tokenize_text(text)

    indexed_tokens = np.array(indexed_tokens, dtype=np.int32)
    indexed_tokens = np.expand_dims(indexed_tokens, axis=0)
    inputs[0].set_data_from_numpy(indexed_tokens, binary_data=True)

    attention_mask = np.array(attention_mask, dtype=np.int32)
    attention_mask = np.expand_dims(attention_mask, axis=0)
    inputs[1].set_data_from_numpy(attention_mask, binary_data=True)

    outputs.append(httpclient.InferRequestedOutput(output_names[0], binary_data=True))
    outputs.append(httpclient.InferRequestedOutput(output_names[1], binary_data=True))
    request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
        inputs, outputs=outputs
    )
    return request_body, header_length


def get_sample_tokenized_text_binary_pt(text):
    return _get_sample_tokenized_text_binary(
        text, ["INPUT__0", "INPUT__1"], ["OUTPUT__0", "1634__1"]
    )


def get_sample_tokenized_text_binary_trt(text):
    return _get_sample_tokenized_text_binary(
        text, ["token_ids", "attn_mask"], ["output", "pooled_output"]
    )

In [9]:
!mkdir -p workspace

In [12]:
%%writefile workspace/generate_models.sh
#!/bin/bash
python -m pip install transformers==4.26.1
python onnx_exporter.py

trtexec --onnx=model.onnx --saveEngine=model_bs16.plan --minShapes=token_ids:1x128,attn_mask:1x128 --optShapes=token_ids:16x128,attn_mask:16x128 --maxShapes=token_ids:128x128,attn_mask:128x128 --fp16 --verbose --workspace=14000 | tee conversion_bs16_dy.txt

Overwriting workspace/generate_models.sh


In [11]:
%%writefile workspace/onnx_exporter.py
import torch
from transformers import BertModel
import argparse
import os

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save", default="model.onnx")
    args = parser.parse_args()

    model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

    bs = 1
    seq_len = 128
    dummy_inputs = (torch.randint(1000, (bs, seq_len)), torch.zeros(bs, seq_len, dtype=torch.int))

    torch.onnx.export(
        model,
        dummy_inputs,
        args.save,
        export_params=True,
        opset_version=10,
        input_names=["token_ids", "attn_mask"],
        output_names=["output","pooled_output"],
        dynamic_axes={"token_ids": [0, 1], "attn_mask": [0, 1], "output": [0]},
    )

    print("Saved {}".format(args.save))

Overwriting workspace/onnx_exporter.py


In [13]:
!docker run --gpus=all --rm -it \
            -v `pwd`/workspace:/workspace nvcr.io/nvidia/pytorch:23.02-py3 \
            /bin/bash generate_models.sh

Unable to find image 'nvcr.io/nvidia/pytorch:23.02-py3' locally
23.02-py3: Pulling from nvidia/pytorch

[1Bf31133a9: Pulling fs layer 
[1B5fe751be: Pulling fs layer 
[1Bc9e56778: Pulling fs layer 
[1B9c08065c: Pulling fs layer 
[1B129e9daa: Pulling fs layer 
[1Bddf5daef: Pulling fs layer 
[1Bb1e8d85a: Pulling fs layer 
[1Bbd04fbf5: Pulling fs layer 
[1B0bd8cf35: Pulling fs layer 
[1B15e856a0: Pulling fs layer 
[1B618a5cab: Pulling fs layer 
[1Bc8820090: Pulling fs layer 
[1B305027fa: Pulling fs layer 
[1Bb700ef54: Pulling fs layer 
[1B21f602b1: Pulling fs layer 
[1B08dbe941: Pulling fs layer 
[1B6a08c26e: Pulling fs layer 
[1B6dbcf033: Pulling fs layer 
[1B753787ae: Pulling fs layer 
[1B11624008: Pulling fs layer 
[1Bc6e7f54b: Pulling fs layer 
[1B6f066e66: Pulling fs layer 
[1B1dd6989f: Pulling fs layer 
[1B0ace1bc4: Pulling fs layer 
[1B46083198: Pulling fs layer 
[1Ba7a88287: Pulling fs layer 
[1B73547e7d: Pulling fs layer 
[1B6116a3dd: Pulling fs layer 


In [14]:
!mkdir -p model_repo_0/bert_0


In [15]:
%%writefile model_repo_0/bert_0/config.pbtxt
name: "bert"
platform: "tensorrt_plan"
max_batch_size: 128
input [
  {
    name: "token_ids"
    data_type: TYPE_INT32
    dims: [128]
  },
  {
    name: "attn_mask"
    data_type: TYPE_INT32
    dims: [128]
  }
]
output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [128, 768]
  },
  {
    name: "pooled_output"
    data_type: TYPE_FP32
    dims: [768]
  }
]
instance_group {
  count: 1
  kind: KIND_GPU
}
dynamic_batching {
  preferred_batch_size: 16
}

Writing model_repo_0/bert_0/config.pbtxt


In [16]:
!mkdir -p model_repo_0/bert_0/1/
!cp workspace/model_bs16.plan model_repo_0/bert_0/1/model.plan

In [17]:
import os
import shutil

N = 5
prefix = "bert-mme"

# If a default bucket prefix is specified, append it to the s3 path
if default_bucket_prefix:
    prefix = f"{default_bucket_prefix}/{prefix}"
    
model_repo_base = "model_repo"


# Get model names from model_repo_0
model_names = [
    name
    for name in os.listdir(f"{model_repo_base}_0")
    if os.path.isdir(f"{model_repo_base}_0/{name}")
]

for i in range(N):
    # Make copy of previous model repo, increment # id
    shutil.copytree(f"{model_repo_base}_0", f"{model_repo_base}_{i+1}")
    time.sleep(5)
    for name in model_names:
        model_dirs_path = f"{model_repo_base}_{i+1}/{name}"

        # Open each model's config file to increment model # id there
        fin = open(f"{model_dirs_path}/config.pbtxt", "rt")
        data = fin.read()
        data = data.replace(name, name[:-1] + str(i + 1))
        fin.close()
        fin = open(f"{model_dirs_path}/config.pbtxt", "wt")
        fin.write(data)
        fin.close()

        # Change model directory name to match new config
        os.rename(model_dirs_path, model_dirs_path[:-1] + str(i + 1))
        time.sleep(2)

    if i == 0:
        tar_file_name = f"bert-{i}.tar.gz"
        model_repo_target = f"{model_repo_base}_{i}/"
        !tar -C $model_repo_target -czf $tar_file_name .
        sagemaker_session.upload_data(path=tar_file_name, key_prefix=prefix)

    tar_file_name = f"bert-{i+1}.tar.gz"
    model_repo_target = f"{model_repo_base}_{i+1}/"
    !tar -C $model_repo_target -czf $tar_file_name .
    sagemaker_session.upload_data(path=tar_file_name, key_prefix=prefix)
    !sudo rm -r "$tar_file_name" "$model_repo_target"

In [18]:
sm_model_name = "triton-nlp-bert-trt-mme-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
model_data_uri = f"s3://{bucket}/{prefix}/"
container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_data_uri,
    #     "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "bert"},
    "Mode": "MultiModel",
}

create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Model Arn: arn:aws:sagemaker:us-east-1:648758970526:model/triton-nlp-bert-trt-mme-2025-12-14-20-58-57


In [19]:
endpoint_config_name = "triton-nlp-bert-trt-mme-" + time.strftime(
    "%Y-%m-%d-%H-%M-%S", time.gmtime()
)

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g5.xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Endpoint Config Arn: arn:aws:sagemaker:us-east-1:648758970526:endpoint-config/triton-nlp-bert-trt-mme-2025-12-14-20-59-06


In [20]:
endpoint_name = "triton-nlp-bert-trt-mme-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

Endpoint Arn: arn:aws:sagemaker:us-east-1:648758970526:endpoint/triton-nlp-bert-trt-mme-2025-12-14-20-59-17


In [21]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: InService
Arn: arn:aws:sagemaker:us-east-1:648758970526:endpoint/triton-nlp-bert-trt-mme-2025-12-14-20-59-17
Status: InService


In [22]:
text_triton = "Triton Inference Server provides a cloud and edge inferencing solution optimized for both CPUs and GPUs."
input_ids, attention_mask = tokenize_text(text_triton)

payload = {
    "inputs": [
        {"name": "token_ids", "shape": [1, 128], "datatype": "INT32", "data": input_ids},
        {"name": "attn_mask", "shape": [1, 128], "datatype": "INT32", "data": attention_mask},
    ]
}

for i in range(N):
    response = client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/octet-stream",
        Body=json.dumps(payload),
        TargetModel=f"bert-{i}.tar.gz",
    )

    print(json.loads(response["Body"].read().decode("utf8")))

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

{'model_name': '724d9b62ee3afe8e6275eb469ecfd7c9', 'model_version': '1', 'outputs': [{'name': 'output', 'datatype': 'FP32', 'shape': [1, 128, 768], 'data': [-1.0244140625, -0.303955078125, 0.1134033203125, 0.07061767578125, -0.298583984375, -0.126220703125, -0.1038818359375, 0.281982421875, 0.0196990966796875, -0.5478515625, -0.061126708984375, -0.074951171875, -0.354736328125, 0.5732421875, 0.3017578125, -0.1473388671875, -0.339599609375, 0.6943359375, 0.26025390625, 0.1678466796875, -0.6796875, -0.666015625, 0.054229736328125, -0.08807373046875, -0.3212890625, 0.09197998046875, -0.2015380859375, -0.416259765625, 0.20361328125, 0.348388671875, -0.415283203125, 0.283935546875, -0.2177734375, -0.9375, 0.71826171875, -0.37255859375, -0.2578125, -0.162353515625, -0.2125244140625, -0.0102386474609375, -0.476806640625, -0.18310546875, 0.439697265625, -0.08349609375, 0.1617431640625, -0.08245849609375, -2.955078125, -0.1722412109375, -0.625, -0.52294921875, -0.035888671875, 0.24267578125, -0

In [23]:
text_sm = "Amazon SageMaker helps data scientists and developers to prepare, build, train, and deploy high-quality machine learning (ML) models quickly by bringing together a broad set of capabilities purpose-built for ML."
request_body, header_length = get_sample_tokenized_text_binary_trt(text_sm)

response = client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
        header_length
    ),
    Body=request_body,
    TargetModel="bert-0.tar.gz",
)

# Parse json header size length from the response
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)
# print(response)
# print(result)
output0_data = result.as_numpy("output")
output1_data = result.as_numpy("pooled_output")
print(output0_data)
print(output1_data)

[[[-0.9008789  -0.30004883  0.11590576 ... -0.56103516 -0.5209961
    0.47387695]
  [-0.17492676  0.8105469   0.05801392 ... -0.13537598  0.00822449
   -0.32055664]
  [-0.14343262 -0.0848999   1.0869141  ... -0.90625    -0.21081543
   -0.20910645]
  ...
  [-0.6088867  -0.53125     0.20507812 ...  0.1517334   0.07647705
   -0.5444336 ]
  [-0.59765625 -0.57714844  0.21948242 ...  0.16455078  0.12237549
   -0.4350586 ]
  [-0.53125    -0.5488281   0.31274414 ...  0.16809082  0.20324707
   -0.14196777]]]
[[-4.46044922e-01 -1.72729492e-01 -7.30957031e-01 -8.07495117e-02
   4.81933594e-01 -2.93212891e-01 -7.35351562e-01  1.06445312e-01
  -5.34179688e-01 -9.88281250e-01 -4.12841797e-01  5.55664062e-01
   8.27148438e-01 -1.84082031e-01  1.14257812e-01  1.34643555e-01
   1.12426758e-01 -2.86376953e-01  3.17626953e-01  9.35058594e-01
   9.72900391e-02  1.00000000e+00 -2.39746094e-01  3.68164062e-01
   3.50830078e-01  5.79101562e-01 -3.80371094e-01  5.43457031e-01
   6.80664062e-01  5.91308594e-01

In [24]:
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_model(ModelName=sm_model_name)

{'ResponseMetadata': {'RequestId': 'a1b42fb4-3bc3-451e-b1c6-e5b9185597bb',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'a1b42fb4-3bc3-451e-b1c6-e5b9185597bb',
   'strict-transport-security': 'max-age=47304000; includeSubDomains',
   'x-frame-options': 'DENY',
   'content-security-policy': "frame-ancestors 'none'",
   'cache-control': 'no-cache, no-store, must-revalidate',
   'x-content-type-options': 'nosniff',
   'content-type': 'application/x-amz-json-1.1',
   'date': 'Sun, 14 Dec 2025 21:08:36 GMT',
   'content-length': '0'},
  'RetryAttempts': 0}}