Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

end2end.engine to Triton #465

Closed
austinmw opened this issue May 12, 2022 · 17 comments
Closed

end2end.engine to Triton #465

austinmw opened this issue May 12, 2022 · 17 comments
Assignees

Comments

@austinmw
Copy link

Hi, I built the MMDeploy GPU Dockerfile and installed MMDetection on top of it to generate an end2end.engine file.

I tried to run this in NVIDIA’s Triton docker image, but got errors due to missing plugins.

I see that there’s some documentation on installing plugins, but it does not seem straight forward to add those instructions to a Dockerfile that uses Triton’s image as a base.

Since this is probably a very common use case, could I request details or a feature request to include a Dockerfile for Triton server with necessary plugins added?

Appreciate any assistance or advice, thanks!

@lvhan028
Copy link
Collaborator

@austinmw Is it related to #460?
Regarding " got errors due to missing plugins", could you paste the error log?
Can you let me know how to reproduce it? I'll fix it as soon as possible. Or you can make a PR if you'd like. We'll appreciate it a lot.

@lvhan028 lvhan028 self-assigned this May 12, 2022
@austinmw
Copy link
Author

austinmw commented May 12, 2022

@lvhan028 Thanks for your help, really appreciate it!

(As for #460, I don't think this is related since that issue only occurs when I attempt to use a dynamic size mmdeploy config, whereas in this example I use a fixed size, which successfully produces an engine file. However you could reproduce that issue using all of the below steps except for using detection_tensorrt-fp16_dynamic-320x320-1344x1344.py instead of the detection_tensorrt-fp16_static-800x1344.py I use below.)


Here's all of the steps to reproduce what I've done (and the Triton failure log at the bottom of this post):

1. MMDeploy docker build

This is nearly the same as the base MMDeploy GPU image except that I've updated the versions of tensorrt, torch, onnx, mmcv, pplcv and added mmdetection install at the end.

Dockerfile.mmdeploy:

FROM nvcr.io/nvidia/tensorrt:22.04-py3

ARG CUDA=11.3
ARG PYTHON_VERSION=3.8
ARG TORCH_VERSION=1.11.0
ARG TORCHVISION_VERSION=0.12.0
ARG ONNXRUNTIME_VERSION=1.11.1
ARG MMCV_VERSION=1.5.0
ARG PPLCV_VERSION=0.6.3
ENV FORCE_CUDA="1"

ENV DEBIAN_FRONTEND=noninteractive

### change the system source for installing libs
ARG USE_SRC_INSIDE=false
RUN if [ ${USE_SRC_INSIDE} == true ] ; \
    then \
        sed -i s/archive.ubuntu.com/mirrors.aliyun.com/g /etc/apt/sources.list ; \
        sed -i s/security.ubuntu.com/mirrors.aliyun.com/g /etc/apt/sources.list ; \
        echo "Use aliyun source for installing libs" ; \
    else \
        echo "Keep the download source unchanged" ; \
    fi

### update apt and install libs
RUN apt-get update &&\
    apt-get install -y vim libsm6 libxext6 libxrender-dev libgl1-mesa-glx git wget libssl-dev libopencv-dev libspdlog-dev --no-install-recommends &&\
    rm -rf /var/lib/apt/lists/*

RUN curl -fsSL -v -o ~/miniconda.sh -O  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh  && \
    chmod +x ~/miniconda.sh && \
    ~/miniconda.sh -b -p /opt/conda && \
    rm ~/miniconda.sh && \
    /opt/conda/bin/conda install -y python=${PYTHON_VERSION} conda-build pyyaml numpy ipython cython typing typing_extensions mkl mkl-include ninja && \
    /opt/conda/bin/conda clean -ya

### pytorch
RUN /opt/conda/bin/conda install pytorch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} cudatoolkit=${CUDA} -c pytorch
ENV PATH /opt/conda/bin:$PATH

### install mmcv-full
RUN /opt/conda/bin/pip install mmcv-full==${MMCV_VERSION} -f https://download.openmmlab.com/mmcv/dist/cu${CUDA//./}/torch${TORCH_VERSION}/index.html

WORKDIR /root/workspace
### get onnxruntime
RUN wget https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}.tgz \
    && tar -zxvf onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}.tgz &&\
    pip install onnxruntime-gpu==${ONNXRUNTIME_VERSION}

### cp trt from pip to conda
RUN cp -r /usr/local/lib/python${PYTHON_VERSION}/dist-packages/tensorrt* /opt/conda/lib/python${PYTHON_VERSION}/site-packages/

### install mmdeploy
ENV ONNXRUNTIME_DIR=/root/workspace/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}
ENV TENSORRT_DIR=/workspace/tensorrt
ARG VERSION
RUN git clone https://github.com/open-mmlab/mmdeploy &&\
    cd mmdeploy &&\
    if [ -z ${VERSION} ] ; then echo "No MMDeploy version passed in, building on master" ; else git checkout tags/v${VERSION} -b tag_v${VERSION} ; fi &&\
    git submodule update --init --recursive &&\
    mkdir -p build &&\
    cd build &&\
    cmake -DMMDEPLOY_TARGET_BACKENDS="ort;trt" .. &&\
    make -j$(nproc) &&\
    cd .. &&\
    pip install -e .

### build sdk
RUN git clone https://github.com/openppl-public/ppl.cv.git &&\
    cd ppl.cv &&\
    git checkout tags/v${PPLCV_VERSION} -b v${PPLCV_VERSION} &&\
    ./build.sh cuda

ENV BACKUP_LD_LIBRARY_PATH=$LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/local/cuda-11.6/compat/lib.real/:$LD_LIBRARY_PATH

RUN cd /root/workspace/mmdeploy &&\
    rm -rf build/CM* build/cmake-install.cmake build/Makefile build/csrc &&\
    mkdir -p build && cd build &&\
    cmake .. \
        -DMMDEPLOY_BUILD_SDK=ON \
        -DCMAKE_CXX_COMPILER=g++ \
        -Dpplcv_DIR=/root/workspace/ppl.cv/cuda-build/install/lib/cmake/ppl \
        -DTENSORRT_DIR=${TENSORRT_DIR} \
        -DONNXRUNTIME_DIR=${ONNXRUNTIME_DIR} \
        -DMMDEPLOY_BUILD_SDK_PYTHON_API=ON \
        -DMMDEPLOY_TARGET_DEVICES="cuda;cpu" \
        -DMMDEPLOY_TARGET_BACKENDS="ort;trt" \
        -DMMDEPLOY_CODEBASES=all &&\
    make -j$(nproc) && make install &&\
    cd install/example  && mkdir -p build && cd build &&\
    cmake -DMMDeploy_DIR=/root/workspace/mmdeploy/build/install/lib/cmake/MMDeploy .. &&\
    make -j$(nproc) && export SPDLOG_LEVEL=warn &&\
    if [ -z ${VERSION} ] ; then echo "Built MMDeploy master for GPU devices successfully!" ; else echo "Built MMDeploy version v${VERSION} for GPU devices successfully!" ; fi

ENV LD_LIBRARY_PATH="/root/workspace/mmdeploy/build/lib:${BACKUP_LD_LIBRARY_PATH}"


# Add mmdetection

# install mmcv and mmdetection
RUN cd / && \
    git clone -b v2.24.1 https://github.com/open-mmlab/mmdetection.git && \
    cd mmdetection && \
    pip install -r requirements/build.txt && \
    pip install -v -e .

Build command I use:

nvidia-docker build -f Dockerfile.mmdeploy -t mmdeploy:latest .

2. Engine file creation

# create directory and download checkpoint into it
mkdir volume_share
wget https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth -O volume_share/checkpoint.pth

# run conversion command with checkpoint dir mounted
nvidia-docker run -it -v $(pwd)/volume_share:/volume_share mmdeploy:latest python /root/workspace/mmdeploy/tools/deploy.py \
    /root/workspace/mmdeploy/configs/mmdet/detection/detection_tensorrt-fp16_static-800x1344.py \
    /mmdetection/configs/yolox/yolox_s_8x8_300e_coco.py \
    /volume_share/checkpoint.pth \
    /mmdetection/demo/demo.jpg \
    --work-dir /volume_share \
    --show \
    --device cuda:0

For me this results in an end2end.engine file created inside ./volume_share.


3. Triton server build

I'd like to use Triton with Amazon SageMaker, so I've added the provided serve file to the Dockerfile:

FROM nvcr.io/nvidia/tritonserver:22.04-py3

# Get /bin/serve for SageMaker
RUN wget https://raw.githubusercontent.com/triton-inference-server/server/main/docker/sagemaker/serve -P /bin/ && \
    chmod +x /bin/serve

And here's my build command:

nvidia-docker build -f Dockerfile.serve -t triton:latest .

4. Package model

I create a model directory using the following config.pbtxt:

name: "yolox"
platform: "tensorrt_plan"
max_batch_size: 128
input {
  name: "input"
  data_type: TYPE_FP32
  dims: [ 3, 800, 1344 ]
}
output [
  {
    name: "dets"
    data_type: TYPE_FP32
    dims: [ 100, 5 ]
  },    
  {
    name: "labels"
    data_type: TYPE_INT32
    dims: [ 100 ]
  }    
]
instance_group {
  count: 1
  kind: KIND_GPU
}
dynamic_batching {
  preferred_batch_size: 128
  max_queue_delay_microseconds: 100
}

default_model_filename: "end2end.engine"

And run these commands to create a model directory:

mkdir -p triton-serve/yolox/1/
cp volume_share/end2end.engine triton-serve/yolox/1/
cp config.pbtxt triton-serve/yolox/

Which results in this directory structure for my model contents:

./triton-serve:

yolox
├── 1
 │   └── end2end.engine
└── config.pbtxt

5. Testing serving locally

I test serving by running this command which simulates how it would be run on SageMaker:

nvidia-docker run -it -v $(pwd)/triton-serve:/opt/ml/model triton:latest serve

Result

After going through these steps here's the output error I get:

=============================
== Triton Inference Server ==
=============================

NVIDIA Release 22.04 (build 36821869)

Copyright (c) 2018-2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.

Various files include modifications (c) NVIDIA CORPORATION.  All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

NOTE: Legacy NVIDIA Driver detected.  Compatibility mode ENABLED.

WARNING: No SAGEMAKER_TRITON_DEFAULT_MODEL_NAME provided.
         Starting with the only existing model directory yolox
I0512 13:01:07.013668 52 libtorch.cc:1381] TRITONBACKEND_Initialize: pytorch
I0512 13:01:07.013766 52 libtorch.cc:1391] Triton TRITONBACKEND API version: 1.9
I0512 13:01:07.013784 52 libtorch.cc:1397] 'pytorch' TRITONBACKEND API version: 1.9
2022-05-12 13:01:07.218204: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
I0512 13:01:07.262541 52 tensorflow.cc:2181] TRITONBACKEND_Initialize: tensorflow
I0512 13:01:07.262578 52 tensorflow.cc:2191] Triton TRITONBACKEND API version: 1.9
I0512 13:01:07.262588 52 tensorflow.cc:2197] 'tensorflow' TRITONBACKEND API version: 1.9
I0512 13:01:07.262603 52 tensorflow.cc:2221] backend configuration:
{}
I0512 13:01:07.264346 52 onnxruntime.cc:2400] TRITONBACKEND_Initialize: onnxruntime
I0512 13:01:07.264377 52 onnxruntime.cc:2410] Triton TRITONBACKEND API version: 1.9
I0512 13:01:07.264387 52 onnxruntime.cc:2416] 'onnxruntime' TRITONBACKEND API version: 1.9
I0512 13:01:07.264406 52 onnxruntime.cc:2446] backend configuration:
{}
I0512 13:01:07.285557 52 openvino.cc:1207] TRITONBACKEND_Initialize: openvino
I0512 13:01:07.285584 52 openvino.cc:1217] Triton TRITONBACKEND API version: 1.9
I0512 13:01:07.285601 52 openvino.cc:1223] 'openvino' TRITONBACKEND API version: 1.9
I0512 13:01:09.095814 52 pinned_memory_manager.cc:240] Pinned memory pool is created at '0x7f5544000000' with size 268435456
I0512 13:01:09.096292 52 cuda_memory_manager.cc:105] CUDA memory pool is created on device 0 with size 67108864
I0512 13:01:09.098458 52 model_repository_manager.cc:1077] loading: yolox:1
I0512 13:01:09.199354 52 tensorrt.cc:5294] TRITONBACKEND_Initialize: tensorrt
I0512 13:01:09.199418 52 tensorrt.cc:5304] Triton TRITONBACKEND API version: 1.9
I0512 13:01:09.199437 52 tensorrt.cc:5310] 'tensorrt' TRITONBACKEND API version: 1.9
I0512 13:01:09.199563 52 tensorrt.cc:5353] backend configuration:
{}
I0512 13:01:09.199619 52 tensorrt.cc:5405] TRITONBACKEND_ModelInitialize: yolox (version 1)
I0512 13:01:09.201159 52 tensorrt.cc:5454] TRITONBACKEND_ModelInstanceInitialize: yolox_0 (GPU device 0)
I0512 13:01:09.586191 52 logging.cc:49] [MemUsageChange] Init CUDA: CPU +252, GPU +0, now: CPU 1411, GPU 1013 (MiB)
I0512 13:01:09.622173 52 logging.cc:49] Loaded engine size: 21 MiB
E0512 13:01:09.693721 52 logging.cc:43] 1: [pluginV2Runner.cpp::load::290] Error Code 1: Serialization (Serialization assertion creator failed.Cannot deserialize plugin since corresponding IPluginCreator not found in Plugin Registry)
E0512 13:01:09.693781 52 logging.cc:43] 4: [runtime.cpp::deserializeCudaEngine::50] Error Code 4: Internal Error (Engine deserialization failed.)
I0512 13:01:09.694681 52 tensorrt.cc:5492] TRITONBACKEND_ModelInstanceFinalize: delete instance state
I0512 13:01:09.694721 52 tensorrt.cc:5431] TRITONBACKEND_ModelFinalize: delete model state
E0512 13:01:09.695238 52 model_repository_manager.cc:1234] failed to load 'yolox' version 1: Internal: unable to create TensorRT engine
I0512 13:01:09.695515 52 tritonserver.cc:2123] 
+----------------------------------+------------------------------------------+
| Option                           | Value                                    |
+----------------------------------+------------------------------------------+
| server_id                        | triton                                   |
| server_version                   | 2.21.0                                   |
| server_extensions                | classification sequence model_repository |
|                                  |  model_repository(unload_dependents) sch |
|                                  | edule_policy model_configuration system_ |
|                                  | shared_memory cuda_shared_memory binary_ |
|                                  | tensor_data statistics trace             |
| model_repository_path[0]         | /opt/ml/model/                           |
| model_control_mode               | MODE_EXPLICIT                            |
| startup_models_0                 | yolox                                    |
| strict_model_config              | 1                                        |
| rate_limit                       | OFF                                      |
| pinned_memory_pool_byte_size     | 268435456                                |
| cuda_memory_pool_byte_size{0}    | 67108864                                 |
| response_cache_byte_size         | 0                                        |
| min_supported_compute_capability | 6.0                                      |
| strict_readiness                 | 1                                        |
| exit_timeout                     | 30                                       |
+----------------------------------+------------------------------------------+

I0512 13:01:09.695612 52 server.cc:247] No server context available. Exiting immediately.
error: creating server: Invalid argument - load failed for model 'yolox': version 1: Internal: unable to create TensorRT engine;

So it looks like the primary error is:

E0512 13:01:09.693721 52 logging.cc:43] 1: [pluginV2Runner.cpp::load::290] Error Code 1: Serialization (Serialization assertion creator failed.Cannot deserialize plugin since corresponding IPluginCreator not found in Plugin Registry)

Maybe I need to somehow identify and then copy over some shared libraries from my mmdeploy docker image into my triton docker image and then prepend LD_PRELOAD ... to the tritonserver run command? Or maybe it's something else entirely?

If there's any way I can help to make reproducing this issue faster for you, please just let me know. Thanks!


Edit: It looks like I can launch the server successfully by adding the following:

  1. Copying /root/workspace/mmdeploy/build/lib/libmmdeploy_tensorrt_ops.so from mmdeploy docker image into /opt/tritonserver/lib/ in triton docker image
  2. Prepending LD_PRELOAD=libmmdeploy_tensorrt_ops.so to the end of /bin/serve/ right before the tritonserver command.

I guess I can do the first step with a multi-stage build like:

FROM mmdeploy:latest AS mmdeploy
FROM nvcr.io/nvidia/tritonserver:22.04-py3
COPY --from=mmdeploy /root/workspace/mmdeploy/build/lib/libmmdeploy_tensorrt_ops.so /opt/tritonserver/lib/   

And I can do the second step with a sed command, though is there any way to force tritonserver to use .so files that I place in some location without requiring LD_PRELOAD? Or more generally, is there a more recommended way to add libmmdeploy_tensorrt_ops.so to Triton?

@lvhan028
Copy link
Collaborator

I got it.
When MMDeploy develops a bunch of custom tensorrt plugins for model deployment, especially the detection models. Those plugins are built into a dynamic lib 'libmmdeploy_tensorrt_ops.so', which has to be loaded before inference.

Let me find out if there is any improvement.

@lvhan028
Copy link
Collaborator

lvhan028 commented May 23, 2022

Hi, @austinmw sorry for replying late.
I am afraid "LD_PRELOAD" has to be used. Otherwise, the custom ops cannot be found.
MMDeploy providing a prebuilt package might be a better way to improve the experience.
So, users will do not have to build MMDeploy in pain.

@austinmw
Copy link
Author

Thanks for your response!

@manhtd98
Copy link

manhtd98 commented Jun 6, 2022

@austinmw i use your pipeline . But on tritonserver it only support batch-size=1. OnSagemaker it cannot receive and send response. Do you face the same problem?

@austinmw
Copy link
Author

austinmw commented Jun 7, 2022

@manhtd98 You need to set the batch dimension in the mmdeploy config file to the max batch size in opt_shape and max_shape for it to allow batches.

@manhtd98
Copy link

manhtd98 commented Jun 7, 2022

@austinmw Error when create tensorrt file. here is config file: Error Code 4: Internal Error (input: kMAX dimensions in profile 0 are [128,3,800,1344] but input has static dimensions [1,3,800,1344].)

_base_ = ['./base_static.py', '../../_base_/backends/tensorrt-fp16.py']

onnx_config = dict(input_shape=(1344, 800))

backend_config = dict(
    common_config=dict(max_workspace_size=1 << 30),
    model_inputs=[
        dict(
            input_shapes=dict(
                input=dict(
                    min_shape=[1, 3, 800, 1344],
                    opt_shape=[128, 3, 800, 1344],
                    max_shape=[128, 3, 800, 1344])))
    ])

@austinmw
Copy link
Author

austinmw commented Jun 7, 2022

I didn't use an onnx_config like that, it may be incorrect. I think you need to use one of the dynamic config files as a base instead of the static config. For example, starting with this and then increasing batch dimension.

Also probably unrelated, but you may want to increase the max_workspace_size.

@manhtd98
Copy link

manhtd98 commented Jun 7, 2022

python3 ./tools/deploy.py ./configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py ../mmdetection/configs/insurance/cascade_mask_rcnn_restnext101.py ../mmdetection/pretrained/cascade_mask_rcnn_restnext101.pth ./demo/demo.jpg --device cuda:0

I can convert yolo but for cascade only batch 1 success. There are error: Error[10]: [optimizer.cpp::computeCosts::2011] Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[3450 + (Unnamed Layer* 3695) [Shuffle]...Reshape_1887 + Reshape_1889 + Unsqueeze_1890 + Reshape_1905]}.)
[06/07/2022-20:30:15] [E] Error[2]: [builder.cpp::buildSerializedNetwork::609] Error Code 2: Internal Error (Assertion enginePtr != nullptr failed. )

@austinmw
Copy link
Author

austinmw commented Jun 8, 2022

Hmm sorry I'm not sure about that error. Maybe it's a version compatibility issue with one of the libraries in the MMDeploy Dockerfile.

@leemengwei
Copy link

@manhtd98 Hi
Have you solved it ? I had same issue for maskrcnn. Batch = 1 ok, not for others!

@manhtd98
Copy link

@leemengwei you need to increase workspace size in the trtexec
It insreace size of model and more batch

@manhtd98
Copy link

@austinmw how did you start the client inference. I success to start sage in port 8080 but cannot send request to this

@austinmw
Copy link
Author

I used Amazon SageMaker docker container to run Triton: https://github.com/aws/amazon-sagemaker-examples/tree/main/sagemaker-triton

@manhtd98
Copy link

manhtd98 commented Jul 1, 2022

@austinmw could you provide same code. I tried many time and stuck in there

request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
            inputs, outputs=outputs)

    headers = {
        'Content-Type':
            'application/vnd.sagemaker-triton.binary+json;json-header-size={}'
            .format(header_length)
    }
    import requests
    r = requests.post("http://192.168.81.111:8080", data=request_body, headers=headers)
    r.raise_for_status()
    results = dict((el, []) for el in keys)
    user_data = r.body

@hh123445
Copy link

hh123445 commented Feb 20, 2023

我使用 Amazon SageMaker docker 容器来运行 Triton:https://github.com/aws/amazon-sagemaker-examples/tree/main/sagemaker-triton

Hello, I'm trying to convert the model on mmdetection into an. engine file through mmdeploy and deploy it to a triton. But when I make an inference request, the output of the model is all 0 and - 1. I need your help! For detailed description of the problem, please refer to the following:
triton-inference-server/server#5382

Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants