Skip to content

Commit

Permalink
Fix bug in sort (#437)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Sep 22, 2023
1 parent 2d5cec2 commit 1862c15
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.9a10"
version = "0.0.9a11"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
2 changes: 1 addition & 1 deletion src/redcat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,7 @@ def sort(self, *args, **kwargs) -> torch.return_types.sort:
[4, 3, 2, 1, 0]], batch_dim=0))
"""
out = torch.sort(self._data, *args, **kwargs)
return type(out)([BatchedTensor(data=o, batch_dim=self._batch_dim) for o in out])
return type(out)([self._create_new_batch(o) for o in out])

def sort_along_batch(self, *args, **kwargs) -> torch.return_types.sort:
r"""Sorts the elements of the batch along the batch dimension in
Expand Down
4 changes: 2 additions & 2 deletions src/redcat/tensorseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,9 @@ def sort_along_seq(self, *args, **kwargs) -> torch.return_types.sort:
>>> batch.sort_along_seq(descending=True)
torch.return_types.sort(
values=tensor([[4, 3, 2, 1, 0],
[9, 8, 7, 6, 5]], batch_dim=0),
[9, 8, 7, 6, 5]], batch_dim=0, seq_dim=1),
indices=tensor([[4, 3, 2, 1, 0],
[4, 3, 2, 1, 0]], batch_dim=0))
[4, 3, 2, 1, 0]], batch_dim=0, seq_dim=1))
"""
return self.sort(self._seq_dim, *args, **kwargs)

Expand Down

0 comments on commit 1862c15

Please sign in to comment.