Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Fix FP16 Precision for Sigmoid Op #14727

Merged
merged 4 commits into from
Feb 22, 2023
Merged

Conversation

er3x3
Copy link
Contributor

@er3x3 er3x3 commented Feb 17, 2023

Current Sigmoid's CUDA kernel uses target data type for all computation. For some small negative numbers, if using FP16, it will loss precision. For example, for input [-7.8477, 7.3320, -7.8008, 6.6016], the expected output is [3.9047e-04, 9.9935e-01, 4.0919e-04, 9.9864e-01], but current kernel will generate result [0.0000, 0.9990, 0.0000, 0.9990]. If some sub-graph contains Sigmoid, such as BinaryCrossEntropyWithLogits, it's likely to produce NaN as compute result.

The PR fixes this by using FP32 for kernel internal computation. Note that the fix will not have perf regression, as CUDA's _Exp will also do float to half casting, so the fix doesn't introduce extra cast. We move the cast to right begin and end of the whole kernel so that other parts of computation are also in FP32 (instead of only Exp).

hanbitmyths
hanbitmyths previously approved these changes Feb 17, 2023
Copy link
Contributor

@hanbitmyths hanbitmyths left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@askhade
Copy link
Contributor

askhade commented Feb 17, 2023

@iK1D can you bring back the unit test which was deleted in this PR: https://github.com/microsoft/onnxruntime/pull/12301/files

@er3x3
Copy link
Contributor Author

er3x3 commented Feb 18, 2023

@iK1D can you bring back the unit test which was deleted in this PR: https://github.com/microsoft/onnxruntime/pull/12301/files

Just added back. But logically this is unnecessary. Previously we called ATen so the UT guarantee we called ATen properly. But now, logically there should be UT in exporter side to make sure the exporter logic is correct, and our UTs at Op level cover the correctness of single Op compute. The Sigmoid Op UT for FP16 should have covered the issue before, but the error tolerance is too big for FP16 to cover such very small error, but it's not easy to chang the error tolerance as it may effect other UTs.

@er3x3 er3x3 merged commit e9ec4c0 into main Feb 22, 2023
@er3x3 er3x3 deleted the weicwang/sigmoid_half_precision branch February 22, 2023 01:16
PatriceVignola pushed a commit that referenced this pull request Feb 22, 2023
Current Sigmoid's CUDA kernel uses target data type for all computation.
For some small negative numbers, if using FP16, it will loss precision.
For example, for input [-7.8477, 7.3320, -7.8008, 6.6016], the expected
output is [3.9047e-04, 9.9935e-01, 4.0919e-04, 9.9864e-01], but current
kernel will generate result [0.0000, 0.9990, 0.0000, 0.9990]. If some
sub-graph contains Sigmoid, such as BinaryCrossEntropyWithLogits, it's
likely to produce NaN as compute result.

The PR fixes this by using FP32 for kernel internal computation. Note
that the fix will not have perf regression, as CUDA's _Exp will also do
float to half casting, so the fix doesn't introduce extra cast. We move
the cast to right begin and end of the whole kernel so that other parts
of computation are also in FP32 (instead of only Exp).
PatriceVignola pushed a commit that referenced this pull request Feb 22, 2023
Current Sigmoid's CUDA kernel uses target data type for all computation.
For some small negative numbers, if using FP16, it will loss precision.
For example, for input [-7.8477, 7.3320, -7.8008, 6.6016], the expected
output is [3.9047e-04, 9.9935e-01, 4.0919e-04, 9.9864e-01], but current
kernel will generate result [0.0000, 0.9990, 0.0000, 0.9990]. If some
sub-graph contains Sigmoid, such as BinaryCrossEntropyWithLogits, it's
likely to produce NaN as compute result.

The PR fixes this by using FP32 for kernel internal computation. Note
that the fix will not have perf regression, as CUDA's _Exp will also do
float to half casting, so the fix doesn't introduce extra cast. We move
the cast to right begin and end of the whole kernel so that other parts
of computation are also in FP32 (instead of only Exp).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants