From e1aabf583749d780fd85b9eece980c94c3d82424 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Mon, 22 Aug 2022 14:00:28 +0800 Subject: [PATCH] [Doc] How to write a customized TensorRT plugin (#290) * first edition * fix lint * add 06, 07 * resolve comments * update index.rst * update title * update img --- docs/zh_cn/index.rst | 2 + .../tutorial/06_introduction_to_tensorrt.md | 516 ++++++++++++++++++ docs/zh_cn/tutorial/07_write_a_plugin.md | 497 +++++++++++++++++ resources/tutorial/IPluginV2DynamicExt.svg | 4 + resources/tutorial/srcnn.svg | 1 + 5 files changed, 1020 insertions(+) create mode 100644 docs/zh_cn/tutorial/06_introduction_to_tensorrt.md create mode 100644 docs/zh_cn/tutorial/07_write_a_plugin.md create mode 100644 resources/tutorial/IPluginV2DynamicExt.svg create mode 100644 resources/tutorial/srcnn.svg diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index e04f27649..e90b78cd6 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -96,6 +96,8 @@ tutorial/03_pytorch2onnx.md tutorial/04_onnx_custom_op.md tutorial/05_onnx_model_editing.md + tutorial/06_introduction_to_tensorrt.md + tutorial/07_write_a_plugin.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/tutorial/06_introduction_to_tensorrt.md b/docs/zh_cn/tutorial/06_introduction_to_tensorrt.md new file mode 100644 index 000000000..2f330e525 --- /dev/null +++ b/docs/zh_cn/tutorial/06_introduction_to_tensorrt.md @@ -0,0 +1,516 @@ +# 第六章: TensorRT 模型构建与推理 + +模型部署入门教程继续更新啦!相信经过前几期的学习,大家已经对 ONNX 这一中间表示有了一个比较全面的认识,但是在具体的生产环境中,ONNX 模型常常需要被转换成能被具体推理后端使用的模型格式。本篇教程我们就和大家一起来认识大名鼎鼎的推理后端 TensorRT。 + +## TensorRT 简介 + +TensorRT 是由 NVIDIA 发布的深度学习框架,用于在其硬件上运行深度学习推理。TensorRT 提供量化感知训练和离线量化功能,用户可以选择 INT8 和 FP16 两种优化模式,将深度学习模型应用到不同任务的生产部署,如视频流、语音识别、推荐、欺诈检测、文本生成和自然语言处理。TensorRT 经过高度优化,可在 NVIDIA GPU 上运行, 并且可能是目前在 NVIDIA GPU 运行模型最快的推理引擎。关于 TensorRT 更具体的信息可以访问 [TensorRT官网](https://developer.nvidia.com/tensorrt) 了解。 + +## 安装 TensorRT + +### Windows + +默认在一台有 NVIDIA 显卡的机器上,提前安装好 [CUDA](https://developer.nvidia.com/cuda-toolkit-archive) 和 [CUDNN](https://developer.nvidia.com/rdp/cudnn-archive),登录 NVIDIA 官方网站下载和主机 CUDA 版本适配的 TensorRT 压缩包即可。 + +以 CUDA 版本是 10.2 为例,选择适配 CUDA 10.2 的 [zip 包](https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.5.1/zip/tensorrt-8.2.5.1.windows10.x86_64.cuda-10.2.cudnn8.2.zip),下载完成后,有 conda 虚拟环境的用户可以优先切换到虚拟环境中,然后在 powershell 中执行类似如下的命令安装并测试: + +```shell +cd \the\path\of\tensorrt\zip\file +Expand-Archive TensorRT-8.2.5.1.Windows10.x86_64.cuda-10.2.cudnn8.2.zip . +$env:TENSORRT_DIR = "$pwd\TensorRT-8.2.5.1" +$env:path = "$env:TENSORRT_DIR\lib;" + $env:path +pip install $env:TENSORRT_DIR\python\tensorrt-8.2.5.1-cp36-none-win_amd64.whl +python -c "import tensorrt;print(tensorrt.__version__)" +``` + +上述命令会在安装后检查 TensorRT 版本,如果打印结果是 8.2.5.1,说明安装 Python 包成功了。 + +### Linux + +和在 Windows 环境下安装类似,默认在一台有 NVIDIA 显卡的机器上,提前安装好 [CUDA](https://developer.nvidia.com/cuda-toolkit-archive) 和 [CUDNN](https://developer.nvidia.com/rdp/cudnn-archive),登录 NVIDIA 官方网站下载和主机 CUDA 版本适配的 TensorRT 压缩包即可。 + +以 CUDA 版本是 10.2 为例,选择适配 CUDA 10.2 的 [tar 包](https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.5.1/tars/tensorrt-8.2.5.1.linux.x86_64-gnu.cuda-10.2.cudnn8.2.tar.gz),然后执行类似如下的命令安装并测试: + +```shell +cd /the/path/of/tensorrt/tar/gz/file +tar -zxvf TensorRT-8.2.5.1.linux.x86_64-gnu.cuda-10.2.cudnn8.2.tar.gz +export TENSORRT_DIR=$(pwd)/TensorRT-8.2.5.1 +export LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$LD_LIBRARY_PATH +pip install TensorRT-8.2.5.1/python/tensorrt-8.2.5.1-cp37-none-linux_x86_64.whl +python -c "import tensorrt;print(tensorrt.__version__)" +``` + +如果发现打印结果是 8.2.5.1,说明安装 Python 包成功了。 + +### Jetson + +对于 Jetson 平台,我们有非常详细的安装环境配置教程,可参考 [MMDeploy 安装文档](https://github.com/open-mmlab/mmdeploy/blob/master/docs/zh_cn/01-how-to-build/jetsons.md)。需要注意的是,在 Jetson 上配置的 CUDA 版本 TensorRT 版本与 JetPack 强相关的,我们选择适配硬件的版本即可。配置好环境后,通过 `python -c "import tensorrt;print(tensorrt.__version__)"` 查看TensorRT版本是否正确。 + +## 模型构建 + +我们使用 TensorRT 生成模型主要有两种方式: + +1. 直接通过 TensorRT 的 API 逐层搭建网络; +2. 将中间表示的模型转换成 TensorRT 的模型,比如将 ONNX 模型转换成 TensorRT 模型。 + +接下来,我们将用 Python 和 C++ 语言分别使用这两种方式构建 TensorRT 模型,并将生成的模型进行推理。 + +### 直接构建 + +利用 TensorRT 的 API 逐层搭建网络,这一过程类似使用一般的训练框架,如使用 Pytorch 或者TensorFlow 搭建网络。需要注意的是对于权重部分,如卷积或者归一化层,需要将权重内容赋值到 TensorRT 的网络中。本文就不详细展示,只搭建一个对输入做池化的简单网络。 + +#### 使用 Python API 构建 + +首先是使用 Python API 直接搭建 TensorRT 网络,这种方法主要是利用 `tensorrt.Builder` 的 `create_builder_config` 和 `create_network` 功能,分别构建 config 和 network,前者用于设置网络的最大工作空间等参数,后者就是网络主体,需要对其逐层添加内容。 + +此外,需要定义好输入和输出名称,将构建好的网络序列化,保存成本地文件。值得注意的是:如果想要网络接受不同分辨率的输入输出,需要使用 `tensorrt.Builder` 的 `create_optimization_profile` 函数,并设置最小、最大的尺寸。 + +实现代码如下: + +```python +import tensorrt as trt + +verbose = True +IN_NAME = 'input' +OUT_NAME = 'output' +IN_H = 224 +IN_W = 224 +BATCH_SIZE = 1 + +EXPLICIT_BATCH = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + +TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger() +with trt.Builder(TRT_LOGGER) as builder, builder.create_builder_config( +) as config, builder.create_network(EXPLICIT_BATCH) as network: + # define network + input_tensor = network.add_input( + name=IN_NAME, dtype=trt.float32, shape=(BATCH_SIZE, 3, IN_H, IN_W)) + pool = network.add_pooling( + input=input_tensor, type=trt.PoolingType.MAX, window_size=(2, 2)) + pool.stride = (2, 2) + pool.get_output(0).name = OUT_NAME + network.mark_output(pool.get_output(0)) + + # serialize the model to engine file + profile = builder.create_optimization_profile() + profile.set_shape_input('input', *[[BATCH_SIZE, 3, IN_H, IN_W]]*3) + builder.max_batch_size = 1 + config.max_workspace_size = 1 << 30 + engine = builder.build_engine(network, config) + with open('model_python_trt.engine', mode='wb') as f: + f.write(bytearray(engine.serialize())) + print("generating file done!") +``` + +#### 使用 C++ API 构建 + +对于想要直接用 C++ 语言构建网络的小伙伴来说,整个流程和上述 Python 的执行过程非常类似,需要注意的点主要有: + +1. `nvinfer1:: createInferBuilder` 对应 Python 中的 `tensorrt.Builder`,需要传入 `ILogger` 类的实例,但是 `ILogger` 是一个抽象类,需要用户继承该类并实现内部的虚函数。不过此处我们直接使用了 TensorRT 包解压后的 samples 文件夹 ../samples/common/logger.h 文件里的实现 `Logger` 子类。 +2. 设置 TensorRT 模型的输入尺寸,需要多次调用 `IOptimizationProfile` 的 `setDimensions` 方法,比 Python `略繁琐一些。IOptimizationProfile` 需要用 `createOptimizationProfile` 函数,对应 Python 的 `create_builder_config` 函数。 + +实现代码如下: + +```cpp +#include +#include + +#include +#include <../samples/common/logger.h> + +using namespace nvinfer1; +using namespace sample; + +const char* IN_NAME = "input"; +const char* OUT_NAME = "output"; +static const int IN_H = 224; +static const int IN_W = 224; +static const int BATCH_SIZE = 1; +static const int EXPLICIT_BATCH = 1 << (int)(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + +int main(int argc, char** argv) +{ + // Create builder + Logger m_logger; + IBuilder* builder = createInferBuilder(m_logger); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model to populate the network + INetworkDefinition* network = builder->createNetworkV2(EXPLICIT_BATCH); + ITensor* input_tensor = network->addInput(IN_NAME, DataType::kFLOAT, Dims4{ BATCH_SIZE, 3, IN_H, IN_W }); + IPoolingLayer* pool = network->addPoolingNd(*input_tensor, PoolingType::kMAX, DimsHW{ 2, 2 }); + pool->setStrideNd(DimsHW{ 2, 2 }); + pool->getOutput(0)->setName(OUT_NAME); + network->markOutput(*pool->getOutput(0)); + + // Build engine + IOptimizationProfile* profile = builder->createOptimizationProfile(); + profile->setDimensions(IN_NAME, OptProfileSelector::kMIN, Dims4(BATCH_SIZE, 3, IN_H, IN_W)); + profile->setDimensions(IN_NAME, OptProfileSelector::kOPT, Dims4(BATCH_SIZE, 3, IN_H, IN_W)); + profile->setDimensions(IN_NAME, OptProfileSelector::kMAX, Dims4(BATCH_SIZE, 3, IN_H, IN_W)); + config->setMaxWorkspaceSize(1 << 20); + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + + // Serialize the model to engine file + IHostMemory* modelStream{ nullptr }; + assert(engine != nullptr); + modelStream = engine->serialize(); + + std::ofstream p("model.engine", std::ios::binary); + if (!p) { + std::cerr << "could not open output file to save model" << std::endl; + return -1; + } + p.write(reinterpret_cast(modelStream->data()), modelStream->size()); + std::cout << "generating file done!" << std::endl; + + // Release resources + modelStream->destroy(); + network->destroy(); + engine->destroy(); + builder->destroy(); + config->destroy(); + return 0; +} +``` + +### IR 转换模型 + +除了直接通过 TensorRT 的 API 逐层搭建网络并序列化模型,TensorRT 还支持将中间表示的模型(如 ONNX)转换成 TensorRT 模型。 + +#### 使用 Python API 转换 + +我们首先使用 Pytorch 实现一个和上文一致的模型,即只对输入做一次池化并输出;然后将 Pytorch 模型转换成 ONNX 模型;最后将 ONNX 模型转换成 TensorRT 模型。 +这里主要使用了 TensorRT 的 `OnnxParser` 功能,它可以将 ONNX 模型解析到 TensorRT 的网络中。最后我们同样可以得到一个 TensorRT 模型,其功能与上述方式实现的模型功能一致。 + +实现代码如下: + +```python +import torch +import onnx +import tensorrt as trt + + +onnx_model = 'model.onnx' + +class NaiveModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(2, 2) + + def forward(self, x): + return self.pool(x) + +device = torch.device('cuda:0') + +# generate ONNX model +torch.onnx.export(NaiveModel(), torch.randn(1, 3, 224, 224), onnx_model, input_names=['input'], output_names=['output'], opset_version=11) +onnx_model = onnx.load(onnx_model) + +# create builder and network +logger = trt.Logger(trt.Logger.ERROR) +builder = trt.Builder(logger) +EXPLICIT_BATCH = 1 << (int)( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) +network = builder.create_network(EXPLICIT_BATCH) + +# parse onnx +parser = trt.OnnxParser(network, logger) + +if not parser.parse(onnx_model.SerializeToString()): + error_msgs = '' + for error in range(parser.num_errors): + error_msgs += f'{parser.get_error(error)}\n' + raise RuntimeError(f'Failed to parse onnx, {error_msgs}') + +config = builder.create_builder_config() +config.max_workspace_size = 1<<20 +profile = builder.create_optimization_profile() + +profile.set_shape('input', [1,3 ,224 ,224], [1,3,224, 224], [1,3 ,224 ,224]) +config.add_optimization_profile(profile) +# create engine +with torch.cuda.device(device): + engine = builder.build_engine(network, config) + +with open('model.engine', mode='wb') as f: + f.write(bytearray(engine.serialize())) + print("generating file done!") +``` + +IR 转换时,如果有多 Batch、多输入、动态 shape 的需求,都可以通过多次调用 `set_shape` 函数进行设置。`set_shape` 函数接受的传参分别是:输入节点名称,可接受的最小输入尺寸,最优的输入尺寸,可接受的最大输入尺寸。一般要求这三个尺寸的大小关系为单调递增。 + +#### 使用 C++ API 转换 + +介绍了如何用 Python 语言将 ONNX 模型转换成 TensorRT 模型后,再介绍下如何用 C++ 将 ONNX 模型转换成 TensorRT 模型。这里通过 `NvOnnxParser`,我们可以将上一小节转换时得到的 ONNX 文件直接解析到网络中。 + +实现代码如下: + +```cpp +#include +#include + +#include +#include +#include <../samples/common/logger.h> + +using namespace nvinfer1; +using namespace nvonnxparser; +using namespace sample; + +int main(int argc, char** argv) +{ + // Create builder + Logger m_logger; + IBuilder* builder = createInferBuilder(m_logger); + const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + IBuilderConfig* config = builder->createBuilderConfig(); + + // Create model to populate the network + INetworkDefinition* network = builder->createNetworkV2(explicitBatch); + + // Parse ONNX file + IParser* parser = nvonnxparser::createParser(*network, m_logger); + bool parser_status = parser->parseFromFile("model.onnx", static_cast(ILogger::Severity::kWARNING)); + + // Get the name of network input + Dims dim = network->getInput(0)->getDimensions(); + if (dim.d[0] == -1) // -1 means it is a dynamic model + { + const char* name = network->getInput(0)->getName(); + IOptimizationProfile* profile = builder->createOptimizationProfile(); + profile->setDimensions(name, OptProfileSelector::kMIN, Dims4(1, dim.d[1], dim.d[2], dim.d[3])); + profile->setDimensions(name, OptProfileSelector::kOPT, Dims4(1, dim.d[1], dim.d[2], dim.d[3])); + profile->setDimensions(name, OptProfileSelector::kMAX, Dims4(1, dim.d[1], dim.d[2], dim.d[3])); + config->addOptimizationProfile(profile); + } + + + // Build engine + config->setMaxWorkspaceSize(1 << 20); + ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); + + // Serialize the model to engine file + IHostMemory* modelStream{ nullptr }; + assert(engine != nullptr); + modelStream = engine->serialize(); + + std::ofstream p("model.engine", std::ios::binary); + if (!p) { + std::cerr << "could not open output file to save model" << std::endl; + return -1; + } + p.write(reinterpret_cast(modelStream->data()), modelStream->size()); + std::cout << "generate file success!" << std::endl; + + // Release resources + modelStream->destroy(); + network->destroy(); + engine->destroy(); + builder->destroy(); + config->destroy(); + return 0; +} +``` + +### 模型推理 + +前面,我们使用了两种构建 TensorRT 模型的方式,分别用 Python 和 C++ 两种语言共生成了四个 TensorRT 模型,这四个模型的功能理论上是完全一致的。 +接下来,我们将分别使用 Python 和 C++ 两种语言对生成的 TensorRT 模型进行推理。 + +#### 使用 Python API 推理 + +首先是使用 Python API 推理 TensorRT 模型,这里部分代码参考了 [MMDeploy](https://github.com/open-mmlab/mmdeploy)。运行下面代码,可以发现输入一个 `1x3x224x224` 的张量,输出一个 `1x3x112x112` 的张量,完全符合我们对输入池化后结果的预期。 + +```python +from typing import Union, Optional, Sequence,Dict,Any + +import torch +import tensorrt as trt + +class TRTWrapper(torch.nn.Module): + def __init__(self,engine: Union[str, trt.ICudaEngine], + output_names: Optional[Sequence[str]] = None) -> None: + super().__init__() + self.engine = engine + if isinstance(self.engine, str): + with trt.Logger() as logger, trt.Runtime(logger) as runtime: + with open(self.engine, mode='rb') as f: + engine_bytes = f.read() + self.engine = runtime.deserialize_cuda_engine(engine_bytes) + self.context = self.engine.create_execution_context() + names = [_ for _ in self.engine] + input_names = list(filter(self.engine.binding_is_input, names)) + self._input_names = input_names + self._output_names = output_names + + if self._output_names is None: + output_names = list(set(names) - set(input_names)) + self._output_names = output_names + + def forward(self, inputs: Dict[str, torch.Tensor]): + assert self._input_names is not None + assert self._output_names is not None + bindings = [None] * (len(self._input_names) + len(self._output_names)) + profile_id = 0 + for input_name, input_tensor in inputs.items(): + # check if input shape is valid + profile = self.engine.get_profile_shape(profile_id, input_name) + assert input_tensor.dim() == len( + profile[0]), 'Input dim is different from engine profile.' + for s_min, s_input, s_max in zip(profile[0], input_tensor.shape, + profile[2]): + assert s_min <= s_input <= s_max, \ + 'Input shape should be between ' \ + + f'{profile[0]} and {profile[2]}' \ + + f' but get {tuple(input_tensor.shape)}.' + idx = self.engine.get_binding_index(input_name) + + # All input tensors must be gpu variables + assert 'cuda' in input_tensor.device.type + input_tensor = input_tensor.contiguous() + if input_tensor.dtype == torch.long: + input_tensor = input_tensor.int() + self.context.set_binding_shape(idx, tuple(input_tensor.shape)) + bindings[idx] = input_tensor.contiguous().data_ptr() + + # create output tensors + outputs = {} + for output_name in self._output_names: + idx = self.engine.get_binding_index(output_name) + dtype = torch.float32 + shape = tuple(self.context.get_binding_shape(idx)) + + device = torch.device('cuda') + output = torch.empty(size=shape, dtype=dtype, device=device) + outputs[output_name] = output + bindings[idx] = output.data_ptr() + self.context.execute_async_v2(bindings, + torch.cuda.current_stream().cuda_stream) + return outputs + +model = TRTWrapper('model.engine', ['output']) +output = model(dict(input = torch.randn(1, 3, 224, 224).cuda())) +print(output) +``` + +#### 使用 C++ API 推理 + +最后,在很多实际生产环境中,我们都会使用 C++ 语言完成具体的任务,以达到更加高效的代码运行效果,另外 TensoRT 的用户一般也都更看重其在 C++ 下的使用,所以我们也用 C++ 语言实现一遍模型推理,这也可以和用 Python API 推理模型做一个对比。 + +实现代码如下: + +```cpp +#include +#include + +#include +#include <../samples/common/logger.h> + +#define CHECK(status) \ + do\ + {\ + auto ret = (status);\ + if (ret != 0)\ + {\ + std::cerr << "Cuda failure: " << ret << std::endl;\ + abort();\ + }\ + } while (0) + +using namespace nvinfer1; +using namespace sample; + +const char* IN_NAME = "input"; +const char* OUT_NAME = "output"; +static const int IN_H = 224; +static const int IN_W = 224; +static const int BATCH_SIZE = 1; +static const int EXPLICIT_BATCH = 1 << (int)(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + + +void doInference(IExecutionContext& context, float* input, float* output, int batchSize) +{ + const ICudaEngine& engine = context.getEngine(); + + // Pointers to input and output device buffers to pass to engine. + // Engine requires exactly IEngine::getNbBindings() number of buffers. + assert(engine.getNbBindings() == 2); + void* buffers[2]; + + // In order to bind the buffers, we need to know the names of the input and output tensors. + // Note that indices are guaranteed to be less than IEngine::getNbBindings() + const int inputIndex = engine.getBindingIndex(IN_NAME); + const int outputIndex = engine.getBindingIndex(OUT_NAME); + + // Create GPU buffers on device + CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * IN_H * IN_W * sizeof(float))); + CHECK(cudaMalloc(&buffers[outputIndex], batchSize * 3 * IN_H * IN_W /4 * sizeof(float))); + + // Create stream + cudaStream_t stream; + CHECK(cudaStreamCreate(&stream)); + + // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host + CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * IN_H * IN_W * sizeof(float), cudaMemcpyHostToDevice, stream)); + context.enqueue(batchSize, buffers, stream, nullptr); + CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * 3 * IN_H * IN_W / 4 * sizeof(float), cudaMemcpyDeviceToHost, stream)); + cudaStreamSynchronize(stream); + + // Release stream and buffers + cudaStreamDestroy(stream); + CHECK(cudaFree(buffers[inputIndex])); + CHECK(cudaFree(buffers[outputIndex])); +} + +int main(int argc, char** argv) +{ + // create a model using the API directly and serialize it to a stream + char *trtModelStream{ nullptr }; + size_t size{ 0 }; + + std::ifstream file("model.engine", std::ios::binary); + if (file.good()) { + file.seekg(0, file.end); + size = file.tellg(); + file.seekg(0, file.beg); + trtModelStream = new char[size]; + assert(trtModelStream); + file.read(trtModelStream, size); + file.close(); + } + + Logger m_logger; + IRuntime* runtime = createInferRuntime(m_logger); + assert(runtime != nullptr); + ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr); + assert(engine != nullptr); + IExecutionContext* context = engine->createExecutionContext(); + assert(context != nullptr); + + // generate input data + float data[BATCH_SIZE * 3 * IN_H * IN_W]; + for (int i = 0; i < BATCH_SIZE * 3 * IN_H * IN_W; i++) + data[i] = 1; + + // Run inference + float prob[BATCH_SIZE * 3 * IN_H * IN_W /4]; + doInference(*context, data, prob, BATCH_SIZE); + + // Destroy the engine + context->destroy(); + engine->destroy(); + runtime->destroy(); + return 0; +} +``` + +## 总结 + +通过本文的学习,我们掌握了两种构建 TensorRT 模型的方式:直接通过 TensorRT 的 API 逐层搭建网络;将中间表示的模型转换成 TensorRT 的模型。不仅如此,我们还分别用 C++ 和 Python 两种语言完成了 TensorRT 模型的构建及推理,相信大家都有所收获!在下一篇文章中,我们将和大家一起学习何添加 TensorRT 自定义算子,敬请期待哦~ + +## FAQ + +1. Could not find: cudnn64_8.dll. Is it on your PATH? + 首先检查下自己的环境变量中是否包含 cudnn64_8.dll 所在的路径,若发现 cudnn 的路径在 C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2\\bin 中,但是里面只有 cudnn64_7.dll。解决方法是去 NVIDIA 官网下载 cuDNN zip 包,解压后,复制其中的 cudnn64_8.dll 到 CUDA Toolkit 的 bin 目录下。这时也可以复制一份 cudnn64_7.dll,然后将复制的那份改名成 cudnn64_8.dll,同样可以解决这个问题。 diff --git a/docs/zh_cn/tutorial/07_write_a_plugin.md b/docs/zh_cn/tutorial/07_write_a_plugin.md new file mode 100644 index 000000000..05e6c0ebb --- /dev/null +++ b/docs/zh_cn/tutorial/07_write_a_plugin.md @@ -0,0 +1,497 @@ +# 第七章: TensorRT 自定义插件 + +## 介绍 + +在前面的模型部署入门系列文章中,我们介绍了部署一个 PyTorch 模型到推理后端,如 ONNXRuntime,这其中可能遇到很多工程性的问题。 + +有些可以通过创建 ONNX 节点来解决,该节点仍然使用后端原生的实现进行推理。而有些无法导出到后端的算法,可以通过重写代码改变算法的实现过程,同样可以导出到 ONNX ,达到一致的效果。以上两种方式一般可以处理绝大多数的部署问题,同时也不需要向推理框架引入新的内容,是我们进行模型部署时候的优先选择。 + +然而,仍然存在部分模型,模型中某些算子无法通过上述两种方式绕过问题,这时候,如何对特定后端实现对应代码就极为重要。这也是本文将介绍的第三种方式——**自定义插件**。 + +自定义插件是很多推理框架支持用户自定义算子的方式,以 MMDeploy 为例,它是一个支持多种推理后端的算法库。目前支持的后端有: + +- ONNXRuntime +- TensorRT +- ncnn +- openvino +- PPLNN + 其中,前三种后端均实现了一些自定义的算子。例如 ONNXRuntime 中的调制可变性卷积,ncnn 中的topk 算子,TensorRT 中的 MultiLevelRoiAlign 。 + +介绍如何给后端自定义算子是一件相对复杂的事情,所以本文只针对其中一种后端 TensorRT,介绍自定义算子。如果读者对其他后端感兴趣,可以去他们的代码库查看,一般地,各个推理框架均有详细文档介绍如何添加客制化的算子实现。 + +## 在MMDeploy添加TensorRT插件 + +仍然以前面[教程二](./02_challenges.md)中的超分辨模型SRCNN为例。在教程二中,我们用 ONNXRuntime 作为后端,通过 PyTorch 的 symbolic 函数导出了一个支持动态 scale 的 ONNX 模型,这个模型可以直接用 ONNXRuntime 运行,这是因为 `NewInterpolate` 类导出的节点 `Resize` 就是ONNXRuntime支持的节点。下面我们尝试直接将教程二导出的 `srcnn3.onnx` 转换到TensorRT。 + +```python +from mmdeploy.backend.tensorrt import create_trt_engine + +engine = create_trt_engine( + 'srcnn3.onnx', + input_shapes=dict(input = dict( + min_shape=[1, 3, 256, 256], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 256, 256]))) +``` + +没有安装过MMDeploy的小伙伴可以先参考 [build](../01-how-to-build) 进行安装,安装完成后执行上述脚本,会有如下报错: + +```shell +RuntimeError: Failed to parse onnx, In node 1 (importResize): UNSUPPORTED_NODE: Assertion failed: mode != "cubic" && "This version of TensorRT does not support cubic interpolation!" +``` + +报错的原因有以下两方面: + +1. `srcnn3.onnx`文件中的 `Resize` 是 ONNX 原生节点。其插值方式之一 bicubic 并不被 TensorRT 支持(TensorRT 的 Resize Layer仅支持 nearest 和 bilinear 两种插值方式)。日志的错误信息也明确提示了这点; +2. 但即便将 "bicubic" 模式改为 "bilinear" ,转换仍然失败: `RuntimeError: Failed to parse onnx, In node 1 (importResize): UNSUPPORTED_NODE: Assertion failed: scales.is_weights() && Resize scales must be initializer!"`。这是因为 TensorRT 无法接受动态 scale 导致的。 + +### 创建ONNX节点 + +为解决上述问题,我们需要创建一个新的节点替换原生 Resize 节点,并且实现新节点对应的插件代码。 + +继续复用同样节点名的方式已经不可取,我们需要创建新的节点。改节点名称就叫 `Test::DynamicTRTResize`,这是种类C++的写法,`Test` 为域名,主要用于区分不同来源下的同名的节点,比如 `ONNX::` 和 `Test::`。当然了,ONNX本身也不存在 `DynamicTRTResize` 的节点名。 + +```python +import torch +from torch import nn +from torch.nn.functional import interpolate +import torch.onnx +import cv2 +import numpy as np +import os, requests + +# Download checkpoint and test image +urls = ['https://download.openmmlab.com/mmediting/restorers/srcnn/srcnn_x4k915_1x16_1000k_div2k_20200608-4186f232.pth', + 'https://raw.githubusercontent.com/open-mmlab/mmediting/master/tests/data/face/000001.png'] +names = ['srcnn.pth', 'face.png'] +for url, name in zip(urls, names): + if not os.path.exists(name): + open(name, 'wb').write(requests.get(url).content) + +class DynamicTRTResize(torch.autograd.Function): + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def symbolic(g, input, size_tensor, align_corners = False): + """Symbolic function for creating onnx op.""" + return g.op( + 'Test::DynamicTRTResize', + input, + size_tensor, + align_corners_i=align_corners) + + @staticmethod + def forward(g, input, size_tensor, align_corners = False): + """Run forward.""" + size = [size_tensor.size(-2), size_tensor.size(-1)] + return interpolate( + input, size=size, mode='bicubic', align_corners=align_corners) + + +class StrangeSuperResolutionNet(nn.Module): + + def __init__(self): + super().__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4) + self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0) + self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2) + + self.relu = nn.ReLU() + + def forward(self, x, size_tensor): + x = DynamicTRTResize.apply(x, size_tensor) + out = self.relu(self.conv1(x)) + out = self.relu(self.conv2(out)) + out = self.conv3(out) + return out + + +def init_torch_model(): + torch_model = StrangeSuperResolutionNet() + + state_dict = torch.load('srcnn.pth')['state_dict'] + + # Adapt the checkpoint + for old_key in list(state_dict.keys()): + new_key = '.'.join(old_key.split('.')[1:]) + state_dict[new_key] = state_dict.pop(old_key) + + torch_model.load_state_dict(state_dict) + torch_model.eval() + return torch_model + + +model = init_torch_model() +factor = torch.rand([1, 1, 512, 512], dtype=torch.float) + +input_img = cv2.imread('face.png').astype(np.float32) + +# HWC to NCHW +input_img = np.transpose(input_img, [2, 0, 1]) +input_img = np.expand_dims(input_img, 0) + +# Inference +torch_output = model(torch.from_numpy(input_img), factor).detach().numpy() + +# NCHW to HWC +torch_output = np.squeeze(torch_output, 0) +torch_output = np.clip(torch_output, 0, 255) +torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8) + +# Show image +cv2.imwrite("face_torch.png", torch_output) + +x = torch.randn(1, 3, 256, 256) + +dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'factor': { + 0: 'batch1', + 2: 'height1', + 3: 'width1' + }, + 'output': { + 0: 'batch2', + 2: 'height2', + 3: 'width2' + }, + } + +with torch.no_grad(): + torch.onnx.export( + model, (x, factor), + "srcnn3.onnx", + opset_version=11, + input_names=['input', 'factor'], + output_names=['output'], + dynamic_axes=dynamic_axes) +``` + +执行上述脚本,我们导出成功了一个ONNX模型 `srcnn.onnx`。用[netron](https://netron.app/)打开这个模型可视化如下: + +![](../../../resources/tutorial/srcnn.svg) + +直接将该模型转换成TensorRT模型也是不可行的,这是因为TensorRT还无法解析 `DynamicTRTResize` 节点。而想要解析该节点,我们必须为TensorRT添加c++代码,实现该插件。 + +### C++实现 + +因为MMDeploy中已经实现了Bicubic Interpolate算子,所以我们可以复用其中的CUDA部分代码,只针对TensorRT实现支持动态scale的插件即可。对CUDA编程感兴趣的小伙伴可以参考CUDA的[官方教程](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)。因为 `csrc/backend_ops/tensorrt/bicubic_interpolate` 中有我们需要的CUDA代码,所以我们可以直接在该文件夹加添加TensorRT相关的trt_dynamic_resize.hpp和trt_dynamic_resize.cpp文件,在这两个文件中分别声明和实现插件就可以了。我们也可以新建文件夹 `csrc/backend_ops/tensorrt/dynamic_resize`,将这两个文件直接放到这个文件夹下。 + +对TensorRT 7+,要实现这样一个自定义插件,我们需要写两个类。 + +- `DynamicTRTResize`,继承自[nvinfer1::IPluginV2DynamicExt](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_plugin_v2_dynamic_ext.html),完成插件的具体实现 +- `DynamicTRTResizeCreator`,继承自[nvinfer1::IPluginCreator](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_plugin_creator.html),是插件的工厂类,用于创建`DynamicTRTResize`插件的实例。 + +在MMDeploy中,由于有若干插件需要实现,所以我们在`mmdeploy/csrc/backend_ops/tensorrt/common/trt_plugin_base.hpp`中实现了`TRTPluginBase`和`TRTPluginCreatorBase`两个类,用于管理一些所有插件共有的属性方法。其中,`TRTPluginBase`是继承自`nvinfer1::IPluginV2DynamicExt`,而`TRTPluginCreatorBase`是继承自`nvinfer1::IPluginCreator`。这样,用户实现插件时只需继承这两个新的类即可。所以我们只需在`dynamic_resize`文件夹下.hpp文件中,引用`trt_plugin_base.hpp`头文件,然后实现类如下: + +```cpp +class DynamicTRTResize : public TRTPluginBase{} +class DynamicTRTResizeCreator : public TRTPluginCreatorBase{} +``` + +在trt_dynamic_resize.hpp中,我们声明如下内容: + +```cpp +#ifndef TRT_DYNAMIC_RESIZE_HPP +#define TRT_DYNAMIC_RESIZE_HPP +#include + +#include +#include +#include + +#include "trt_plugin_base.hpp" +namespace mmdeploy { +class DynamicTRTResize : public TRTPluginBase { + public: + DynamicTRTResize(const std::string &name, bool align_corners); + + DynamicTRTResize(const std::string name, const void *data, size_t length); + + DynamicTRTResize() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, + int nbInputs, nvinfer1::IExprBuilder &exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char *getPluginType() const TRT_NOEXCEPT override; + const char *getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void *buffer) const TRT_NOEXCEPT override; + + private: + bool mAlignCorners; +}; + +class DynamicTRTResizeCreator : public TRTPluginCreatorBase { + public: + DynamicTRTResizeCreator(); + + const char *getPluginName() const TRT_NOEXCEPT override; + + const char *getPluginVersion() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT override; +}; +} // namespace mmdeploy +#endif // TRT_DYNAMIC_RESIZE_HPP +``` + +在这样一份头文件中,DynamicTRTResize类进行了如下的套娃继承: + +![](../../../resources/tutorial/IPluginV2DynamicExt.svg) + +从上面的图片和代码中我们发现,插件类`DynamicTRTResize`中我们定义了私有变量`mAlignCorners`,该变量表示是否`align corners`。此外只要实现构造析构函数和TensoRT中三个基类的方法即可。其中构造函数有二,分别用于创建插件和反序列化插件。而基类方法中: + +1. 基类`IPluginV2DynamicExt`的方法较为值得关注,`getOutputDimensions`获取输出张量的形状,`enqueue`真正负责执行我们的算法,内部一般会调用CUDA核函数。本文实现的插件直接调用MMDeploy已定义在`csrc/backend_ops/tensorrt/bicubic_interpolate`的核函数`bicubic_interpolate`。 +2. 基类`IPluginV2Ext`的方法,我们只要实现获取输出数据类型的`getOutputDataType`即可。 +3. 基类`IPluginV2`则是些获取插件类型和版本号的方法,此外则是序列化输入插件的参数的函数`serialize`和计算该参数的序列化后`buffer`大小的函数`getSerializationSize`,以及获取输出张量个数的方法`getNbOutputs`。还有部分公共方法被定义在`TRTPluginBase`类内了。 + +在插件工厂类 `DynamicTRTResizeCreator` 中,我们需要声明获取插件名称和版本的方法 `getPluginName` 和 `getPluginVersion`。同时我们还需要声明创建插件和反序列化插件的方法 `createPlugin` 和 `deserializePlugin`,前者调用 `DynamicTRTResize` 中创建插件的方法,后者调用反序列化插件的方法。 + +接下来,我们就实现上述声明吧。在.cpp文件中我们实现代码如下: + +```cpp +// Copyright (c) OpenMMLab. All rights reserved +#include "trt_dynamic_resize.hpp" + +#include + +#include + +#include "trt_plugin_helper.hpp" +#include "trt_serialize.hpp" +// 引入CUDA核函数bicubic_interpolate在的头文件,会在enqueue中使用 +#include "../bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp" + +using namespace nvinfer1; + +namespace mmdeploy { +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"DynamicTRTResize"};//插件名需和ONNX节点名一致,在转换TensorRT模型时被触发 +} // namespace + +DynamicTRTResize::DynamicTRTResize(const std::string &name, bool align_corners) + : TRTPluginBase(name), mAlignCorners(align_corners) {} + +DynamicTRTResize::DynamicTRTResize(const std::string name, const void *data, + size_t length) + : TRTPluginBase(name) { + deserialize_value(&data, &length, &mAlignCorners); +} + +nvinfer1::IPluginV2DynamicExt *DynamicTRTResize::clone() const TRT_NOEXCEPT { + DynamicTRTResize *plugin = + new DynamicTRTResize(mLayerName, mAlignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs DynamicTRTResize::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + // 输入张量有两个:input和size_tensor,后者只用于计算输出张量形状 + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = inputs[1].d[2]; + ret.d[3] = inputs[1].d[3]; + return ret; +} + +bool DynamicTRTResize::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *ioDesc, + int nbInputs, int nbOutputs) TRT_NOEXCEPT { + if (pos == 0) { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + + } else { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } +} + +void DynamicTRTResize::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, + int nbOutputs) TRT_NOEXCEPT {} + +size_t DynamicTRTResize::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT { + return 0; +} + +int DynamicTRTResize::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workSpace, + cudaStream_t stream) TRT_NOEXCEPT { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + + int height_out = outputDesc[0].dims.d[2]; + int width_out = outputDesc[0].dims.d[3]; + const void *x = inputs[0]; + void *output = outputs[0]; + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + bicubic_interpolate((float *)x, (float *)output, batch, channels, height, width, + height_out, width_out, mAlignCorners, stream); + break; + default: + return 1; + break; + } + + return 0; +} + +nvinfer1::DataType DynamicTRTResize::getOutputDataType(int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *DynamicTRTResize::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } + +const char *DynamicTRTResize::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } + +int DynamicTRTResize::getNbOutputs() const TRT_NOEXCEPT { return 1; } + +size_t DynamicTRTResize::getSerializationSize() const TRT_NOEXCEPT { + return serialized_size(mAlignCorners); +} + +void DynamicTRTResize::serialize(void *buffer) const TRT_NOEXCEPT { + serialize_value(&buffer, mAlignCorners); +} + +////////////////////// creator ///////////////////////////// + +DynamicTRTResizeCreator::DynamicTRTResizeCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *DynamicTRTResizeCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } + +const char *DynamicTRTResizeCreator::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +nvinfer1::IPluginV2 *DynamicTRTResizeCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { + nvinfer1::Dims size{2, {1, 1}}; + bool align_corners = 1; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + //获取align_corners值,用于创建插件DynamicTRTResize的实例 + if (field_name.compare("align_corners") == 0) { + align_corners = static_cast(fc->fields[i].data)[0]; + } + } + // 创建插件DynamicTRTResize实例并返回 + DynamicTRTResize *plugin = new DynamicTRTResize(name, align_corners); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *DynamicTRTResizeCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { + auto plugin = new DynamicTRTResize(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} +REGISTER_TENSORRT_PLUGIN(DynamicTRTResizeCreator);//真正注册了该插件 +} // namespace mmdeploy +``` + +然后,我们就对MMDeploy重新build一次TensorRT的动态库`build/lib/libmmdeploy_tensorrt_ops.so`。一般编译成功就表示已经注册算子了,但是我们需要进行一些测试以保证结果正确。 + +### 测试 + +我们用TensorRT的python api查看一下目前的插件列表: + +```python +import tensorrt as trt +from mmdeploy.backend.tensorrt import load_tensorrt_plugin +load_tensorrt_plugin() +def get_plugin_names(): + return [pc.name for pc in trt.get_plugin_registry().plugin_creator_list] +print(get_plugin_names()) +``` + +可以发现 'DynamicTRTResize' 在插件列表中。然后我们对这个插件进行功能测试,看推理结果是否和PyTroch结果一致,并且可以动态控制输出尺寸。 + +```python +from mmdeploy.backend.tensorrt import create_trt_engine, save_trt_engine + +engine = create_trt_engine( + 'srcnn3.onnx', + input_shapes=dict(input = dict( + min_shape=[1, 3, 256, 256], + opt_shape=[1, 3, 256, 256], + max_shape=[1, 3, 256, 256]), + factor = dict(min_shape = [1, 1, 256, 256], opt_shape = [1, 1, 512, 512], max_shape = [1, 1, 1024, 1024]))) + +save_trt_engine(engine, 'srcnn3.engine') + +from mmdeploy.backend.tensorrt import TRTWrapper +trt_model = TRTWrapper('srcnn3.engine', ['output']) + +factor = torch.rand([1, 1, 768, 768], dtype=torch.float) +trt_output = trt_model.forward(dict(input = x.cuda(), factor = factor.cuda())) +torch_output = model.forward(x, factor) +assert np.allclose(trt_output['output'].cpu().numpy(), torch_output.cpu().detach(), rtol = 1e-3, atol = 1e-5) +``` + +对比 TensorRT 的输出结果和 PyTorch 的输出结果是否一致,程序如果不报错即可说明推理正确。此外,测试时我们使用和导出时不一样的尺寸,结果也和 PyTorch 一致,说明可以支持动态的尺寸。 + +## 总结 + +本篇教程我们主要讲述如何在 MMDeploy 代码库中添加一个自定义的 TensorRT 插件,整个过程不涉及太多更复杂的 CUDA 编程,相信小伙伴们学完可以自己实现想要的插件。 diff --git a/resources/tutorial/IPluginV2DynamicExt.svg b/resources/tutorial/IPluginV2DynamicExt.svg new file mode 100644 index 000000000..f3c2ae823 --- /dev/null +++ b/resources/tutorial/IPluginV2DynamicExt.svg @@ -0,0 +1,4 @@ + + + +
nvinfer1::IPluginV2
nvinfer1::IPluginV2
nvinfer1::IPluginV2Ext
nvinfer1::IPluginV2Ext
nvinfer1::IPluginV2DynamicExt
nvinfer1::IPluginV2DynamicExt
TRTPluginBase
TRTPluginBase
DynamicTRTResize
DynamicTRTResize
TensorRT
TensorRT
MMDeploy
MMDeploy
Text is not SVG - cannot display
diff --git a/resources/tutorial/srcnn.svg b/resources/tutorial/srcnn.svg new file mode 100644 index 000000000..ea35d01d8 --- /dev/null +++ b/resources/tutorial/srcnn.svg @@ -0,0 +1 @@ +inputfloat32[batch,3,height,width]factorfloat32[batch,1,height,width]DynamicTRTResizeDynamicTRTResize_0ConvConv_1float32[64,3,9,9]W〈64×3×9×9〉float32[64]B〈64〉ReluRelu_2ConvConv_3float32[32,64,1,1]W〈32×64×1×1〉float32[32]B〈32〉ReluRelu_4ConvConv_5float32[3,32,5,5]W〈3×32×5×5〉float32[3]B〈3〉outputfloat32[batch,3,height,width]