Skip to content

Commit

Permalink
support multi_scale_deform_attn trt plugin (#1844)
Browse files Browse the repository at this point in the history
* support multi_scale_deform_attn trt plugin

* fix lint error

* add onnx symblic fun

* init msdeformablecrossattn symblic fun & fix plugin

* unittest for ms_deformable_cross_attn

* fix template

* update unittest

* fix input contiguous of trtwrapper

* update doc description
  • Loading branch information
cxiang26 committed Mar 17, 2023
1 parent 8e2f655 commit 50bd7e9
Show file tree
Hide file tree
Showing 10 changed files with 715 additions and 0 deletions.
@@ -0,0 +1,181 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_ms_deform_attn.hpp"

#include <assert.h>

#include <chrono>

#include "trt_ms_deform_attn_kernel.hpp"
#include "trt_serialize.hpp"

using namespace nvinfer1;

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"};
} // namespace

MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string &name)
: TRTPluginBase(name) {}

MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name,
const void *data,
size_t length)
: TRTPluginBase(name) {}
MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {}

nvinfer1::IPluginV2DynamicExt *MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT {
MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[3].d[1];

ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
*inputs[0].d[2], *inputs[0].d[3]);

return ret;
}

bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {

if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR)
{
if ((pos == 1) || (pos == 2))
{
return (ioDesc[pos].type == nvinfer1::DataType::kINT32);
}
else
{
return ((ioDesc[pos].type == ioDesc[0].type) &&
((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || (ioDesc[pos].type == nvinfer1::DataType::kHALF)));
}
}
else
{
return false;
}
}

void MultiScaleDeformableAttnPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT {
}

size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT {
return 0;
}

int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs,
void *workSpace,
cudaStream_t stream) TRT_NOEXCEPT {
int32_t const batch = inputDesc[0].dims.d[0];
int32_t spatial_size = inputDesc[0].dims.d[1];
int32_t num_heads = inputDesc[0].dims.d[2];
int32_t channels = inputDesc[0].dims.d[3];
int32_t num_levels = inputDesc[1].dims.d[0];
int32_t num_query = inputDesc[3].dims.d[1];
int32_t num_point = inputDesc[3].dims.d[4];
int32_t rc = 0;
if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
{
float const* value = static_cast<float const*>(inputs[0]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
float const* samplingLoc = static_cast<float const*>(inputs[3]);
float const* attnWeight = static_cast<float const*>(inputs[4]);
float* output = static_cast<float*>(outputs[0]);

rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream);
}
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
{
const __half* value = static_cast<const __half*>(inputs[0]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
const __half* samplingLoc = static_cast<const __half*>(inputs[3]);
const __half* attnWeight = static_cast<const __half*>(inputs[4]);
__half* output = static_cast<__half*>(outputs[0]);

rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream);
}

return rc;
}

nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}

const char *MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }

size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
return 0;
}

void MultiScaleDeformableAttnPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {}

void MultiScaleDeformableAttnPluginDynamic::attachToContext(
cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {}

void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {}

////////////////////// creator /////////////////////////////

MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() {
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}

const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {

MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT {
auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator);
} // namespace mmdeploy
@@ -0,0 +1,71 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_MS_DEFORM_ATTN_HPP
#define TRT_MS_DEFORM_ATTN_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase {
public:

MultiScaleDeformableAttnPluginDynamic(const std::string &name);

MultiScaleDeformableAttnPluginDynamic(const std::string name, const void *data, size_t length);

MultiScaleDeformableAttnPluginDynamic();

~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override;

// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
};

class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase {
public:
MultiScaleDeformableAttnPluginDynamicCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_MS_DEFORM_ATTN_HPP
@@ -0,0 +1,57 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <assert.h>
#include <cuda_fp16.h>

#include "common_cuda_helper.hpp"
#include "trt_ms_deform_attn_kernel.cuh"
#include "trt_ms_deform_attn_kernel.hpp"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, int32_t const* dataSpatialShapes,
int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight,
int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, int32_t const numLevels,
int32_t const numQuery, int32_t const numPoint, scalar_t* dataCol)
{
int32_t const numKernels = batchSize * numQuery * numHeads * channels;
int32_t const numActualKernels = batchSize * numQuery * numHeads * channels;

ms_deformable_im2col_gpu_kernel<scalar_t><<<GET_BLOCKS(numActualKernels), THREADS_PER_BLOCK, 0, stream>>>(
numKernels, dataValue, dataSpatialShapes, dataLevelStartIndex, dataSamplingLoc, dataAttnWeight, batchSize,
spatialSize, numHeads, channels, numLevels, numQuery, numPoint, dataCol);
}


template <typename scalar_t>
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream)
{
auto perValueSize = mSpatialSize * mNumHeads * mChannels;
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2;
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint;
auto perOutputSize = mNumQuery * mNumHeads * mChannels;

int32_t mIm2colStep = batch;

for (int32_t n = 0; n < batch / mIm2colStep; ++n)
{
auto columns = output + n * mIm2colStep * perOutputSize;
ms_deformable_im2col_cuda<scalar_t>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep,
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns);
}

return 0;
}

template int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream);

template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream);

0 comments on commit 50bd7e9

Please sign in to comment.