New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Invalid behavior for ops.take
with torch
#18449
Comments
@fchollet if this issue is still valid can I pick this up ? also if possible to share a sample code to reproduce the unexpected behaviour ? |
Yes, it's still valid -- you're welcome to take on it! You can start by writing a unit test for it -- if I recall correctly the issue has to do with using |
@fchollet got it. |
Hey @asingh9530 this is where the Hope that helps 🤗 |
@ariG23498 I think this is just a wrapper since it will be running |
Yes! You are correct, sorry for pointing to the wrapper and not the actual backend file! |
@ariG23498 I think test-case has already been written from line 191 to 204, so what remains is fixing bug with
|
I think you would also need to add this test as @fchollet has advised. This would help make the Does that help? |
@ariG23498 yes this is what I am planning thanks. |
As per my understanding invalid behaviour of
The core issue with this logic as this is primarily searching for 1-d values
which ideally should output Possible solution 1: Possible solution 2: @fchollet @ariG23498 let me know your thoughts on this 🙂. |
@fchollet Can I proceed to make change and raise a PR ? |
@fchollet I think if would like the behaviour of @ariG23498 If there is any specific case you encountered where |
@asingh9530 absolutely, please open a PR! Numpy behavior should indeed be our reference. |
I tried the below code with Torch, it is now giving same result as TensorFlow and Jax. import os
os.environ["KERAS_BACKEND"] = "torch"
import numpy as np
import keras
a = np.array([[0, 1.0, 2.0],
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
[30.0, 31.0, 32.0]])
print(keras.ops.take(a,[3,1],axis=0))
tensor([[30., 31., 32.],
[10., 11., 12.]], dtype=torch.float64) Only in out of range index, torch throws |
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
It appears that
ops.take(embed_matrix, ids)
has a different behavior in torch compared to the other 2 backends.The text was updated successfully, but these errors were encountered: