-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
metatensor.torch.sort
is broken
#642
Comments
ahahah it is two day that I am try to fix/improve sort |
Broken for me too! Here's an MWE: import torch
import metatensor.torch as mts
tensor = mts.TensorMap(
keys=mts.Labels(
names=["a", "b"],
values=torch.tensor([[2, 1], [1, 5]]),
),
blocks=[
mts.TensorBlock(
values=torch.randn(3, 2),
samples=mts.Labels(
names=["s1", "s2", "s3"],
values=torch.tensor(
[
[0, 1, 2],
[2, 3, 4],
[1, 5, 7],
]
)
),
components=[],
properties=mts.Labels(
names=["p1", "p2"],
values=torch.tensor(
[
[100, 0],
[5, 7000],
]
)
)
),
mts.TensorBlock(
values=torch.randn(3, 3),
samples=mts.Labels(
names=["s1", "s2", "s3"],
values=torch.tensor(
[
[0, 2, 2],
[0, 1, 2],
[1, 5, 7],
]
),
),
components=[],
properties=mts.Labels(
names=["p1", "p2"],
values=torch.tensor(
[
[5, 10],
[5, 5],
[5, 6],
]
)
),
)
]
)
# Sort
sorted = mts.sort(tensor, "all")
# All metadata is the same
if tensor.keys == sorted.keys:
print("Keys:", tensor.keys, sorted.keys)
for key in tensor.keys:
block = tensor[key]
block_sorted = sorted[key]
if block.samples == block_sorted.samples: # not sorted
print("Key:", key, "samples:", block.samples, block_sorted.samples)
if block.properties == block_sorted.properties: # not sorted
print("Key:", key, "properties:", block.properties, block_sorted.properties)
if torch.all(block.values == block_sorted.values): # not sorted
print("Key:", key, "values:", block.values, block_sorted.values) |
Thanks Joe, that's exactly the same issue I'm facing. |
I'm guessing the error is here, in if isinstance(labels_values, TorchTensor):
# torchscript does not support sorted for List[List[int]]
# so we temporary do this trick. this will be fixed with issue #366
max_int = torch.max(labels_values)
idx = torch.sum(
max_int ** torch.arange(labels_values.shape[1]) * labels_values, dim=1
)
return torch.argsort(idx, dim=-1, descending=reverse) Seems like a strange hack |
mmm this is weird in here (cae018d) I should have added test for the metatensor-torch sort operation. These are basically the same test we use for the numpy version and they seem to pass at least on my local machine. This is pure guessing or maybe i did something wrong while doing the tests |
we tested this on cpu, so also not sure |
ok i simply do not know how to use tox |
So in @jwa7 example the code does nothing becouse the starting tensormap is already ordered. Even if you use |
@DavideTisi Hmm I'm not sure I follow. What do you mean the TensorMap is sorted in my example? I created it so that the labels are not sorted. When I run the code as above, there is a lot of printed output, showing that the My conclusion that was it's a problem with the torch side of the operation, so I'm confused as to why you're seeing what you're seeing. |
@jwa7 ok tomorrow I will look better because if I run your code both with torch or numpy I get the same result |
so what i did in #644 to |
I leave it to @ppegolo for more details
The text was updated successfully, but these errors were encountered: