Skip to content

Commit

Permalink
update unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
cxiang26 committed Mar 15, 2023
1 parent f5317c6 commit 0a65320
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include <cuda_runtime.h>

template <typename scalar_t>
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes,
const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch,
int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, const int32_t* levelStartIndex,
const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch, int32_t mSpatialSize,
int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint,
cudaStream_t stream);

#endif
23 changes: 15 additions & 8 deletions tests/test_ops/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,23 +1234,25 @@ def test_multi_scale_deformable_attn(backend, save_dir=None):
Bs = 2
Nh = 8
Nc = 32
Nq = 100
Np = 200
Nq = 32
Np = 32
spatial_shapes = [[68, 120], [34, 60]]
Nl = len(spatial_shapes)
value_spatial_shapes = torch.LongTensor(spatial_shapes).cuda()
Nl = value_spatial_shapes.shape[0]
Nk = sum([spatial_shapes[i][0] * spatial_shapes[i][1] for i in range(Nl)])
value = torch.rand(Bs, Nk, Nh, Nc).cuda()
value_spatial_shapes = torch.LongTensor(spatial_shapes).cuda()
level_start_index = torch.LongTensor(
[0, spatial_shapes[0][0] * spatial_shapes[0][1]]).cuda()
level_start_index = torch.cat((
value_spatial_shapes.new_zeros((1, )),
value_spatial_shapes.prod(1).cumsum(0)[:-1].to(torch.int64),
))
sampling_locations = torch.rand(Bs, Nq, Nh, Nl, Np, 2).cuda()
attention_weights = torch.rand(Bs, Nq, Nh, Nl, Np).cuda()

class TestModel(torch.nn.Module):

def __init__(self) -> None:
super().__init__()
self.im2col_step = 64
self.im2col_step = 32

def forward(self, value, value_spatial_shapes, level_start_index,
sampling_locations, attention_weights):
Expand All @@ -1262,7 +1264,12 @@ def forward(self, value, value_spatial_shapes, level_start_index,

model = TestModel().cuda()

with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
with RewriterContext(
Config({'backend_config': {
'type': backend.backend_name
}}),
backend=backend.backend_name,
opset=11):
backend.run_and_validate(
model, [
value, value_spatial_shapes, level_start_index,
Expand Down

0 comments on commit 0a65320

Please sign in to comment.