Skip to content

Commit

Permalink
Apply changes from pytorch#37846 to test_topk_smallest_unsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
malfet committed Jun 11, 2020
1 parent 3bb338a commit d6db5d5
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,11 @@ def forward(self, x):
def test_topk_smallest_unsorted(self):
class MyModule(torch.nn.Module):
def forward(self, x, k):
return torch.topk(x, k, largest=False, sorted=False)
# When sorted=False, order of elements in the outout tensors
# are not expected to match between PyTorch and ORT
topk_unsorted = torch.topk(x, k, largest=False, sorted=False)
topk_sorted = torch.topk(x, k, largest=False, sorted=True)
return topk_sorted, torch.sort(topk_unsorted.values).values

x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
Expand Down

0 comments on commit d6db5d5

Please sign in to comment.