You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to eventually contribute to this Keras issue, and since this would be my first contribution, I want to start small and implement arctan2. If nobody is currently working on this, I would like to contribute this feature. Specifically, my implementation plan, based on looking at arctan, is as follows:
Add an ArcTan2 class to mlx/primitives.h, and its corresponding vjp, jvp, and vmap functions in mlx/primitives.cpp
Add an ArcTan2 struct to mlx/backend/common/ops.h, and overload the function call operator.
Add an ArcTan2 struct to mlx/backend/metal/unary.h, and add instantitate_unary_float calls in mlx/backend/metal/unary.metal. The necessary metal::precise::atan2 function seems to already be implemented.
Add tests, and update docs.
Please let me know if I can proceed.
The text was updated successfully, but these errors were encountered:
Your description is right modulo the fact that it is a binary op not a unary op.
I would look at another binary op like Add as an example. Most of the implementation should be boiler plate copying of another binary primitive. The core change will be in the op struct which gets passed to the CPU / GPU binary op kernel.
I think I did it :) When running make, I get warnings about the use of binary(...) (but it successfully builds and tests pass). This is because the binary function also tests bool, so std::atan2 warns about precision loss. I'm not entirely sure what to do there--maybe for future binary (but not bool) ops we have a different function called binary_fp, like we have for unary ops?
I want to eventually contribute to this Keras issue, and since this would be my first contribution, I want to start small and implement
arctan2
. If nobody is currently working on this, I would like to contribute this feature. Specifically, my implementation plan, based on looking atarctan
, is as follows:ArcTan2
class tomlx/primitives.h
, and its correspondingvjp
,jvp
, andvmap
functions inmlx/primitives.cpp
ArcTan2
struct tomlx/backend/common/ops.h
, and overload the function call operator.ArcTan2
struct tomlx/backend/metal/unary.h
, and addinstantitate_unary_float
calls inmlx/backend/metal/unary.metal
. The necessarymetal::precise::atan2
function seems to already be implemented.Please let me know if I can proceed.
The text was updated successfully, but these errors were encountered: