Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix log1p inaccuracies on complex inputs with large absolute values. #10503

Closed
wants to merge 2 commits into from

Conversation

pearu
Copy link
Contributor

@pearu pearu commented Mar 13, 2024

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436

@pearu
Copy link
Contributor Author

pearu commented Mar 13, 2024

The same comment as in #10376 (comment) applies here as well: the CI testOnComplexPlane failures are expected and that should be fixed after google/jax#20144 lands

Copy link
Member

@cheshire cheshire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we test the improved accuracy?

@pearu
Copy link
Contributor Author

pearu commented Mar 15, 2024

Could we test the improved accuracy?

The accuracy of log1p is tested in google/jax#20144 . See the description of that PR that reports the accuracy improvements for log1p.

@tdanyluk
Copy link
Member

@cheshire gentle reminder

@cheshire
Copy link
Member

The accuracy of log1p is tested in google/jax#20144 . See the description of that PR that reports the accuracy improvements for log1p.

Understood, so JAX test is an e2e integration test, but could we have an XLA unit test as well?

@pearu
Copy link
Contributor Author

pearu commented Mar 26, 2024

... could we have an XLA unit test as well?

@cheshire Having complex functions correctness tests in XLA rather than in JAX makes lots of sense.

That said, when testing functions for correctness, it is essential to have reference implementations that correctness are verified in a platform independent way. Unfortunately, there are not many libraries available that qualify as a reference library.

For instance, within the Python/JAX framework, one obvious choice would be numpy, however, as it turns out, numpy functions often rely on system math libraries and the function evaluation results can be different on different platforms. This is one important reason why in JAX tests we use mpmath as a reference library. While there exists a few bugs in mpmath itself that require workarounds, one of the most important properties of mpmath is that its functions evaluation results are platform independent. In https://github.com/pearu/complex_function_validation/, I have validated complex functions from a number of Python array libraries for correctness with a conclusion that mpmath is a reasonable reference library indeed.

When considering XLA, one of the obvious choices is libc++. Unfortunately, there exists incorrectness issues also in libc++ that must be tackled via correctness analysis and implementing workarounds before using it as a reference library. The current tests in XLA are very minimal (functions are tested on a couple of points only) and miss many incorrectness issues that the correctness analysis of JAX complex functions above also demonstrates.

In short term, one approach would be to generate tables of inputs and outputs to different complex functions (say, using mpmath) and use these tables as a reference database when testing XLA complex functions for correctness.

What do you think?

@cheshire
Copy link
Member

In this case having hardcoded expected output values along with the long comment as the one above indeed seems like the preferred option, thanks!

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 27, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 27, 2024
@pearu
Copy link
Contributor Author

pearu commented Mar 27, 2024

In this case having hardcoded expected output values along with the long comment as the one above indeed seems like the preferred option, thanks!

