Skip to content

Commit

Permalink
fix template
Browse files Browse the repository at this point in the history
  • Loading branch information
cxiang26 committed Mar 14, 2023
1 parent dcf0caf commit 16d7c1e
Showing 1 changed file with 12 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, i
spatialSize, numHeads, channels, numLevels, numQuery, numPoint, dataCol);
}

template <>
int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,

template <typename scalar_t>
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream)
{
Expand All @@ -37,33 +38,20 @@ int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* sp
for (int32_t n = 0; n < batch / mIm2colStep; ++n)
{
auto columns = output + n * mIm2colStep * perOutputSize;
ms_deformable_im2col_cuda<float>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
ms_deformable_im2col_cuda<scalar_t>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep,
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns);
}

return 0;
}

template <>
int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
template int32_t ms_deform_attn_cuda_forward<float>(const float* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const float* samplingLoc, const float* attnWeight, float* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream)
{
auto perValueSize = mSpatialSize * mNumHeads * mChannels;
auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2;
auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint;
auto perOutputSize = mNumQuery * mNumHeads * mChannels;
cudaStream_t stream);

int32_t mIm2colStep = batch;
for (int32_t n = 0; n < batch / mIm2colStep; ++n)
{
auto columns = output + n * mIm2colStep * perOutputSize;
ms_deformable_im2col_cuda<__half>(stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex,
samplingLoc + n * mIm2colStep * perSampleLocSize, attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep,
mSpatialSize, mNumHeads, mChannels, mNumLevels, mNumQuery, mNumPoint, columns);
}

return 0;
}
template int32_t ms_deform_attn_cuda_forward<__half>(const __half* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream);

0 comments on commit 16d7c1e

Please sign in to comment.