New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix log1p inaccuracies on complex inputs with large absolute values. #10503
Conversation
The same comment as in #10376 (comment) applies here as well: the CI testOnComplexPlane failures are expected and that should be fixed after google/jax#20144 lands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we test the improved accuracy?
The accuracy of log1p is tested in google/jax#20144 . See the description of that PR that reports the accuracy improvements for log1p. |
@cheshire gentle reminder |
Understood, so JAX test is an e2e integration test, but could we have an XLA unit test as well? |
@cheshire Having complex functions correctness tests in XLA rather than in JAX makes lots of sense. That said, when testing functions for correctness, it is essential to have reference implementations that correctness are verified in a platform independent way. Unfortunately, there are not many libraries available that qualify as a reference library. For instance, within the Python/JAX framework, one obvious choice would be numpy, however, as it turns out, numpy functions often rely on system math libraries and the function evaluation results can be different on different platforms. This is one important reason why in JAX tests we use mpmath as a reference library. While there exists a few bugs in mpmath itself that require workarounds, one of the most important properties of mpmath is that its functions evaluation results are platform independent. In https://github.com/pearu/complex_function_validation/, I have validated complex functions from a number of Python array libraries for correctness with a conclusion that mpmath is a reasonable reference library indeed. When considering XLA, one of the obvious choices is libc++. Unfortunately, there exists incorrectness issues also in libc++ that must be tackled via correctness analysis and implementing workarounds before using it as a reference library. The current tests in XLA are very minimal (functions are tested on a couple of points only) and miss many incorrectness issues that the correctness analysis of JAX complex functions above also demonstrates. In short term, one approach would be to generate tables of inputs and outputs to different complex functions (say, using mpmath) and use these tables as a reference database when testing XLA complex functions for correctness. What do you think? |
In this case having hardcoded expected output values along with the long comment as the one above indeed seems like the preferred option, thanks! |
I suggest to do this in a follow up PR because it is relevant for other functions as well, and to avoid XLA and JAX repos going out of sync (google/jax#20436 has landed that includes |
Sorry what's the connection between going out of sync and having a unit test? |
In this particular case, google/jax#20436 enabled unit tests for log1p fixes that are implemented in this PR. Atm, this means that when running JAX tests from its main branch using jaxlib built from XLA main branch, the JAX log1p unit tests will fail. |
A quick solution to this is to create a JAX PR that disables log1p units tests. Although, it would also make sense to land this PR as we already do test the related changes in JAX. And in the future, we'll implement the corresponding unit tests also in XLA. What do you think? |
Sorry I'm still confused. This PR improves log1p handling in XLA, right? So it would be natural to accompany it with a unit test? The question regarding JAX and/or JAX/XLA interop is orthogonal? |
Correct.
True. I am working on it atm.
There exists inference only when JAX implements tests for XLA features. AFAIK, this situation is not unique to this issue. |
…te values. Imported from GitHub PR #10503 As in the title. Tests and improvement reports are in google/jax#20144. Accuracy tests are enabled in google/jax#20436 Copybara import of the project: -- 2b8f253 by Pearu Peterson <pearu.peterson@gmail.com>: Fix log1p inaccuracies on complex inputs with large absolute values. -- d35cef4 by Pearu Peterson <pearu.peterson@gmail.com>: Add tests to complex Log1p Merging this change closes #10503 FUTURE_COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4 PiperOrigin-RevId: 620293408
…te values. Imported from GitHub PR openxla/xla#10503 As in the title. Tests and improvement reports are in google/jax#20144. Accuracy tests are enabled in google/jax#20436 Copybara import of the project: -- 2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>: Fix log1p inaccuracies on complex inputs with large absolute values. -- d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>: Add tests to complex Log1p Merging this change closes #10503 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 620293408
…te values. Imported from GitHub PR #10503 As in the title. Tests and improvement reports are in google/jax#20144. Accuracy tests are enabled in google/jax#20436 Copybara import of the project: -- 2b8f253 by Pearu Peterson <pearu.peterson@gmail.com>: Fix log1p inaccuracies on complex inputs with large absolute values. -- d35cef4 by Pearu Peterson <pearu.peterson@gmail.com>: Add tests to complex Log1p Merging this change closes #10503 FUTURE_COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4 PiperOrigin-RevId: 621891416
…te values. Imported from GitHub PR openxla/xla#10503 As in the title. Tests and improvement reports are in google/jax#20144. Accuracy tests are enabled in google/jax#20436 Copybara import of the project: -- 2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>: Fix log1p inaccuracies on complex inputs with large absolute values. -- d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>: Add tests to complex Log1p Merging this change closes #10503 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 621891416
…te values. Imported from GitHub PR #10503 As in the title. Tests and improvement reports are in google/jax#20144. Accuracy tests are enabled in google/jax#20436 Copybara import of the project: -- 2b8f253 by Pearu Peterson <pearu.peterson@gmail.com>: Fix log1p inaccuracies on complex inputs with large absolute values. -- d35cef4 by Pearu Peterson <pearu.peterson@gmail.com>: Add tests to complex Log1p Merging this change closes #10503 FUTURE_COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4 PiperOrigin-RevId: 620293408
…te values. Imported from GitHub PR openxla/xla#10503 As in the title. Tests and improvement reports are in google/jax#20144. Accuracy tests are enabled in google/jax#20436 Copybara import of the project: -- 2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>: Fix log1p inaccuracies on complex inputs with large absolute values. -- d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>: Add tests to complex Log1p Merging this change closes #10503 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 620293408
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 607147423
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 620894314
This is part of an effort to move runtime targets to the runtime folder. #5758 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 621605191
tl;dr: this gives a 1.26x compilation time speedup for a large, dense model in XLA:GPU. The largest perf leaf seen in profiles of a large, dense model is related to computing the post order. Surprisingly, it is not the DFS itself what's most expensive; rather, most of the time is spent on scanning through HloComputation::Instructions() to identify DFS roots. The reason this scan becomes expensive as instructions are removed is that the vector holding HloInstructionInfo (introduced in cl/600130708 || openxla/xla@247280ab727) is not shrunk as it flows through the pipeline, making us having to walk through many deleted "tombstone" entries. Here is the histogram of # of tombstones encountered during post order computations for this model: ``` [ 1 - 1,536,345) ****************************** (1,300,248) [1,536,345 - 3,072,690) (2) [3,072,690 - 4,609,034) (364) [4,609,034 - 6,145,378) (10,443) ``` To ameliorate this, this CL shrinks the vector periodically, so far only between passes. This is done by running compaction on the vector during HloComputation::Cleanup(), which is called after every pass. The cost of compaction is made proportional to the number of deleted entries by swapping--if needed--each tombstone with the rightmost (within the vector) non-deleted entry. This brings the number of seen tombstones down significantly: ``` [ 1 - 327,699) ****************************** (937,541) [ 327,699 - 655,396) (308) [ 655,396 - 983,094) (0) [ 983,094 - 1,310,792) (1) ``` Note: we could further improve compaction by calling Cleanup() from some passes, instead of just between passes. However, that would not yield a significant gain; at least for this model, scanning the instructions' vector now takes ~1% of total time (vs. ~17% before). FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 619057964
…atibility. FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 619687562
This is part of an effort to move runtime targets to the runtime folder. #5758 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde PiperOrigin-RevId: 621605191
…te values. Imported from GitHub PR openxla/xla#10503 As in the title. Tests and improvement reports are in google/jax#20144. Accuracy tests are enabled in google/jax#20436 Copybara import of the project: -- 2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>: Fix log1p inaccuracies on complex inputs with large absolute values. -- d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>: Add tests to complex Log1p Merging this change closes #10503 PiperOrigin-RevId: 621917683
This ports openxla/xla#10503 by @pearu. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters.
This ports openxla/xla#10503 by @pearu. The new implementation matches mpmath's results for most inputs, see caveats in the linked pull request. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters.
This ports openxla/xla#10503 by @pearu. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters. This is a fixed version of llvm#88260. The previous version relied on implementation-specific behavior in the order of evaluation of maxAbsOfRealPlusOneAndImagMinusOne's operands.
This ports openxla/xla#10503 by @pearu. In addition to the filecheck test here, the accuracy was tested with XLA's complex_unary_op_test and its MLIR emitters. This is a fixed version of #88260. The previous version relied on implementation-specific behavior in the order of evaluation of maxAbsOfRealPlusOneAndImagMinusOne's operands.
As in the title.
Tests and improvement reports are in google/jax#20144.
Accuracy tests are enabled in google/jax#20436