In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# https://github.com/aws/amazon-sagemaker-examples/blob/main/multi-model-endpoints/mme-on-gpu/cv/resnet50_mme_with_gpu.ipynb

# Prepare PyTorch Model
!docker run \
  --gpus=all \
  --volume="${PWD}/workspace:/workspace" \
  --interactive \
  --tty \
  --rm \
  nvcr.io/nvidia/pytorch:22.07-py3 \
    /bin/bash generate_model_pytorch.sh

In [None]:
!mkdir -p triton-serve-pt/resnet-50/

In [None]:
%%writefile triton-serve-pt/resnet-50/config.pbtxt
name: "resnet"
platform: "pytorch_libtorch"
max_batch_size: 128
input {
  name: "INPUT__0"
  data_type: TYPE_FP32
  dims: 3
  dims: 224
  dims: 224
}
output {
  name: "OUTPUT__0"
  data_type: TYPE_FP32
  dims: 1000
}

In [None]:
# Prepare TensorRT Model
!docker run \
  --gpus=all \
  --volume="${PWD}/workspace:/workspace" \
  --interactive \
  --tty \
  --rm \
  nvcr.io/nvidia/pytorch:22.07-py3 \
    /bin/bash generate_model_trt.sh

In [None]:
!mkdir -p triton-serve-trt/resnet-50/

In [None]:
%%writefile triton-serve-trt/resnet-50/config.pbtxt
name: "resnet"
platform: "tensorrt_plan"
max_batch_size: 128
input {
  name: "input"
  data_type: TYPE_FP32
  dims: 3
  dims: 224
  dims: 224
}
output {
  name: "output"
  data_type: TYPE_FP32
  dims: 1000
}
model_warmup {
    name: "bs128 Warmup"
    batch_size: 128
    inputs: {
        key: "input"
        value: {
            data_type: TYPE_FP32
            dims: 3
            dims: 224
            dims: 224
            zero_data: false
        }
    }
}

In [None]:
!mkdir -p triton-serve-pt/resnet-50/1/
!mv -f workspace/model.pt triton-serve-pt/resnet-50/1/
!tar --directory=triton-serve-pt/ --create --gzip --file resnet_pt_v0.tar.gz resnet-50

!mkdir -p triton-serve-trt/resnet-50/1/
!mv -f workspace/model.plan triton-serve-trt/resnet-50/1/
!tar --directory=triton-serve-trt/ --create --gzip --file resnet_trt_v0.tar.gz resnet-50

In [None]:
# Upload to S3

import boto3
import sagemaker

model_name = "resnet-50"
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
model_uri_pt = sagemaker_session.upload_data(
    path="resnet_pt_v0.tar.gz",
    key_prefix=model_name,
)
model_uri_trt = sagemaker_session.upload_data(
    path="resnet_trt_v0.tar.gz",
    key_prefix=model_name,
)