Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAGCombine] Fix multi-use miscompile in load combine #81492

Closed
wants to merge 1 commit into from

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Feb 12, 2024

The load combine replaces a number of original loads with one new loads and also replaces the output chains of the original loads with the output chain of the new load. This is only correct if the old loads actually get removed, otherwise they may get incorrectly reordered.

The code did enforce that all involved operations are one-use (which also guarantees that the loads will be removed), with one exceptions: For vector loads, multi-use was allowed to support multiple extract elements from one load.

This patch collects these extract elements, and then validates that the loads are only used inside them.

I think an alternative fix would be to replace the uses of the old output chains with TokenFactors that include both the old output chains and the new output chain. However, I think the proposed patch is preferable, as the profitability of the transform in the general multi-use case is unclear, as it may increase the overall number of loads.

Fixes #80911.

The load combine replaces a number of original loads with one
new loads and also replaces the output chains of the original loads
with the output chain of the new load. This is only correct if
the old loads actually get removed, otherwise they may get
incorrectly reordered.

The code did enforce that all involved operations are one-use
(which also guarantees that the loads will be removed), with one
exceptions: For vector loads, multi-use was allowed to support
multiple extract elements from one load.

This patch collects these extract elements, and then validates
that the loads are only used inside them.

I think an alternative fix would be to replace the uses of the old
output chains with TokenFactors that include both the old output
chains and the new output chain. However, I think the proposed
patch is preferable, as the profitability of the transform in the
general multi-use case is unclear, as it may increase the overall
number of loads.

Fixes llvm#80911.
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 12, 2024

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-selectiondag

Author: Nikita Popov (nikic)

Changes

The load combine replaces a number of original loads with one new loads and also replaces the output chains of the original loads with the output chain of the new load. This is only correct if the old loads actually get removed, otherwise they may get incorrectly reordered.

The code did enforce that all involved operations are one-use (which also guarantees that the loads will be removed), with one exceptions: For vector loads, multi-use was allowed to support multiple extract elements from one load.

This patch collects these extract elements, and then validates that the loads are only used inside them.

I think an alternative fix would be to replace the uses of the old output chains with TokenFactors that include both the old output chains and the new output chain. However, I think the proposed patch is preferable, as the profitability of the transform in the general multi-use case is unclear, as it may increase the overall number of loads.

Fixes #80911.


Full diff: https://github.com/llvm/llvm-project/pull/81492.diff

4 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+24-9)
  • (modified) llvm/test/CodeGen/AArch64/load-combine.ll (+5-3)
  • (modified) llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll (+6-4)
  • (modified) llvm/test/CodeGen/X86/load-combine.ll (+16-9)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d3cd9b1671e1b9..45114b85e25d8c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8668,6 +8668,7 @@ using SDByteProvider = ByteProvider<SDNode *>;
 static std::optional<SDByteProvider>
 calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
                       std::optional<uint64_t> VectorIndex,
+                      SmallPtrSetImpl<SDNode *> &ExtractElements,
                       unsigned StartingIndex = 0) {
 
   // Typical i64 by i8 pattern requires recursion up to 8 calls depth
@@ -8694,12 +8695,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
 
   switch (Op.getOpcode()) {
   case ISD::OR: {
-    auto LHS =
-        calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
+    auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
+                                     VectorIndex, ExtractElements);
     if (!LHS)
       return std::nullopt;
-    auto RHS =
-        calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
+    auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1,
+                                     VectorIndex, ExtractElements);
     if (!RHS)
       return std::nullopt;
 
@@ -8726,7 +8727,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
     return Index < ByteShift
                ? SDByteProvider::getConstantZero()
                : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
-                                       Depth + 1, VectorIndex, Index);
+                                       Depth + 1, VectorIndex, ExtractElements,
+                                       Index);
   }
   case ISD::ANY_EXTEND:
   case ISD::SIGN_EXTEND:
@@ -8743,11 +8745,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
                        SDByteProvider::getConstantZero())
                  : std::nullopt;
     return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
-                                 StartingIndex);
+                                 ExtractElements, StartingIndex);
   }
   case ISD::BSWAP:
     return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
-                                 Depth + 1, VectorIndex, StartingIndex);
+                                 Depth + 1, VectorIndex, ExtractElements,
+                                 StartingIndex);
   case ISD::EXTRACT_VECTOR_ELT: {
     auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
     if (!OffsetOp)
@@ -8772,8 +8775,9 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
     if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
       return std::nullopt;
 
+    ExtractElements.insert(Op.getNode());
     return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
-                                 VectorIndex, StartingIndex);
+                                 VectorIndex, ExtractElements, StartingIndex);
   }
   case ISD::LOAD: {
     auto L = cast<LoadSDNode>(Op.getNode());
@@ -9110,6 +9114,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   SDValue Chain;
 
   SmallPtrSet<LoadSDNode *, 8> Loads;
+  SmallPtrSet<SDNode *, 8> ExtractElements;
   std::optional<SDByteProvider> FirstByteProvider;
   int64_t FirstOffset = INT64_MAX;
 
@@ -9119,7 +9124,9 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   unsigned ZeroExtendedBytes = 0;
   for (int i = ByteWidth - 1; i >= 0; --i) {
     auto P =
-        calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
+        calculateByteProvider(SDValue(N, 0), i, 0,
+                              /*VectorIndex*/ std::nullopt, ExtractElements,
+
                               /*StartingIndex*/ i);
     if (!P)
       return SDValue();
@@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
   if (!Allowed || !Fast)
     return SDValue();
 
+  // calculatebyteProvider() allows multi-use for vector loads. Ensure that
+  // all uses are in vector element extracts that are part of the pattern.
+  for (LoadSDNode *L : Loads)
+    if (L->getMemoryVT().isVector())
+      for (auto It = L->use_begin(); It != L->use_end(); ++It)
+        if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It))
+          return SDValue();
+
   SDValue NewLoad =
       DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
                      Chain, FirstLoad->getBasePtr(),
diff --git a/llvm/test/CodeGen/AArch64/load-combine.ll b/llvm/test/CodeGen/AArch64/load-combine.ll
index 57f61e5303ecf9..b30ee45aa4d1a0 100644
--- a/llvm/test/CodeGen/AArch64/load-combine.ll
+++ b/llvm/test/CodeGen/AArch64/load-combine.ll
@@ -606,10 +606,12 @@ define void @short_vector_to_i32_unused_high_i8(ptr %in, ptr %out, ptr %p) {
 ; CHECK-LABEL: short_vector_to_i32_unused_high_i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ldr s0, [x0]
-; CHECK-NEXT:    ldrh w9, [x0]
 ; CHECK-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-NEXT:    umov w8, v0.h[2]
-; CHECK-NEXT:    orr w8, w9, w8, lsl #16
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    umov w9, v0.h[0]
+; CHECK-NEXT:    umov w10, v0.h[2]
+; CHECK-NEXT:    bfi w9, w8, #8, #8
+; CHECK-NEXT:    orr w8, w9, w10, lsl #16
 ; CHECK-NEXT:    str w8, [x1]
 ; CHECK-NEXT:    ret
   %ld = load <4 x i8>, ptr %in, align 4
diff --git a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
index c27e44609c527f..96921082801821 100644
--- a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
+++ b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
@@ -205,12 +205,14 @@ define i64 @load_3xi16_combine(ptr addrspace(1) %p) #0 {
 ; GCN-LABEL: load_3xi16_combine:
 ; GCN:       ; %bb.0:
 ; GCN-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; GCN-NEXT:    global_load_dword v2, v[0:1], off
-; GCN-NEXT:    global_load_ushort v3, v[0:1], off offset:4
+; GCN-NEXT:    global_load_dword v3, v[0:1], off
+; GCN-NEXT:    global_load_ushort v2, v[0:1], off offset:4
+; GCN-NEXT:    s_mov_b32 s4, 0xffff
 ; GCN-NEXT:    s_waitcnt vmcnt(1)
-; GCN-NEXT:    v_mov_b32_e32 v0, v2
+; GCN-NEXT:    v_and_b32_e32 v0, 0xffff0000, v3
+; GCN-NEXT:    v_and_or_b32 v0, v3, s4, v0
 ; GCN-NEXT:    s_waitcnt vmcnt(0)
-; GCN-NEXT:    v_mov_b32_e32 v1, v3
+; GCN-NEXT:    v_mov_b32_e32 v1, v2
 ; GCN-NEXT:    s_setpc_b64 s[30:31]
   %gep.p = getelementptr i16, ptr addrspace(1) %p, i32 1
   %gep.2p = getelementptr i16, ptr addrspace(1) %p, i32 2
diff --git a/llvm/test/CodeGen/X86/load-combine.ll b/llvm/test/CodeGen/X86/load-combine.ll
index 7e4e11fcc75c20..530e17a0b0f099 100644
--- a/llvm/test/CodeGen/X86/load-combine.ll
+++ b/llvm/test/CodeGen/X86/load-combine.ll
@@ -1283,26 +1283,33 @@ define i32 @zext_load_i32_by_i8_bswap_shl_16(ptr %arg) {
   ret i32 %tmp8
 }
 
-; FIXME: This is a miscompile.
 define i32 @pr80911_vector_load_multiuse(ptr %ptr, ptr %clobber) nounwind {
 ; CHECK-LABEL: pr80911_vector_load_multiuse:
 ; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushl %edi
 ; CHECK-NEXT:    pushl %esi
-; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %ecx
 ; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %edx
-; CHECK-NEXT:    movl (%edx), %esi
-; CHECK-NEXT:    movzwl (%edx), %eax
-; CHECK-NEXT:    movl $0, (%ecx)
-; CHECK-NEXT:    movl %esi, (%edx)
+; CHECK-NEXT:    movl {{[0-9]+}}(%esp), %esi
+; CHECK-NEXT:    movzbl (%esi), %ecx
+; CHECK-NEXT:    movzbl 1(%esi), %eax
+; CHECK-NEXT:    movzwl 2(%esi), %edi
+; CHECK-NEXT:    movl $0, (%edx)
+; CHECK-NEXT:    movw %di, 2(%esi)
+; CHECK-NEXT:    movb %al, 1(%esi)
+; CHECK-NEXT:    movb %cl, (%esi)
+; CHECK-NEXT:    shll $8, %eax
+; CHECK-NEXT:    orl %ecx, %eax
 ; CHECK-NEXT:    popl %esi
+; CHECK-NEXT:    popl %edi
 ; CHECK-NEXT:    retl
 ;
 ; CHECK64-LABEL: pr80911_vector_load_multiuse:
 ; CHECK64:       # %bb.0:
-; CHECK64-NEXT:    movzwl (%rdi), %eax
+; CHECK64-NEXT:    movaps (%rdi), %xmm0
 ; CHECK64-NEXT:    movl $0, (%rsi)
-; CHECK64-NEXT:    movl (%rdi), %ecx
-; CHECK64-NEXT:    movl %ecx, (%rdi)
+; CHECK64-NEXT:    movss %xmm0, (%rdi)
+; CHECK64-NEXT:    movaps %xmm0, -{{[0-9]+}}(%rsp)
+; CHECK64-NEXT:    movzwl -{{[0-9]+}}(%rsp), %eax
 ; CHECK64-NEXT:    retq
   %load = load <4 x i8>, ptr %ptr, align 16
   store i32 0, ptr %clobber

@@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
if (!Allowed || !Fast)
return SDValue();

// calculatebyteProvider() allows multi-use for vector loads. Ensure that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

byte -> Byte

for (auto It = L->use_begin(); It != L->use_end(); ++It)
if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It))
return SDValue();

Copy link
Collaborator

@topperc topperc Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we go ahead and call DAG.makeEquivalentMemoryOrdering instead of ReplaceAllUsesOfValueWith below for maximum safety?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gah, I've been looking everywhere for a function that does this, and couldn't find it. After looking at a few other uses of that function, it seems like we actually have quite a few folds that are happy to introduce extra loads to avoid other instructions, so it seems like strictly avoiding multi-use loads in this fold may not be desirable. I've opened #81586 to only use makeEquivalentMemoryOrdering() instead.

@nikic nikic closed this Feb 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Miscompile in DAGCombine
3 participants