Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid behavior for ops.take with torch #18449

Open
fchollet opened this issue Jul 2, 2023 · 17 comments
Open

Invalid behavior for ops.take with torch #18449

fchollet opened this issue Jul 2, 2023 · 17 comments

Comments

@fchollet
Copy link
Member

fchollet commented Jul 2, 2023

It appears that ops.take(embed_matrix, ids) has a different behavior in torch compared to the other 2 backends.

@asingh9530
Copy link
Contributor

@fchollet if this issue is still valid can I pick this up ? also if possible to share a sample code to reproduce the unexpected behaviour ?

@fchollet
Copy link
Member Author

Yes, it's still valid -- you're welcome to take on it!

You can start by writing a unit test for it -- if I recall correctly the issue has to do with using take() with multi-dimensional IDs. So like, a 2D embed_matrix and 2D or 3D (batched) IDs?

@asingh9530
Copy link
Contributor

@fchollet got it.

@asingh9530
Copy link
Contributor

asingh9530 commented Jul 13, 2023

@fchollet are we open to write test cases in pytest or only using unittest ? also we will be writing test cases here only right ?

@ariG23498
Copy link
Contributor

@fchollet are we open to write test cases in pytest or only using unittest ? also we will be writing test cases here only right ?

Hey @asingh9530 this is where the ops.take comes from. You would need to make the necessary changes here! And as per the tests, you would need to make changes to this file.

Hope that helps 🤗

@asingh9530
Copy link
Contributor

@ariG23498 I think this is just a wrapper since it will be running backend.take and the problem is occurring in torch.take I think this will be correct place to make changes correct me if I am wrong 🤔 also thanks for pointing out exact test file.

@ariG23498
Copy link
Contributor

Yes! You are correct, sorry for pointing to the wrapper and not the actual backend file!

@asingh9530
Copy link
Contributor

@ariG23498 I think test-case has already been written from line 191 to 204, so what remains is fixing bug with backend == 'torch'

 def test_take(self):
        x = KerasTensor([None, 3])
        self.assertEqual(knp.take(x, 1).shape, ())
        self.assertEqual(knp.take(x, [1, 2]).shape, (2,))
        self.assertEqual(
            knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2)
        )

        x = KerasTensor([None, 3, 3])
        self.assertEqual(knp.take(x, 1, axis=1).shape, (None, 3))
        self.assertEqual(knp.take(x, [1, 2]).shape, (2,))
        self.assertEqual(
            knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2, 3)
        )

@ariG23498
Copy link
Contributor

You can start by writing a unit test for it -- if I recall correctly the issue has to do with using take() with multi-dimensional IDs. So like, a 2D embed_matrix and 2D or 3D (batched) IDs?

I think you would also need to add this test as @fchollet has advised. This would help make the take() ops more robust across the backends. After you see that failing for torch (hopefully) you could move to resolving the issue.

Does that help?

@asingh9530
Copy link
Contributor

@ariG23498 yes this is what I am planning thanks.

@asingh9530
Copy link
Contributor

As per my understanding invalid behaviour of take in torch backend comes from following logic which is defined here from line 749 which is

if x.ndim == 2 and (axis is None or axis == 0):
        
        # This case is equivalent to embedding lookup.
        return torch.nn.functional.embedding(indices, x)
    

The core issue with this logic as this is primarily searching for 1-d values embeddings in a 2-d matrix embedding matrix for index specified in indices which makes it impossible to run lets say following logic

tensor = [[4, 3, 5],[6, 7, 8]]
indices = [0, 2, 5]
axis = None

which ideally should output [4, 5, 8] when comparing it to torch.take() but as the logic states this get's treated as embedding lookup and it fails.

Possible solution 1:
Have a lookup flag as parameter in function call but since this implementation will only be applicable in torch setting. This might not be optimal.

Possible solution 2:
making lookup as a separate operation rather than part of take, but I believe this will require rewrite of many other component, so not sure how feasible is this.

@fchollet @ariG23498 let me know your thoughts on this 🙂.

@asingh9530
Copy link
Contributor

@fchollet Can I proceed to make change and raise a PR ?

@asingh9530
Copy link
Contributor

@fchollet I think if would like the behaviour of ops.take() with backend torch to be equivalent to numpy.take() then I think we are still doing fine as for all the 2-d or 3-d indices or any multidimensional indices but if we want for it to replicate behaviour of only torch.take() + torch.indices_select() then a added lookup variable would be required as mentioned above.

@ariG23498 If there is any specific case you encountered where ops.take() behaviour is not similar to numpy.take() then please share as I am not able to reproduce any specific set of inputs where its behaviour is not as expected as numpy.take()

@fchollet
Copy link
Member Author

@asingh9530 absolutely, please open a PR! Numpy behavior should indeed be our reference.

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@sachinprasadhs sachinprasadhs self-assigned this Apr 12, 2024
@sachinprasadhs
Copy link
Collaborator

I tried the below code with Torch, it is now giving same result as TensorFlow and Jax.

import os
os.environ["KERAS_BACKEND"] = "torch"
import numpy as np
import keras

a = np.array([[0, 1.0, 2.0],
              [10.0, 11.0, 12.0],
              [20.0, 21.0, 22.0],
              [30.0, 31.0, 32.0]])
print(keras.ops.take(a,[3,1],axis=0))

tensor([[30., 31., 32.],
        [10., 11., 12.]], dtype=torch.float64)

Only in out of range index, torch throws IndexError, JAX handles with nan values and TensorFlow provides the last axis values.

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Apr 27, 2024
@innat-asj
Copy link

innat-asj commented Apr 27, 2024

@fchollet
FYI, #19238 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants