Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] arctan2 #1065

Closed
Tracked by #19571
yrahul3910 opened this issue May 2, 2024 · 3 comments
Closed
Tracked by #19571

[Feature] arctan2 #1065

yrahul3910 opened this issue May 2, 2024 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@yrahul3910
Copy link
Contributor

I want to eventually contribute to this Keras issue, and since this would be my first contribution, I want to start small and implement arctan2. If nobody is currently working on this, I would like to contribute this feature. Specifically, my implementation plan, based on looking at arctan, is as follows:

  • Add an ArcTan2 class to mlx/primitives.h, and its corresponding vjp, jvp, and vmap functions in mlx/primitives.cpp
  • Add an ArcTan2 struct to mlx/backend/common/ops.h, and overload the function call operator.
  • Add an ArcTan2 struct to mlx/backend/metal/unary.h, and add instantitate_unary_float calls in mlx/backend/metal/unary.metal. The necessary metal::precise::atan2 function seems to already be implemented.
  • Add tests, and update docs.

Please let me know if I can proceed.

@awni
Copy link
Member

awni commented May 2, 2024

Your description is right modulo the fact that it is a binary op not a unary op.

I would look at another binary op like Add as an example. Most of the implementation should be boiler plate copying of another binary primitive. The core change will be in the op struct which gets passed to the CPU / GPU binary op kernel.

@awni
Copy link
Member

awni commented May 2, 2024

And yes please add this, that would be great!

@awni awni added the enhancement New feature or request label May 3, 2024
@yrahul3910 yrahul3910 mentioned this issue May 5, 2024
4 tasks
@yrahul3910
Copy link
Contributor Author

I think I did it :) When running make, I get warnings about the use of binary(...) (but it successfully builds and tests pass). This is because the binary function also tests bool, so std::atan2 warns about precision loss. I'm not entirely sure what to do there--maybe for future binary (but not bool) ops we have a different function called binary_fp, like we have for unary ops?

@awni awni closed this as completed May 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants