Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm][InstCombine] bitcast bfloat half castpair bug #79832

Merged
merged 3 commits into from
Jan 31, 2024

Conversation

nasherm
Copy link
Contributor

@nasherm nasherm commented Jan 29, 2024

Miscompilation arises due to instruction combining of cast pairs of the type bitcast bfloat to half + <FPOp> bfloat to half or bitcast half to bfloat + <FPOp half to bfloat. For example bitcast bfloat to half+fpext half to double or bitcast bfloat to half+fpext bfloat to double respectively reduce to fpext bfloat to double and fpext half to double. This is an incorrect conversion as it assumes the representation of bfloat and half are equivalent due to having the same width. As a consequence miscompilation arises.

Fixes #61984

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 29, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Nashe Mncube (nasherm)

Changes

Miscompilation arises due to instruction combining of cast pairs of the type bitcast bfloat to half + &lt;FPOp&gt; bfloat to half or bitcast half to bfloat + &lt;FPOp half to bfloat. For example bitcast bfloat to half+fpext half to double or bitcast bfloat to half+fpext bfloat to double respectively reduce to fpext bfloat to double and fpext half to double. This is an incorrect conversion as it assumes the representation of bfloat and half are equivalent due to having the same width. As a consequence miscompilation arises.


Full diff: https://github.com/llvm/llvm-project/pull/79832.diff

2 Files Affected:

  • (modified) llvm/lib/IR/Instructions.cpp (+7)
  • (added) llvm/test/Transforms/InstCombine/bitcast-bfloat-half-mixing.ll (+70)
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 87874c3abc4680..e268184b17f92a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3214,6 +3214,13 @@ unsigned CastInst::isEliminableCastPair(
         return secondOp;
       return 0;
     case 6:
+      // In cast pairs bfloat and half float shouldn't be treated as equivalent
+      // if the first operation is a bitcast i.e. if we have
+      // bitcast bfloat to half + fpext half to double we shouldn't reduce to
+      // fpext bfloat to double as this isn't equal to fpext half to double.
+      // This has been generalised for all float pairs that have the same width.
+      if (SrcTy->getPrimitiveSizeInBits() == MidTy->getPrimitiveSizeInBits())
+        return 0;
       // No-op cast in first op implies secondOp as long as the SrcTy
       // is a floating point.
       if (SrcTy->isFloatingPointTy())
diff --git a/llvm/test/Transforms/InstCombine/bitcast-bfloat-half-mixing.ll b/llvm/test/Transforms/InstCombine/bitcast-bfloat-half-mixing.ll
new file mode 100644
index 00000000000000..3878f45c7326e5
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/bitcast-bfloat-half-mixing.ll
@@ -0,0 +1,70 @@
+; RUN: opt -S %s | FileCheck %s
+
+define double @F0([2 x bfloat] %P0) {
+entry:
+  %P0.extract = extractvalue [2 x bfloat] %P0, 1
+  %conv0 = bitcast bfloat %P0.extract to half
+  %0 = fpext half %conv0 to double
+  ret double %0
+}
+
+; CHECK: fpext half %conv0 to double
+; CHECK-NOT: fpext bfloat %P0.extract to double
+
+define double @F1([2 x half] %P1) {
+entry:
+  %P1.extract = extractvalue [2 x half] %P1, 1
+  %conv1 = bitcast half %P1.extract to bfloat
+  %0 = fpext bfloat %conv1 to double
+  ret double %0
+}
+
+; CHECK: fpext bfloat %conv1 to double
+; CHECK-NOT: fpext bfloat %P1.extract to double
+
+define i32 @F2([2 x bfloat] %P2) {
+entry:
+  %P2.extract = extractvalue [2 x bfloat] %P2, 1
+  %conv2 = bitcast bfloat %P2.extract to half
+  %0 = fptoui half %conv2 to i32
+  ret i32 %0
+}
+
+; CHECK: fptoui half %conv2 to i32
+; CHECK-NOT: fptoui bfloat %P2.extract to i32
+
+define i32 @F3([2 x half] %P3) {
+entry:
+  %P3.extract = extractvalue [2 x half] %P3, 1
+  %conv3 = bitcast half %P3.extract to bfloat
+  %0 = fptoui bfloat %conv3 to i32
+  ret i32 %0
+}
+
+; CHECK: fptoui bfloat %conv3 to i32
+; CHECK-NOT: fptoui half %P3.extract to i32
+
+
+define i32 @F4([2 x bfloat] %P4) {
+entry:
+  %P4.extract = extractvalue [2 x bfloat] %P4, 1
+  %conv4 = bitcast bfloat %P4.extract to half
+  %0 = fptosi half %conv4 to i32
+  ret i32 %0
+}
+
+; CHECK: fptosi half %conv4 to i32
+; CHECK-NOT: fptosi bfloat %P4.extract to i32
+
+define i32 @F5([2 x half] %P5) {
+entry:
+  %P5.extract = extractvalue [2 x half] %P5, 1
+  %conv5 = bitcast half %P5.extract to bfloat
+  %0 = fptosi bfloat %conv5 to i32
+  ret i32 %0
+}
+
+; CHECK: fptosi bfloat %conv5 to i32
+; CHECK-NOT: fptosi half %P5.extract to i32
+
+

@nasherm nasherm force-pushed the nashe/bitcast-bfloat-half-cast-pair branch 2 times, most recently from 3e4c9fa to 4f7b10f Compare January 30, 2024 10:15
@nikic nikic self-requested a review January 30, 2024 11:29
@nasherm nasherm force-pushed the nashe/bitcast-bfloat-half-cast-pair branch 2 times, most recently from e72f7e3 to 682363e Compare January 30, 2024 15:00
Copy link

github-actions bot commented Jan 30, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 78e0cca135076154abab21eadd146dc1dfd3549f 6960958a6815754852c4f0a1c2f3fb8714a0dde0 -- llvm/lib/IR/Instructions.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index ce0df53d9f..ad582241a1 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3148,24 +3148,25 @@ unsigned CastInst::isEliminableCastPair(
   const unsigned numCastOps =
     Instruction::CastOpsEnd - Instruction::CastOpsBegin;
   static const uint8_t CastResults[numCastOps][numCastOps] = {
-    // T        F  F  U  S  F  F  P  I  B  A  -+
-    // R  Z  S  P  P  I  I  T  P  2  N  T  S   |
-    // U  E  E  2  2  2  2  R  E  I  T  C  C   +- secondOp
-    // N  X  X  U  S  F  F  N  X  N  2  V  V   |
-    // C  T  T  I  I  P  P  C  T  T  P  T  T  -+
-    {  1, 0, 0,99,99, 0, 0,99,99,99, 0, 3, 0}, // Trunc         -+
-    {  8, 1, 9,99,99, 2,17,99,99,99, 2, 3, 0}, // ZExt           |
-    {  8, 0, 1,99,99, 0, 2,99,99,99, 0, 3, 0}, // SExt           |
-    {  0, 0, 0,99,99, 0, 0,99,99,99, 0, 3, 0}, // FPToUI         |
-    {  0, 0, 0,99,99, 0, 0,99,99,99, 0, 3, 0}, // FPToSI         |
-    { 99,99,99, 0, 0,99,99, 0, 0,99,99, 4, 0}, // UIToFP         +- firstOp
-    { 99,99,99, 0, 0,99,99, 0, 0,99,99, 4, 0}, // SIToFP         |
-    { 99,99,99, 0, 0,99,99, 0, 0,99,99, 4, 0}, // FPTrunc        |
-    { 99,99,99, 2, 2,99,99, 8, 2,99,99, 4, 0}, // FPExt          |
-    {  1, 0, 0,99,99, 0, 0,99,99,99, 7, 3, 0}, // PtrToInt       |
-    { 99,99,99,99,99,99,99,99,99,11,99,15, 0}, // IntToPtr       |
-    {  5, 5, 5, 0, 0, 5, 5, 0, 0,16, 5, 1,14}, // BitCast        |
-    {  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,13,12}, // AddrSpaceCast -+
+      // T        F  F  U  S  F  F  P  I  B  A  -+
+      // R  Z  S  P  P  I  I  T  P  2  N  T  S   |
+      // U  E  E  2  2  2  2  R  E  I  T  C  C   +- secondOp
+      // N  X  X  U  S  F  F  N  X  N  2  V  V   |
+      // C  T  T  I  I  P  P  C  T  T  P  T  T  -+
+      {1, 0, 0, 99, 99, 0, 0, 99, 99, 99, 0, 3, 0},  // Trunc         -+
+      {8, 1, 9, 99, 99, 2, 17, 99, 99, 99, 2, 3, 0}, // ZExt           |
+      {8, 0, 1, 99, 99, 0, 2, 99, 99, 99, 0, 3, 0},  // SExt           |
+      {0, 0, 0, 99, 99, 0, 0, 99, 99, 99, 0, 3, 0},  // FPToUI         |
+      {0, 0, 0, 99, 99, 0, 0, 99, 99, 99, 0, 3, 0},  // FPToSI         |
+      {99, 99, 99, 0, 0, 99, 99, 0, 0, 99, 99, 4,
+       0}, // UIToFP         +- firstOp
+      {99, 99, 99, 0, 0, 99, 99, 0, 0, 99, 99, 4, 0},      // SIToFP         |
+      {99, 99, 99, 0, 0, 99, 99, 0, 0, 99, 99, 4, 0},      // FPTrunc        |
+      {99, 99, 99, 2, 2, 99, 99, 8, 2, 99, 99, 4, 0},      // FPExt          |
+      {1, 0, 0, 99, 99, 0, 0, 99, 99, 99, 7, 3, 0},        // PtrToInt       |
+      {99, 99, 99, 99, 99, 99, 99, 99, 99, 11, 99, 15, 0}, // IntToPtr       |
+      {5, 5, 5, 0, 0, 5, 5, 0, 0, 16, 5, 1, 14},           // BitCast        |
+      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 12},           // AddrSpaceCast -+
   };
 
   // TODO: This logic could be encoded into the table above and handled in the

llvm/lib/IR/Instructions.cpp Outdated Show resolved Hide resolved
@nasherm nasherm force-pushed the nashe/bitcast-bfloat-half-cast-pair branch from 682363e to 6100a7c Compare January 30, 2024 15:51
@nasherm
Copy link
Contributor Author

nasherm commented Jan 30, 2024

clang-format expects entire CastResults table to be formatted in a certain way, but this formatting affects readability

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Miscompilation arises due to instruction combining of cast pairs
of the type `bitcast bfloat to half` + `<FPOp> bfloat to half` or
`bitcast half to bfloat` + `<FPOp half to bfloat`. For example
`bitcast bfloat to half`+`fpext half to double` or
`bitcast bfloat to half`+`fpext bfloat to double` respectively
reduce to `fpext bfloat to double` and `fpext half to double`.
This is an incorrect conversion as it assumes the representation
of `bfloat` and `half` are equivalent due to having the same width.
As a consequence miscompilation arises.

Change-Id: Ie5b7c4b385a946325c60de5495ce3bdf087abc46
Change-Id: I447266c9dced50b6493514450952d046da4db83c
Change-Id: I7d8139fdc9696bb172c038e2cd1c8a5b009a0543
@nasherm nasherm force-pushed the nashe/bitcast-bfloat-half-cast-pair branch from ab55882 to 6960958 Compare January 31, 2024 10:54
@nasherm nasherm merged commit d309261 into llvm:main Jan 31, 2024
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AArch64] Miscompilation of code mixing fp16 and bf16
5 participants