# Deploy FlanT5-XL with dyanmic batching on SageMaker using TensorRT-LLM DLC 

In this notebook, we deploy the open source Flan-T5-XL model across GPUs on a ml.g5.12xlarge instance for using TensorRT-LLM deep learning container. 

TensorRT-LLM has two separate software components Python runtime, and C++ runtime for Triton Inference Server. Whereas Python runtime has the components to build and execute the TensorRT runtime engine, but this implementation does not have the support for inflight/continuous batching. 

More details on [TensorRT-LLM github](https://github.com/NVIDIA/TensorRT-LLM/blob/118b3d7e7bab720d8ea9cd95338da60f7512c93a/docs/source/gpt_runtime.md)

So in order to user T5 models in djl-trtllm DLC (as of 0.27.0), you can only use dyanmic batching for now. 

FlanT5 is available at [Huggingface Model Hub - Model Weights](https://huggingface.co/google/flan-t5-xl).

## Step 1: Install, import the required libraries; set some variables

In [None]:
%pip install sagemaker boto3 awscli --upgrade  --quiet
%pip install -U sagemaker

In [2]:
import time

In [11]:
import boto3
import json
import jinja2
import sagemaker
from pathlib import Path
from sagemaker import Model, image_uris, serializers, deserializers

#role = "arn:aws:iam::125045733377:role/service-role/AmazonSageMaker-ExecutionRole-20240326T111860"
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
smr_client = boto3.client("sagemaker-runtime")
jinja_env = jinja2.Environment()

## Step 2: Start preparing model artifacts
In LMI container, we expect some artifacts to help setting up the model
- serving.properties (required): Defines the model server settings

Below files are optional
- model.py (optional): A python file to define the core inference logic
- requirements.txt (optional): Any additional pip wheel need to install


### Low-code no-code experience

We have recently introduced low-code, node-code (LCNC) experience to our users. You only need to specify the model_id and we figure out the other configurations for you. 

If you provide the huggingface model, the we set the default configurations as follows: 
1. option.dtype = fp32
2. option.tensor_parallel_degree = <max-number-of-gpus>
3. max_dynamic_batch_size=32 (same as batch_size)
4. engine=MPI
5. option.entryPoint=djl_python.tensorrt_llm
    

    
    
Unfortunately, as of now, this LCNC experience is only available for just in time compilation experience. That is, only when you provide the HuggingFace model and we compile the huggingface model to trt-engine format in runtime for you and then load those TensorRT engines for inference. 
    
If you are providing the pre-comipled model, then you need to specify all those configurations in serving.properties yourself. 

#### Just-in-time compilation (HF model)

In [29]:
%%writefile serving.properties
#option.model_id = {{s3url}}
option.model_id=google/flan-t5-xl
#option.entryPoint = {{model_handler_url}}

Writing serving.properties


#### pre-compiled model (TRT engine format)

In [None]:
# %%writefile serving.properties
# engine=MPI
# option.model_id=google/flan-t5-xl
# option.entryPoint=djl_python.tensorrt_llm
# option.tensor_parallel_degree=max
# option.dtype=fp32
# max_dynamic_batch_size=8

In [None]:
%%sh
mkdir flant5_trtllm 
mv serving.properties flant5_trtllm/
rm -f flant5_trtllm.tar.gz
tar czvf flant5-trtllm.tar.gz -C flant5_trtllm .
rm -rf flant5_trtllm

## Step 3: To create the end point the steps are:

1. Create the Model using the Image container
2. Create the endpoint config using the following key parameters

    a) Instance Type is ml.g5.12xlarge
    
3. Create the end point using the endpoint config create

#### Create the Model
Use the image URI for the DJL container and the s3 location to which the tarball was uploaded.

The container downloads the model into the `/tmp` space on the container because SageMaker maps the `/tmp` to the Amazon Elastic Block Store (Amazon EBS) volume that is mounted when we specify the endpoint creation parameter VolumeSizeInGB. It leverages `s5cmd`(https://github.com/peak/s5cmd) which offers a very fast download speed and hence extremely useful when downloading large models.


### 3.1 Getting the container image URI

[Deep learning containers for large model inference](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers) has more details about each version and framework in the container.

In [17]:
# image_uri = image_uris.retrieve(
#         framework="djl-tensorrtllm",
#         region=sess.boto_session.region_name,
#         version="0.27.0"
#     )

image_uri = "125045733377.dkr.ecr.us-west-2.amazonaws.com/djl-serving:tensorrt-llm-nightly"

### 3.2 Upload artifact on S3 and create SageMaker model

In [31]:
s3_code_prefix = "flan-t5-trt-llm/"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("flant5-trtllm.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")
env = {"HUGGINGFACE_HUB_CACHE": "/tmp", "TRANSFORMERS_CACHE": "/tmp"}

model = Model(image_uri=image_uri, model_data=code_artifact, env=env, role=role)

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-west-2-125045733377/flan-t5-trt-llm//flant5-trtllm.tar.gz


### 3.3 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [None]:
instance_type = "ml.g5.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("djl-trtllm-g5-fp32-flan-t5-xl")
print(endpoint_name)

model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  endpoint_name=endpoint_name,
)

In [35]:
# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)

## Step 4: Invoke the endpoint. 

In [None]:
input_prompt = "translate English to German: The house is wonderful."
start = time.time()
response = predictor.predict(
    {
        "inputs": input_prompt,
        "parameters": {"max_new_tokens": 1024, "details": True},
    }
)
end = time.time()
print(f"Time elapsed {end - start}")
print(response)

## Clean up the environment

In [37]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()