Skip to content

Commit

Permalink
only compare the score value while verifying results between ONNX and…
Browse files Browse the repository at this point in the history
… pytorch
  • Loading branch information
Han Ruobing committed Jun 22, 2020
1 parent 6791e95 commit f4eb9d9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
5 changes: 3 additions & 2 deletions mmdet/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def bbox2result(bboxes, labels, num_classes):
if bboxes.shape[0] == 0:
return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
else:
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
if isinstance(bboxes, torch.Tensor):
bboxes = bboxes.cpu().numpy()
labels = labels.cpu().numpy()
return [bboxes[labels == i, :] for i in range(num_classes)]


Expand Down
8 changes: 5 additions & 3 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def pytorch2onnx(model,
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
from mmdet.core import bbox2result
det_bboxes, det_labels = sess.run(
None, {net_feed_input[0]: one_img.detach().numpy()})
# only compare a part of result
onnx_res = det_bboxes[det_labels == 0, :]
bbox_results = bbox2result(det_bboxes, det_labels, 1)
onnx_results = bbox_results[0]
assert (np.abs(
(pytorch_result[0] - onnx_res) / pytorch_result[0]) > 0.01).sum(
) == 0, 'The outputs are different between Pytorch and ONNX'
(pytorch_result[0][:, 4] - onnx_results[:, 4])) > 0.01).sum(
) == 0, 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')


Expand Down

0 comments on commit f4eb9d9

Please sign in to comment.