Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support multi_scale_deform_attn trt plugin (#1844)
* support multi_scale_deform_attn trt plugin * fix lint error * add onnx symblic fun * init msdeformablecrossattn symblic fun & fix plugin * unittest for ms_deformable_cross_attn * fix template * update unittest * fix input contiguous of trtwrapper * update doc description
- Loading branch information
Showing
10 changed files
with
715 additions
and
0 deletions.
There are no files selected for viewing
181 changes: 181 additions & 0 deletions
181
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#include "trt_ms_deform_attn.hpp" | ||
|
||
#include <assert.h> | ||
|
||
#include <chrono> | ||
|
||
#include "trt_ms_deform_attn_kernel.hpp" | ||
#include "trt_serialize.hpp" | ||
|
||
using namespace nvinfer1; | ||
|
||
namespace mmdeploy { | ||
namespace { | ||
static const char *PLUGIN_VERSION{"1"}; | ||
static const char *PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"}; | ||
} // namespace | ||
|
||
MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string &name) | ||
: TRTPluginBase(name) {} | ||
|
||
MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name, | ||
const void *data, | ||
size_t length) | ||
: TRTPluginBase(name) {} | ||
MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {} | ||
|
||
nvinfer1::IPluginV2DynamicExt *MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT { | ||
MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(mLayerName); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
|
||
return plugin; | ||
} | ||
|
||
nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions( | ||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, | ||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { | ||
nvinfer1::DimsExprs ret; | ||
ret.nbDims = 3; | ||
ret.d[0] = inputs[0].d[0]; | ||
ret.d[1] = inputs[3].d[1]; | ||
|
||
ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD, | ||
*inputs[0].d[2], *inputs[0].d[3]); | ||
|
||
return ret; | ||
} | ||
|
||
bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination( | ||
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { | ||
|
||
if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) | ||
{ | ||
if ((pos == 1) || (pos == 2)) | ||
{ | ||
return (ioDesc[pos].type == nvinfer1::DataType::kINT32); | ||
} | ||
else | ||
{ | ||
return ((ioDesc[pos].type == ioDesc[0].type) && | ||
((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || (ioDesc[pos].type == nvinfer1::DataType::kHALF))); | ||
} | ||
} | ||
else | ||
{ | ||
return false; | ||
} | ||
} | ||
|
||
void MultiScaleDeformableAttnPluginDynamic::configurePlugin( | ||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT { | ||
} | ||
|
||
size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize( | ||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT { | ||
return 0; | ||
} | ||
|
||
int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, | ||
const void *const *inputs, void *const *outputs, | ||
void *workSpace, | ||
cudaStream_t stream) TRT_NOEXCEPT { | ||
int32_t const batch = inputDesc[0].dims.d[0]; | ||
int32_t spatial_size = inputDesc[0].dims.d[1]; | ||
int32_t num_heads = inputDesc[0].dims.d[2]; | ||
int32_t channels = inputDesc[0].dims.d[3]; | ||
int32_t num_levels = inputDesc[1].dims.d[0]; | ||
int32_t num_query = inputDesc[3].dims.d[1]; | ||
int32_t num_point = inputDesc[3].dims.d[4]; | ||
int32_t rc = 0; | ||
if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) | ||
{ | ||
float const* value = static_cast<float const*>(inputs[0]); | ||
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]); | ||
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]); | ||
float const* samplingLoc = static_cast<float const*>(inputs[3]); | ||
float const* attnWeight = static_cast<float const*>(inputs[4]); | ||
float* output = static_cast<float*>(outputs[0]); | ||
|
||
rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, | ||
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream); | ||
} | ||
else if (inputDesc[0].type == nvinfer1::DataType::kHALF) | ||
{ | ||
const __half* value = static_cast<const __half*>(inputs[0]); | ||
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]); | ||
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]); | ||
const __half* samplingLoc = static_cast<const __half*>(inputs[3]); | ||
const __half* attnWeight = static_cast<const __half*>(inputs[4]); | ||
__half* output = static_cast<__half*>(outputs[0]); | ||
|
||
rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, | ||
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream); | ||
} | ||
|
||
return rc; | ||
} | ||
|
||
nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType( | ||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { | ||
return inputTypes[0]; | ||
} | ||
|
||
// IPluginV2 Methods | ||
const char *MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT { | ||
return PLUGIN_NAME; | ||
} | ||
|
||
const char *MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } | ||
|
||
size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { | ||
return 0; | ||
} | ||
|
||
void MultiScaleDeformableAttnPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {} | ||
|
||
void MultiScaleDeformableAttnPluginDynamic::attachToContext( | ||
cudnnContext *cudnnContext, cublasContext *cublasContext, | ||
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {} | ||
|
||
void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {} | ||
|
||
////////////////////// creator ///////////////////////////// | ||
|
||
MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() { | ||
mPluginAttributes.clear(); | ||
mFC.nbFields = mPluginAttributes.size(); | ||
mFC.fields = mPluginAttributes.data(); | ||
} | ||
|
||
const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { | ||
return PLUGIN_NAME; | ||
} | ||
|
||
const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { | ||
return PLUGIN_VERSION; | ||
} | ||
|
||
nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::createPlugin( | ||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { | ||
|
||
MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(name); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin( | ||
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { | ||
auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator); | ||
} // namespace mmdeploy |
71 changes: 71 additions & 0 deletions
71
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#ifndef TRT_MS_DEFORM_ATTN_HPP | ||
#define TRT_MS_DEFORM_ATTN_HPP | ||
#include <cublas_v2.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "trt_plugin_base.hpp" | ||
|
||
namespace mmdeploy { | ||
class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase { | ||
public: | ||
|
||
MultiScaleDeformableAttnPluginDynamic(const std::string &name); | ||
|
||
MultiScaleDeformableAttnPluginDynamic(const std::string name, const void *data, size_t length); | ||
|
||
MultiScaleDeformableAttnPluginDynamic(); | ||
|
||
~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override; | ||
|
||
// IPluginV2DynamicExt Methods | ||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; | ||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, | ||
int nbInputs, nvinfer1::IExprBuilder &exprBuilder) | ||
TRT_NOEXCEPT override; | ||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, | ||
int nbOutputs) TRT_NOEXCEPT override; | ||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, | ||
const nvinfer1::DynamicPluginTensorDesc *out, | ||
int nbOutputs) TRT_NOEXCEPT override; | ||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, | ||
const nvinfer1::PluginTensorDesc *outputs, | ||
int nbOutputs) const TRT_NOEXCEPT override; | ||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, | ||
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; | ||
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, | ||
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; | ||
void detachFromContext() TRT_NOEXCEPT override; | ||
|
||
// IPluginV2Ext Methods | ||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, | ||
int nbInputs) const TRT_NOEXCEPT override; | ||
|
||
// IPluginV2 Methods | ||
const char *getPluginType() const TRT_NOEXCEPT override; | ||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
int getNbOutputs() const TRT_NOEXCEPT override; | ||
size_t getSerializationSize() const TRT_NOEXCEPT override; | ||
void serialize(void *buffer) const TRT_NOEXCEPT override; | ||
}; | ||
|
||
class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase { | ||
public: | ||
MultiScaleDeformableAttnPluginDynamicCreator(); | ||
|
||
const char *getPluginName() const TRT_NOEXCEPT override; | ||
|
||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
|
||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) | ||
TRT_NOEXCEPT override; | ||
|
||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, | ||
size_t serialLength) TRT_NOEXCEPT override; | ||
}; | ||
} // namespace mmdeploy | ||
#endif // TRT_MS_DEFORM_ATTN_HPP |
57 changes: 57 additions & 0 deletions
57
csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#include <assert.h> | ||
#include <cuda_fp16.h> | ||
|
||
#include "common_cuda_helper.hpp" | ||
#include "trt_ms_deform_attn_kernel.cuh" | ||
#include "trt_ms_deform_attn_kernel.hpp" | ||
#include "trt_plugin_helper.hpp" | ||
|
||
template <typename scalar_t> | ||
void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, int32_t const* dataSpatialShapes, | ||
int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, | ||
int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, int32_t const numLevels, | ||
int32_t const numQuery, int32_t const numPoint, scalar_t* dataCol) | ||
{ | ||
int32_t const numKernels = batchSize * numQuery * numHeads * channels; | ||
int32_t const numActualKernels = batchSize * numQuery * numHeads * channels; | ||
|
||
ms_deformable_im2col_gpu_kernel<scalar_t><<<GET_BLOCKS(numActualKernels), THREADS_PER_BLOCK, 0, stream>>>( | ||
numKernels, dataValue, dataSpatialShapes, dataLevelStartIndex, dataSamplingLoc, dataAttnWeight, batchSize, | ||
spatialSize, numHeads, channels, numLevels, numQuery, numPoint, dataCol); | ||
} | ||
|
||
|
||
template <typename scalar_t> | ||
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, | ||
const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch, | ||
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, | ||
cudaStream_t stream) | ||
{ | ||
auto perValueSize = mSpatialSize * mNumHeads * mChannels; | ||
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2; | ||
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint; | ||
auto perOutputSize = mNumQuery * mNumHeads * mChannels; | ||
|
||
int32_t mIm2colStep = batch; | ||
|
||
for (int32_t n = 0; n < batch / mIm2colStep; ++n) | ||
{ | ||
auto columns = output + n * mIm2colStep * perOutputSize; | ||
ms_deformable_im2col_cuda<scalar_t>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex, | ||
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep, | ||
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns); | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
template int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes, | ||
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch, | ||
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, | ||
cudaStream_t stream); | ||
|
||
template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes, | ||
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch, | ||
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, | ||
cudaStream_t stream); |
Oops, something went wrong.