Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metatensor.torch.sort is broken #642

Closed
frostedoyster opened this issue May 31, 2024 · 11 comments · Fixed by #647
Closed

metatensor.torch.sort is broken #642

frostedoyster opened this issue May 31, 2024 · 11 comments · Fixed by #647
Labels
bug Something isn't working Operations Related to metatensor-operations in Python

Comments

@frostedoyster
Copy link
Contributor

I leave it to @ppegolo for more details

@DavideTisi
Copy link
Contributor

ahahah it is two day that I am try to fix/improve sort

@jwa7
Copy link
Member

jwa7 commented May 31, 2024

Broken for me too! Here's an MWE:

import torch 
import metatensor.torch as mts

tensor = mts.TensorMap(
    keys=mts.Labels(
        names=["a", "b"],
        values=torch.tensor([[2, 1], [1, 5]]),
    ),
    blocks=[
        mts.TensorBlock(
            values=torch.randn(3, 2),
            samples=mts.Labels(
                names=["s1", "s2", "s3"],
                values=torch.tensor(
                    [
                        [0, 1, 2],
                        [2, 3, 4],
                        [1, 5, 7],
                    ]
                )
            ),
            components=[],
            properties=mts.Labels(
                names=["p1", "p2"],
                values=torch.tensor(
                    [
                        [100, 0],
                        [5, 7000],
                    ]
                )
            )
        ),
        mts.TensorBlock(
            values=torch.randn(3, 3),
            samples=mts.Labels(
                names=["s1", "s2", "s3"],
                values=torch.tensor(
                    [
                        [0, 2, 2],
                        [0, 1, 2],
                        [1, 5, 7],
                    ]
                ),
            ),
            components=[],
            properties=mts.Labels(
                names=["p1", "p2"],
                values=torch.tensor(
                    [
                        [5, 10],
                        [5, 5],
                        [5, 6],
                    ]
                )
            ),
        )
    ]
)
# Sort
sorted = mts.sort(tensor, "all")

# All metadata is the same
if tensor.keys == sorted.keys:
    print("Keys:", tensor.keys, sorted.keys)
for key in tensor.keys:
    block = tensor[key]
    block_sorted = sorted[key]
    if block.samples == block_sorted.samples:  # not sorted
        print("Key:", key, "samples:", block.samples, block_sorted.samples)
    if block.properties == block_sorted.properties:  # not sorted
        print("Key:", key, "properties:", block.properties, block_sorted.properties)
    if torch.all(block.values == block_sorted.values):  # not sorted
        print("Key:", key, "values:", block.values, block_sorted.values)

@Luthaf Luthaf added the Operations Related to metatensor-operations in Python label May 31, 2024
@ppegolo
Copy link

ppegolo commented May 31, 2024

Thanks Joe, that's exactly the same issue I'm facing.

@Luthaf Luthaf added the bug Something isn't working label May 31, 2024
@jwa7
Copy link
Member

jwa7 commented May 31, 2024

I'm guessing the error is here, in _dispatch.argsort_labels_values()?

    if isinstance(labels_values, TorchTensor):
        # torchscript does not support sorted for List[List[int]]
        # so we temporary do this trick. this will be fixed with issue #366
        max_int = torch.max(labels_values)
        idx = torch.sum(
            max_int ** torch.arange(labels_values.shape[1]) * labels_values, dim=1
        )
        return torch.argsort(idx, dim=-1, descending=reverse)

https://github.com/lab-cosmo/metatensor/blob/ec6ee61ebb6eab326462386171f5cd777f12a3c4/python/metatensor-operations/metatensor/operations/_dispatch.py#L126C1-L133C62

Seems like a strange hack

@DavideTisi
Copy link
Contributor

mmm this is weird in here (cae018d) I should have added test for the metatensor-torch sort operation.

These are basically the same test we use for the numpy version and they seem to pass at least on my local machine.
Is it possible that this is related to some CPU/GPU thing? maybe the data are on GPU and they don't get properly handled?

This is pure guessing

or maybe i did something wrong while doing the tests

@curiosity54
Copy link
Contributor

we tested this on cpu, so also not sure

@DavideTisi
Copy link
Contributor

ok i simply do not know how to use tox

@DavideTisi
Copy link
Contributor

DavideTisi commented Jun 2, 2024

So in @jwa7 example the code does nothing becouse the starting tensormap is already ordered. Even if you use numpy instead of torch the result is the same. The code, now, order by last column and the inputs that @jwa7 wrote are already ordered by the last columns.
Already changing the keys in values=torch.tensor([[2, 1], [1, 0]]), would show that the code do in fact sort "correctly" everything

@jwa7
Copy link
Member

jwa7 commented Jun 2, 2024

@DavideTisi Hmm I'm not sure I follow. What do you mean the TensorMap is sorted in my example? I created it so that the labels are not sorted.

When I run the code as above, there is a lot of printed output, showing that the sort operation doesn't do any sorting as intended. When modifying the example to instead import metatensor as mts and use numpy instead of torch, the example prints no output as the sort operation behaves as expected.

My conclusion that was it's a problem with the torch side of the operation, so I'm confused as to why you're seeing what you're seeing.

@DavideTisi
Copy link
Contributor

@jwa7 ok tomorrow I will look better because if I run your code both with torch or numpy I get the same result

@DavideTisi
Copy link
Contributor

so what i did in #644 to _dispatch did fix your example (which I also added to the metatensor.torch test, but it is not a real fix for this problem since the algorithm is not really stable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Operations Related to metatensor-operations in Python
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants