Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support multi_scale_deform_attn trt plugin #1844

Merged
merged 9 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_ms_deform_attn.hpp"

#include <assert.h>

#include <chrono>

#include "trt_ms_deform_attn_kernel.hpp"
#include "trt_serialize.hpp"

using namespace nvinfer1;

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"};
} // namespace

MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string &name)
: TRTPluginBase(name) {}

MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name,
const void *data,
size_t length)
: TRTPluginBase(name) {}
MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {}

nvinfer1::IPluginV2DynamicExt *MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT {
MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[3].d[1];

ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD,
*inputs[0].d[2], *inputs[0].d[3]);

return ret;
}

bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {

if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR)
{
if ((pos == 1) || (pos == 2))
{
return (ioDesc[pos].type == nvinfer1::DataType::kINT32);
}
else
{
return ((ioDesc[pos].type == ioDesc[0].type) &&
((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || (ioDesc[pos].type == nvinfer1::DataType::kHALF)));
}
}
else
{
return false;
}
}

void MultiScaleDeformableAttnPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT {
}

size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT {
return 0;
}

int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs,
void *workSpace,
cudaStream_t stream) TRT_NOEXCEPT {
int32_t const batch = inputDesc[0].dims.d[0];
int32_t spatial_size = inputDesc[0].dims.d[1];
int32_t num_heads = inputDesc[0].dims.d[2];
int32_t channels = inputDesc[0].dims.d[3];
int32_t num_levels = inputDesc[1].dims.d[0];
int32_t num_query = inputDesc[3].dims.d[1];
int32_t num_point = inputDesc[3].dims.d[4];
int32_t rc = 0;
if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
{
float const* value = static_cast<float const*>(inputs[0]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
float const* samplingLoc = static_cast<float const*>(inputs[3]);
float const* attnWeight = static_cast<float const*>(inputs[4]);
float* output = static_cast<float*>(outputs[0]);

rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream);
}
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
{
const __half* value = static_cast<const __half*>(inputs[0]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
const __half* samplingLoc = static_cast<const __half*>(inputs[3]);
const __half* attnWeight = static_cast<const __half*>(inputs[4]);
__half* output = static_cast<__half*>(outputs[0]);

rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream);
}

return rc;
}

nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}

const char *MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }

size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
return 0;
}

void MultiScaleDeformableAttnPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {}

void MultiScaleDeformableAttnPluginDynamic::attachToContext(
cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {}

void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {}

////////////////////// creator /////////////////////////////

MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() {
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}

const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {

MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT {
auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator);
} // namespace mmdeploy
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_MS_DEFORM_ATTN_HPP
#define TRT_MS_DEFORM_ATTN_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase {
public:

MultiScaleDeformableAttnPluginDynamic(const std::string &name);

MultiScaleDeformableAttnPluginDynamic(const std::string name, const void *data, size_t length);

MultiScaleDeformableAttnPluginDynamic();

~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override;

// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
};

class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase {
public:
MultiScaleDeformableAttnPluginDynamicCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_MS_DEFORM_ATTN_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <assert.h>
#include <cuda_fp16.h>

#include "common_cuda_helper.hpp"
#include "trt_ms_deform_attn_kernel.cuh"
#include "trt_ms_deform_attn_kernel.hpp"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, int32_t const* dataSpatialShapes,
int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight,
int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, int32_t const numLevels,
int32_t const numQuery, int32_t const numPoint, scalar_t* dataCol)
{
int32_t const numKernels = batchSize * numQuery * numHeads * channels;
int32_t const numActualKernels = batchSize * numQuery * numHeads * channels;

ms_deformable_im2col_gpu_kernel<scalar_t><<<GET_BLOCKS(numActualKernels), THREADS_PER_BLOCK, 0, stream>>>(
numKernels, dataValue, dataSpatialShapes, dataLevelStartIndex, dataSamplingLoc, dataAttnWeight, batchSize,
spatialSize, numHeads, channels, numLevels, numQuery, numPoint, dataCol);
}


template <typename scalar_t>
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream)
{
auto perValueSize = mSpatialSize * mNumHeads * mChannels;
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2;
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint;
auto perOutputSize = mNumQuery * mNumHeads * mChannels;

int32_t mIm2colStep = batch;

for (int32_t n = 0; n < batch / mIm2colStep; ++n)
{
auto columns = output + n * mIm2colStep * perOutputSize;
ms_deformable_im2col_cuda<scalar_t>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep,
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns);
}

return 0;
}

template int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream);

template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream);