-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex maths functions #21539
Complex maths functions #21539
Conversation
Also modify it to pass some more tests, there seem to be issues still with approximations being mismatched between different backends, but to fix that will require implementing erf for complex inputs too, which should be its own commit
To improve support of complex numbers. pow was having stability issues due to exp(log(r)), while exp had no native complex number support so this was added (and eventually the call to it from pow was dropped, but its still useful to have). Also adjusted the tests of both, and of gelu
Thanks for contributing to Ivy! 😊👏 |
If you are working on an open task, please edit the PR description to link to the issue you've created. For more information, please check ToDo List Issues Guide. Thank you 🤗 |
Implemented for asin, acos, atan, asinh, acosh, atanh. Also, modified tests for functions (e.g. exp, log, sqrt) that already support complex numbers across backends so that complex numbers will also be tested. This introduces a few new failures, mostly related to numerical stability in JAX, but this is only due to greater test coverage.
That included refactoring GeLU so the calculation of the complex approximation is done in the backend rather than the Ivy API. Also restricts the domain of the numpy square function to prevent overflows, and updates the paddle version on paddle frontend's expm1
Restrict test cases to avoid generation of too large or too small numbers, modify the JAX backend log1p because it gives incorrect values for some complex numbers.
There remains a bug with |
Restrict the domain of more tests to avoid test cases that introduce instability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't gone through the implementation in a lot of detail, but just made a few suggestions. Let's see what @Ishticode thinks, I'll also request @MahmoudAshraf97's review for the changes to the paddle
backend. Thanks for diving into this @jshepherd01 😄
Adding `complex_mode` to the test of tanh exposed a bug with the handle_complex_input decorator, which is also fixed in this commit. Also removes `complex_mode` argument from the `_forward` method of some stateful API classes (as per #21902) and adds some context for the magic numbers in paddle.gelu
Standardised `complex_mode` docstring across wrapper functions, add `complex_mode` to `array.gelu`, fix `complex_mode` parameter in stateful API, remove special case handling in `jax_backend.log1p` and restrict its tests to managable values.
`input_dtype[0]` is already a string so I don't need to turn it into one
Fix doctsrings on the stateful API for ReLU and GELU, add complex_mode to tests for ReLU, GELU, LeakyReLU
For all activation functions which currently accept it: tanh, relu, leaky_relu, and gelu.
When complex_mode is "split", the inner function is called twice by @handle_complex_input, so previously the result of the second function was overwriting the result of the first in the `out` array. Fixed by passing views of the real and imaginary parts of the `out` array to their respective function calls.
The array itself had a safety factor already, but the alpha value did not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there, just a final couple of changes, thanks @jshepherd01 😄
ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py
Outdated
Show resolved
Hide resolved
…into complex-maths-functions
Apply this change to the `softmax` function, which newly gained this argument. Also reflect this change in docstrings of relevant functions.
…#21539) Add support for (and testing of) complex-valued inputs for various mathematical functions: `acos`, `acosh`, `asin`, `asinh`, `atan`, `atanh`, `cos`, `cosh`, `exp`, `exp2`, `expm1`, `log`, `log10`, `log1p`, `log2`, `sin`, `sinh`, `sqrt`, `tan`, `tanh`. Modify testing for existing complex activation functions to test `complex_mode` Add `complex_mode` argument to `tanh` Fix `handle_complex_input` decorator Modify docstrings and function signatures relating to `complex_mode` Refactor `gelu` to better leverage backend frameworks Learnt an important lesson about splitting up PRs properly
…#21539) Add support for (and testing of) complex-valued inputs for various mathematical functions: `acos`, `acosh`, `asin`, `asinh`, `atan`, `atanh`, `cos`, `cosh`, `exp`, `exp2`, `expm1`, `log`, `log10`, `log1p`, `log2`, `sin`, `sinh`, `sqrt`, `tan`, `tanh`. Modify testing for existing complex activation functions to test `complex_mode` Add `complex_mode` argument to `tanh` Fix `handle_complex_input` decorator Modify docstrings and function signatures relating to `complex_mode` Refactor `gelu` to better leverage backend frameworks Learnt an important lesson about splitting up PRs properly
Add complex number support to tanh (as an activation function as well as just a mathematical one, so it now has
@handle_complex_input
and the like) and exp. Modify pow for better numerical stability, and a slight change to gelu to resolve some type issues. Now also adds complex number support to all trig functions, including inverse and hyperbolic, in the API, as well as expanding test coverage to also check complex values.