I suggest to do this in a follow up PR because it is relevant for other functions as well, and to avoid XLA and JAX repos going out of sync (google/jax#20436 has landed that includes xla_extension_version update related to this PR).

@cheshire
Copy link
Member

Sorry what's the connection between going out of sync and having a unit test?

@pearu
Copy link
Contributor Author

pearu commented Mar 27, 2024

Sorry what's the connection between going out of sync and having a unit test?

In this particular case, google/jax#20436 enabled unit tests for log1p fixes that are implemented in this PR. Atm, this means that when running JAX tests from its main branch using jaxlib built from XLA main branch, the JAX log1p unit tests will fail. xla_extension_version is designed to prevent this but under assumption that the related XLA and JAX PRs are landed more or less simultaneously. When some unrelated XLA PR updates xla_extension_version and it lands before, then the related XLA and JAX PRs will be out of sync. The current situation is a bit worse: XLA and JAX main branches are out of sync.

@pearu
Copy link
Contributor Author

pearu commented Mar 27, 2024

The current situation is a bit worse: XLA and JAX main branches are out of sync.

A quick solution to this is to create a JAX PR that disables log1p units tests. Although, it would also make sense to land this PR as we already do test the related changes in JAX. And in the future, we'll implement the corresponding unit tests also in XLA. What do you think?

@cheshire
Copy link
Member

Sorry I'm still confused. This PR improves log1p handling in XLA, right? So it would be natural to accompany it with a unit test? The question regarding JAX and/or JAX/XLA interop is orthogonal?

@pearu
Copy link
Contributor Author

pearu commented Mar 27, 2024

Sorry I'm still confused. This PR improves log1p handling in XLA, right?

Correct.

So it would be natural to accompany it with a unit test?

True. I am working on it atm.

The question regarding JAX and/or JAX/XLA interop is orthogonal?

There exists inference only when JAX implements tests for XLA features. AFAIK, this situation is not unique to this issue.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Mar 29, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Mar 29, 2024
copybara-service bot pushed a commit that referenced this pull request Apr 4, 2024
…te values.

Imported from GitHub PR #10503

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436
Copybara import of the project:

--
2b8f253 by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4 by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

FUTURE_COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4
PiperOrigin-RevId: 620293408
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
…te values.

Imported from GitHub PR openxla/xla#10503

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436
Copybara import of the project:

--
2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 620293408
copybara-service bot pushed a commit that referenced this pull request Apr 4, 2024
…te values.

Imported from GitHub PR #10503

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436
Copybara import of the project:

--
2b8f253 by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4 by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

FUTURE_COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4
PiperOrigin-RevId: 621891416
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
…te values.

Imported from GitHub PR openxla/xla#10503

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436
Copybara import of the project:

--
2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 621891416
copybara-service bot pushed a commit that referenced this pull request Apr 4, 2024
…te values.

Imported from GitHub PR #10503

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436
Copybara import of the project:

--
2b8f253 by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4 by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

FUTURE_COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4
PiperOrigin-RevId: 620293408
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
…te values.

Imported from GitHub PR openxla/xla#10503

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436
Copybara import of the project:

--
2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 620293408
@copybara-service copybara-service bot closed this in 2a29934 Apr 4, 2024
copybara-service bot pushed a commit that referenced this pull request Apr 4, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=#10503 from pearu:pearu/log1p d35cef4
PiperOrigin-RevId: 620894314
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 607147423
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 620894314
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
This is part of an effort to move runtime targets to the runtime folder. #5758

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 621605191
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
tl;dr: this gives a 1.26x compilation time speedup for a large, dense
model in XLA:GPU.

The largest perf leaf seen in profiles of a large, dense model
is related to computing the post order. Surprisingly, it is not
the DFS itself what's most expensive; rather, most of the time is
spent on scanning through HloComputation::Instructions() to identify
DFS roots.

The reason this scan becomes expensive as instructions are removed
is that the vector holding HloInstructionInfo (introduced in
cl/600130708 || openxla/xla@247280ab727)
is not shrunk as it flows through the pipeline, making us having
to walk through many deleted "tombstone" entries. Here is the
histogram of # of tombstones encountered during post order
computations for this model:

```
[        1 - 1,536,345) ****************************** (1,300,248)
[1,536,345 - 3,072,690)  (2)
[3,072,690 - 4,609,034)  (364)
[4,609,034 - 6,145,378)  (10,443)
```

To ameliorate this, this CL shrinks the vector periodically,
so far only between passes. This is done by running compaction
on the vector during HloComputation::Cleanup(), which is called
after every pass. The cost of compaction is made proportional to
the number of deleted entries by swapping--if needed--each tombstone
with the rightmost (within the vector) non-deleted entry.

This brings the number of seen tombstones down significantly:

```
[        1 -   327,699) ****************************** (937,541)
[  327,699 -   655,396)  (308)
[  655,396 -   983,094)  (0)
[  983,094 - 1,310,792)  (1)
```

Note: we could further improve compaction by calling Cleanup()
from some passes, instead of just between passes. However, that
would not yield a significant gain; at least for this model,
scanning the instructions' vector now takes ~1% of total time
(vs. ~17% before).
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 619057964
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
…atibility.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 619687562
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
This is part of an effort to move runtime targets to the runtime folder. #5758

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10503 from pearu:pearu/log1p d35cef4f5fa09482c49edfee709e86c5ca29adde
PiperOrigin-RevId: 621605191
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 4, 2024
…te values.

Imported from GitHub PR openxla/xla#10503

As in the title.

Tests and improvement reports are in google/jax#20144.

Accuracy tests are enabled in google/jax#20436
Copybara import of the project:

--
2b8f2539d6bf364c0a97f65e186430c5eb3ed07b by Pearu Peterson <pearu.peterson@gmail.com>:

Fix log1p inaccuracies on complex inputs with large absolute values.

--
d35cef4f5fa09482c49edfee709e86c5ca29adde by Pearu Peterson <pearu.peterson@gmail.com>:

Add tests to complex Log1p

Merging this change closes #10503

PiperOrigin-RevId: 621917683
jreiffers added a commit to jreiffers/llvm-project that referenced this pull request Apr 10, 2024
This ports openxla/xla#10503 by @pearu. In addition
to the filecheck test here, the accuracy was tested with XLA's
complex_unary_op_test and its MLIR emitters.
jreiffers added a commit to llvm/llvm-project that referenced this pull request Apr 10, 2024
This ports openxla/xla#10503 by @pearu. The new
implementation matches mpmath's results for most inputs, see caveats in
the linked pull request. In addition to the filecheck test here, the
accuracy was tested with XLA's complex_unary_op_test and its MLIR
emitters.
jreiffers added a commit to jreiffers/llvm-project that referenced this pull request Apr 11, 2024
This ports openxla/xla#10503 by @pearu. In addition
to the filecheck test here, the accuracy was tested with XLA's
complex_unary_op_test and its MLIR emitters.

This is a fixed version of
llvm#88260. The previous version
relied on implementation-specific behavior in the order of evaluation of
maxAbsOfRealPlusOneAndImagMinusOne's operands.
jreiffers added a commit to llvm/llvm-project that referenced this pull request Apr 11, 2024
This ports openxla/xla#10503 by @pearu. In addition to the filecheck
test here, the accuracy was tested with XLA's complex_unary_op_test and
its MLIR emitters.

This is a fixed version of
#88260. The previous version
relied on implementation-specific behavior in the order of evaluation of
maxAbsOfRealPlusOneAndImagMinusOne's operands.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants