Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86] matchAddressRecursively - move ZERO_EXTEND patterns into matchIndexRecursively #85081

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

RicoAfoat
Copy link
Contributor

@RicoAfoat RicoAfoat commented Mar 13, 2024

Fixes #82431 - see #82431 for more information.

  1. Move all ZERO_EXTEND patterns into matchIndexRecursively.
  2. Match ZERO_EXTEND patterns recursively so that it can't be reduced further latter, which fixes the bug.
  3. Change the return type of matchIndexRecursively because there isn't a direct indicator to show whether pattern is matched or not. Alternatives like N!=matchIndexRecursively(N,AM,Depth+1) will not help because during match process SDNodes maybe be deallocate and allocate on the same memory address. Thus the return value and N maybe be "equal" on the surface.

(That is to say, the return value's node is build on the same address where N's node just deallocate)

Additional problems to be fixed:

  1. Maybe avoid directly modifying AM.IndexReg in matchingIndexRecursively

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few initial things I've noticed

auto ret_val = matchIndexRecursively(N, AM, Depth + 1);
if (ret_val)
return ret_val.value();
return N;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can all be reduced to

return matchIndexRecursively(N, AM, Depth + 1).value_or(N);

@@ -2670,97 +2766,14 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
break;
}
case ISD::ZERO_EXTEND: {
// Try to widen a zexted shift left to the same size as its use, so we can
// match the shift as a scale factor.
if (AM.IndexReg.getNode() != nullptr || AM.Scale != 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove the AM.Scale != 1? matchIndexRecursively needs to be able to accumulate scale

AM.IndexReg = matchIndexRecursively(Zext, AM, Depth + 1);
// All relevant operations are moved into matchIndexRecursively.
auto Index = matchIndexRecursively(N, AM, Depth + 1);
if (Index) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (auto Index = matchIndexRecursively(N, AM, Depth + 1)) {

if (ret_val)
AM.IndexReg = ret_val.value();
else
AM.IndexReg = IndexOp;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AM.IndexReg = matchIndexRecursively(IndexOp, AM, 0).value_or(IndexOp);

if (!ShAmtC)
return std::nullopt;
unsigned ShAmtV = ShAmtC->getZExtValue();
if (ShAmtV > 3 || (1 << ShAmtV) * AM.Scale > 8)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we can remove AM.Scale!=1 in matchAddressRecursively. I relax the restriction here but not for the last case where we try to deal with mask and shift.

@RicoAfoat RicoAfoat marked this pull request as ready for review March 15, 2024 00:53
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 15, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-x86

Author: RicoAfoat (RicoAfoat)

Changes

Fixes #82431 - see #82431 for more information.

  1. Move all ZERO_EXTEND patterns into matchIndexRecursively.
  2. Match ZERO_EXTEND patterns recursively so that it can't be reduced further latter, which fixes the bug.
  3. Change the return type of matchIndexRecursively because there isn't a direct indicator to show whether pattern is matched or not. Alternatives like N!=matchIndexRecursively(N,AM,Depth+1) will not help because during match process SDNodes maybe be deallocate and allocate on the same memory address. Thus the return value and N maybe be "equal" on the surface.

(That is to say, the return value's node is build on the same address where N's node just deallocate)

Additional problems to be fixed:

  1. adapter lambda expression is not that elegant
  2. Maybe avoid directly modifying AM.IndexReg in matchingIndexRecursively

Full diff: https://github.com/llvm/llvm-project/pull/85081.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelDAGToDAG.cpp (+152-145)
  • (modified) llvm/test/CodeGen/X86/inline-asm-memop.ll (+53-9)
diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index 76c6c1645239ab..06a322fd243212 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -212,8 +212,8 @@ namespace {
     bool matchAddress(SDValue N, X86ISelAddressMode &AM);
     bool matchVectorAddress(SDValue N, X86ISelAddressMode &AM);
     bool matchAdd(SDValue &N, X86ISelAddressMode &AM, unsigned Depth);
-    SDValue matchIndexRecursively(SDValue N, X86ISelAddressMode &AM,
-                                  unsigned Depth);
+    std::optional<SDValue>
+    matchIndexRecursively(SDValue N, X86ISelAddressMode &AM, unsigned Depth);
     bool matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
                                  unsigned Depth);
     bool matchVectorAddressRecursively(SDValue N, X86ISelAddressMode &AM,
@@ -2278,18 +2278,24 @@ static bool foldMaskedShiftToBEXTR(SelectionDAG &DAG, SDValue N,
   return false;
 }
 
-// Attempt to peek further into a scaled index register, collecting additional
-// extensions / offsets / etc. Returns /p N if we can't peek any further.
-SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
-                                               X86ISelAddressMode &AM,
-                                               unsigned Depth) {
+// Change the return type of matchIndexRecursively to std::optional<SDValue> to
+// make it easier for caller to tell whether the match is successful. Attempt to
+// peek further into a scaled index register, collecting additional extensions /
+// offsets / etc. Returns /p N if we can't peek any further.
+std::optional<SDValue>
+X86DAGToDAGISel::matchIndexRecursively(SDValue N, X86ISelAddressMode &AM,
+                                       unsigned Depth) {
   assert(AM.IndexReg.getNode() == nullptr && "IndexReg already matched");
   assert((AM.Scale == 1 || AM.Scale == 2 || AM.Scale == 4 || AM.Scale == 8) &&
          "Illegal index scale");
 
+  auto adapter = [this, &AM, &Depth](SDValue N) -> SDValue {
+    return matchIndexRecursively(N, AM, Depth + 1).value_or(N);
+  };
+
   // Limit recursion.
   if (Depth >= SelectionDAG::MaxRecursionDepth)
-    return N;
+    return std::nullopt;
 
   EVT VT = N.getValueType();
   unsigned Opc = N.getOpcode();
@@ -2299,14 +2305,14 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
     auto *AddVal = cast<ConstantSDNode>(N.getOperand(1));
     uint64_t Offset = (uint64_t)AddVal->getSExtValue() * AM.Scale;
     if (!foldOffsetIntoAddress(Offset, AM))
-      return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
+      return adapter(N.getOperand(0));
   }
 
   // index: add(x,x) -> index: x, scale * 2
   if (Opc == ISD::ADD && N.getOperand(0) == N.getOperand(1)) {
     if (AM.Scale <= 4) {
       AM.Scale *= 2;
-      return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
+      return adapter(N.getOperand(0));
     }
   }
 
@@ -2316,7 +2322,7 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
     uint64_t ScaleAmt = 1ULL << ShiftAmt;
     if ((AM.Scale * ScaleAmt) <= 8) {
       AM.Scale *= ScaleAmt;
-      return matchIndexRecursively(N.getOperand(0), AM, Depth + 1);
+      return adapter(N.getOperand(0));
     }
   }
 
@@ -2346,54 +2352,137 @@ SDValue X86DAGToDAGISel::matchIndexRecursively(SDValue N,
     }
   }
 
-  // index: zext(add_nuw(x,c)) -> index: zext(x), disp + zext(c)
-  // index: zext(addlike(x,c)) -> index: zext(x), disp + zext(c)
-  // TODO: call matchIndexRecursively(AddSrc) if we won't corrupt sext?
-  if (Opc == ISD::ZERO_EXTEND && !VT.isVector() && N.hasOneUse()) {
+  if (Opc == ISD::ZERO_EXTEND) {
+    // index: zext(add_nuw(x,c)) -> index: zext(x), disp + zext(c)
+    // index: zext(addlike(x,c)) -> index: zext(x), disp + zext(c)
     SDValue Src = N.getOperand(0);
-    unsigned SrcOpc = Src.getOpcode();
-    if (((SrcOpc == ISD::ADD && Src->getFlags().hasNoUnsignedWrap()) ||
-         CurDAG->isADDLike(Src)) &&
-        Src.hasOneUse()) {
-      if (CurDAG->isBaseWithConstantOffset(Src)) {
-        SDValue AddSrc = Src.getOperand(0);
-        uint64_t Offset = Src.getConstantOperandVal(1);
-        if (!foldOffsetIntoAddress(Offset * AM.Scale, AM)) {
-          SDLoc DL(N);
-          SDValue Res;
-          // If we're also scaling, see if we can use that as well.
-          if (AddSrc.getOpcode() == ISD::SHL &&
-              isa<ConstantSDNode>(AddSrc.getOperand(1))) {
-            SDValue ShVal = AddSrc.getOperand(0);
-            uint64_t ShAmt = AddSrc.getConstantOperandVal(1);
-            APInt HiBits =
-                APInt::getHighBitsSet(AddSrc.getScalarValueSizeInBits(), ShAmt);
-            uint64_t ScaleAmt = 1ULL << ShAmt;
-            if ((AM.Scale * ScaleAmt) <= 8 &&
-                (AddSrc->getFlags().hasNoUnsignedWrap() ||
-                 CurDAG->MaskedValueIsZero(ShVal, HiBits))) {
-              AM.Scale *= ScaleAmt;
-              SDValue ExtShVal = CurDAG->getNode(Opc, DL, VT, ShVal);
-              SDValue ExtShift = CurDAG->getNode(ISD::SHL, DL, VT, ExtShVal,
-                                                 AddSrc.getOperand(1));
-              insertDAGNode(*CurDAG, N, ExtShVal);
-              insertDAGNode(*CurDAG, N, ExtShift);
-              AddSrc = ExtShift;
-              Res = ExtShVal;
+    if (!VT.isVector() && N.hasOneUse()) {
+      unsigned SrcOpc = Src.getOpcode();
+      if (((SrcOpc == ISD::ADD && Src->getFlags().hasNoUnsignedWrap()) ||
+           CurDAG->isADDLike(Src)) &&
+          Src.hasOneUse()) {
+        if (CurDAG->isBaseWithConstantOffset(Src)) {
+          SDValue AddSrc = Src.getOperand(0);
+          uint64_t Offset = Src.getConstantOperandVal(1);
+          if (!foldOffsetIntoAddress(Offset * AM.Scale, AM)) {
+            SDLoc DL(N);
+            SDValue Res;
+            // If we're also scaling, see if we can use that as well.
+            if (AddSrc.getOpcode() == ISD::SHL &&
+                isa<ConstantSDNode>(AddSrc.getOperand(1))) {
+              SDValue ShVal = AddSrc.getOperand(0);
+              uint64_t ShAmt = AddSrc.getConstantOperandVal(1);
+              APInt HiBits = APInt::getHighBitsSet(
+                  AddSrc.getScalarValueSizeInBits(), ShAmt);
+              uint64_t ScaleAmt = 1ULL << ShAmt;
+              if ((AM.Scale * ScaleAmt) <= 8 &&
+                  (AddSrc->getFlags().hasNoUnsignedWrap() ||
+                   CurDAG->MaskedValueIsZero(ShVal, HiBits))) {
+                AM.Scale *= ScaleAmt;
+                SDValue ExtShVal = CurDAG->getNode(Opc, DL, VT, ShVal);
+                SDValue ExtShift = CurDAG->getNode(ISD::SHL, DL, VT, ExtShVal,
+                                                   AddSrc.getOperand(1));
+                insertDAGNode(*CurDAG, N, ExtShVal);
+                insertDAGNode(*CurDAG, N, ExtShift);
+                AddSrc = ExtShift;
+                Res = adapter(ExtShVal);
+              }
             }
+            SDValue ExtSrc = CurDAG->getNode(Opc, DL, VT, AddSrc);
+            SDValue ExtVal = CurDAG->getConstant(Offset, DL, VT);
+            SDValue ExtAdd = CurDAG->getNode(SrcOpc, DL, VT, ExtSrc, ExtVal);
+            insertDAGNode(*CurDAG, N, ExtSrc);
+            insertDAGNode(*CurDAG, N, ExtVal);
+            insertDAGNode(*CurDAG, N, ExtAdd);
+            CurDAG->ReplaceAllUsesWith(N, ExtAdd);
+            CurDAG->RemoveDeadNode(N.getNode());
+            // AM.IndexReg can be further picked
+            SDValue Zext = adapter(ExtSrc);
+            return Res ? Res : Zext;
           }
-          SDValue ExtSrc = CurDAG->getNode(Opc, DL, VT, AddSrc);
-          SDValue ExtVal = CurDAG->getConstant(Offset, DL, VT);
-          SDValue ExtAdd = CurDAG->getNode(SrcOpc, DL, VT, ExtSrc, ExtVal);
-          insertDAGNode(*CurDAG, N, ExtSrc);
-          insertDAGNode(*CurDAG, N, ExtVal);
-          insertDAGNode(*CurDAG, N, ExtAdd);
-          CurDAG->ReplaceAllUsesWith(N, ExtAdd);
-          CurDAG->RemoveDeadNode(N.getNode());
-          return Res ? Res : ExtSrc;
         }
       }
     }
+
+    // Peek through mask: zext(and(shl(x,c1),c2))
+    APInt Mask = APInt::getAllOnes(Src.getScalarValueSizeInBits());
+    if (Src.getOpcode() == ISD::AND && Src.hasOneUse())
+      if (auto *MaskC = dyn_cast<ConstantSDNode>(Src.getOperand(1))) {
+        Mask = MaskC->getAPIntValue();
+        Src = Src.getOperand(0);
+      }
+
+    if (Src.getOpcode() == ISD::SHL && Src.hasOneUse()) {
+      // Give up if the shift is not a valid scale factor [1,2,3].
+      SDValue ShlSrc = Src.getOperand(0);
+      SDValue ShlAmt = Src.getOperand(1);
+      auto *ShAmtC = dyn_cast<ConstantSDNode>(ShlAmt);
+      if (!ShAmtC)
+        return std::nullopt;
+      unsigned ShAmtV = ShAmtC->getZExtValue();
+      if (ShAmtV > 3 || (1 << ShAmtV) * AM.Scale > 8)
+        return std::nullopt;
+
+      // The narrow shift must only shift out zero bits (it must be 'nuw').
+      // That makes it safe to widen to the destination type.
+      APInt HighZeros =
+          APInt::getHighBitsSet(ShlSrc.getValueSizeInBits(), ShAmtV);
+      if (!Src->getFlags().hasNoUnsignedWrap() &&
+          !CurDAG->MaskedValueIsZero(ShlSrc, HighZeros & Mask))
+        return std::nullopt;
+
+      // zext (shl nuw i8 %x, C1) to i32
+      // --> shl (zext i8 %x to i32), (zext C1)
+      // zext (and (shl nuw i8 %x, C1), C2) to i32
+      // --> shl (zext i8 (and %x, C2 >> C1) to i32), (zext C1)
+      MVT SrcVT = ShlSrc.getSimpleValueType();
+      MVT VT = N.getSimpleValueType();
+      SDLoc DL(N);
+
+      SDValue Res = ShlSrc;
+      if (!Mask.isAllOnes()) {
+        Res = CurDAG->getConstant(Mask.lshr(ShAmtV), DL, SrcVT);
+        insertDAGNode(*CurDAG, N, Res);
+        Res = CurDAG->getNode(ISD::AND, DL, SrcVT, ShlSrc, Res);
+        insertDAGNode(*CurDAG, N, Res);
+      }
+      SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Res);
+      insertDAGNode(*CurDAG, N, Zext);
+      SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, ShlAmt);
+      insertDAGNode(*CurDAG, N, NewShl);
+      CurDAG->ReplaceAllUsesWith(N, NewShl);
+      CurDAG->RemoveDeadNode(N.getNode());
+
+      // Convert the shift to scale factor.
+      AM.Scale *= 1 << ShAmtV;
+      // If matchIndexRecursively is not called here,
+      // Zext may be replaced by other nodes but later used to call a builder
+      // method
+      return adapter(Zext);
+    }
+
+    // TODO: Do not modify AM.IndexReg inside directly.
+    if (AM.Scale != 1)
+      return std::nullopt;
+
+    if (Src.getOpcode() == ISD::SRL && !Mask.isAllOnes()) {
+      // Try to fold the mask and shift into an extract and scale.
+      if (!foldMaskAndShiftToExtract(*CurDAG, N, Mask.getZExtValue(), Src,
+                                     Src.getOperand(0), AM))
+        return AM.IndexReg;
+
+      // Try to fold the mask and shift directly into the scale.
+      if (!foldMaskAndShiftToScale(*CurDAG, N, Mask.getZExtValue(), Src,
+                                   Src.getOperand(0), AM))
+        return AM.IndexReg;
+
+      // Try to fold the mask and shift into BEXTR and scale.
+      if (!foldMaskedShiftToBEXTR(*CurDAG, N, Mask.getZExtValue(), Src,
+                                  Src.getOperand(0), AM, *Subtarget))
+        return AM.IndexReg;
+    }
+
+    return std::nullopt;
   }
 
   // TODO: Handle extensions, shifted masks etc.
@@ -2479,7 +2568,8 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
       if (Val == 1 || Val == 2 || Val == 3) {
         SDValue ShVal = N.getOperand(0);
         AM.Scale = 1 << Val;
-        AM.IndexReg = matchIndexRecursively(ShVal, AM, Depth + 1);
+        AM.IndexReg =
+            matchIndexRecursively(ShVal, AM, Depth + 1).value_or(ShVal);
         return false;
       }
     }
@@ -2670,97 +2760,13 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
     break;
   }
   case ISD::ZERO_EXTEND: {
-    // Try to widen a zexted shift left to the same size as its use, so we can
-    // match the shift as a scale factor.
-    if (AM.IndexReg.getNode() != nullptr || AM.Scale != 1)
+    if (AM.IndexReg.getNode() != nullptr)
       break;
-
-    SDValue Src = N.getOperand(0);
-
-    // See if we can match a zext(addlike(x,c)).
-    // TODO: Move more ZERO_EXTEND patterns into matchIndexRecursively.
-    if (Src.getOpcode() == ISD::ADD || Src.getOpcode() == ISD::OR)
-      if (SDValue Index = matchIndexRecursively(N, AM, Depth + 1))
-        if (Index != N) {
-          AM.IndexReg = Index;
-          return false;
-        }
-
-    // Peek through mask: zext(and(shl(x,c1),c2))
-    APInt Mask = APInt::getAllOnes(Src.getScalarValueSizeInBits());
-    if (Src.getOpcode() == ISD::AND && Src.hasOneUse())
-      if (auto *MaskC = dyn_cast<ConstantSDNode>(Src.getOperand(1))) {
-        Mask = MaskC->getAPIntValue();
-        Src = Src.getOperand(0);
-      }
-
-    if (Src.getOpcode() == ISD::SHL && Src.hasOneUse()) {
-      // Give up if the shift is not a valid scale factor [1,2,3].
-      SDValue ShlSrc = Src.getOperand(0);
-      SDValue ShlAmt = Src.getOperand(1);
-      auto *ShAmtC = dyn_cast<ConstantSDNode>(ShlAmt);
-      if (!ShAmtC)
-        break;
-      unsigned ShAmtV = ShAmtC->getZExtValue();
-      if (ShAmtV > 3)
-        break;
-
-      // The narrow shift must only shift out zero bits (it must be 'nuw').
-      // That makes it safe to widen to the destination type.
-      APInt HighZeros =
-          APInt::getHighBitsSet(ShlSrc.getValueSizeInBits(), ShAmtV);
-      if (!Src->getFlags().hasNoUnsignedWrap() &&
-          !CurDAG->MaskedValueIsZero(ShlSrc, HighZeros & Mask))
-        break;
-
-      // zext (shl nuw i8 %x, C1) to i32
-      // --> shl (zext i8 %x to i32), (zext C1)
-      // zext (and (shl nuw i8 %x, C1), C2) to i32
-      // --> shl (zext i8 (and %x, C2 >> C1) to i32), (zext C1)
-      MVT SrcVT = ShlSrc.getSimpleValueType();
-      MVT VT = N.getSimpleValueType();
-      SDLoc DL(N);
-
-      SDValue Res = ShlSrc;
-      if (!Mask.isAllOnes()) {
-        Res = CurDAG->getConstant(Mask.lshr(ShAmtV), DL, SrcVT);
-        insertDAGNode(*CurDAG, N, Res);
-        Res = CurDAG->getNode(ISD::AND, DL, SrcVT, ShlSrc, Res);
-        insertDAGNode(*CurDAG, N, Res);
-      }
-      SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Res);
-      insertDAGNode(*CurDAG, N, Zext);
-      SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, ShlAmt);
-      insertDAGNode(*CurDAG, N, NewShl);
-      CurDAG->ReplaceAllUsesWith(N, NewShl);
-      CurDAG->RemoveDeadNode(N.getNode());
-
-      // Convert the shift to scale factor.
-      AM.Scale = 1 << ShAmtV;
-      // If matchIndexRecursively is not called here,
-      // Zext may be replaced by other nodes but later used to call a builder
-      // method
-      AM.IndexReg = matchIndexRecursively(Zext, AM, Depth + 1);
+    // All relevant operations are moved into matchIndexRecursively.
+    if (auto Index = matchIndexRecursively(N, AM, Depth + 1)) {
+      AM.IndexReg = Index.value();
       return false;
     }
-
-    if (Src.getOpcode() == ISD::SRL && !Mask.isAllOnes()) {
-      // Try to fold the mask and shift into an extract and scale.
-      if (!foldMaskAndShiftToExtract(*CurDAG, N, Mask.getZExtValue(), Src,
-                                     Src.getOperand(0), AM))
-        return false;
-
-      // Try to fold the mask and shift directly into the scale.
-      if (!foldMaskAndShiftToScale(*CurDAG, N, Mask.getZExtValue(), Src,
-                                   Src.getOperand(0), AM))
-        return false;
-
-      // Try to fold the mask and shift into BEXTR and scale.
-      if (!foldMaskedShiftToBEXTR(*CurDAG, N, Mask.getZExtValue(), Src,
-                                  Src.getOperand(0), AM, *Subtarget))
-        return false;
-    }
-
     break;
   }
   }
@@ -2859,9 +2865,10 @@ bool X86DAGToDAGISel::selectVectorAddr(MemSDNode *Parent, SDValue BasePtr,
 
   // Attempt to match index patterns, as long as we're not relying on implicit
   // sign-extension, which is performed BEFORE scale.
-  if (IndexOp.getScalarValueSizeInBits() == BasePtr.getScalarValueSizeInBits())
-    AM.IndexReg = matchIndexRecursively(IndexOp, AM, 0);
-  else
+  if (IndexOp.getScalarValueSizeInBits() ==
+      BasePtr.getScalarValueSizeInBits()) {
+    AM.IndexReg = matchIndexRecursively(IndexOp, AM, 0).value_or(IndexOp);
+  } else
     AM.IndexReg = IndexOp;
 
   unsigned AddrSpace = Parent->getPointerInfo().getAddrSpace();
diff --git a/llvm/test/CodeGen/X86/inline-asm-memop.ll b/llvm/test/CodeGen/X86/inline-asm-memop.ll
index 83442498076102..38b81b11a8dfc2 100644
--- a/llvm/test/CodeGen/X86/inline-asm-memop.ll
+++ b/llvm/test/CodeGen/X86/inline-asm-memop.ll
@@ -1,20 +1,46 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
-; RUN: llc -mtriple=x86_64-unknown-linux-gnu -O0 < %s | FileCheck %s
+; RUN: llc -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
 
 ; A bug in X86DAGToDAGISel::matchAddressRecursively create a zext SDValue which
 ; is quickly replaced by other SDValue but already pushed into vector for later
 ; calling for SelectionDAGISel::Select_INLINEASM getNode builder, see issue
 ; 82431 for more infomation.
 
-define void @PR82431(i8 %call, ptr %b) {
-; CHECK-LABEL: PR82431:
+define i64 @PR82431_0(i8 %call, ptr %b) {
+; CHECK-LABEL: PR82431_0:
 ; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    movb %dil, %al
-; CHECK-NEXT:    addb $1, %al
-; CHECK-NEXT:    movzbl %al, %eax
-; CHECK-NEXT:    # kill: def $rax killed $eax
-; CHECK-NEXT:    shlq $3, %rax
-; CHECK-NEXT:    addq %rax, %rsi
+; CHECK-NEXT:    movzbl %dil, %eax
+; CHECK-NEXT:    movq 8(%rsi,%rax,8), %rax
+; CHECK-NEXT:    retq
+entry:
+  %narrow = add nuw i8 %call, 1
+  %idxprom = zext i8 %narrow to i64
+  %arrayidx = getelementptr [1 x i64], ptr %b, i64 0, i64 %idxprom
+  %ret_val = load i64, ptr %arrayidx
+  ret i64 %ret_val
+}
+
+define i32 @PR82431_1(i32 %0, ptr %f) {
+; CHECK-LABEL: PR82431_1:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    # kill: def $edi killed $edi def $rdi
+; CHECK-NEXT:    andl $4, %edi
+; CHECK-NEXT:    movl 4(%rsi,%rdi,2), %eax
+; CHECK-NEXT:    retq
+entry:
+  %shr = lshr i32 %0, 1
+  %and = and i32 %shr, 2
+  %add = or i32 %and, 1
+  %idxprom = zext i32 %add to i64
+  %arrayidx = getelementptr [0 x i32], ptr %f, i64 0, i64 %idxprom
+  %ret_val = load i32, ptr %arrayidx
+  ret i32 %ret_val
+}
+
+define void @PR82431_2(i8 %call, ptr %b) {
+; CHECK-LABEL: PR82431_2:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    movzbl %dil, %eax
 ; CHECK-NEXT:    #APP
 ; CHECK-NEXT:    #NO_APP
 ; CHECK-NEXT:    retq
@@ -25,3 +51,21 @@ entry:
   tail call void asm "", "=*m,*m,~{dirflag},~{fpsr},~{flags}"(ptr elementtype(i64) %arrayidx, ptr elementtype(i64) %arrayidx)
   ret void
 }
+
+define void @PR82431_3(i32 %0, ptr %f) {
+; CHECK-LABEL: PR82431_3:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    # kill: def $edi killed $edi def $rdi
+; CHECK-NEXT:    andl $4, %edi
+; CHECK-NEXT:    #APP
+; CHECK-NEXT:    #NO_APP
+; CHECK-NEXT:    retq
+entry:
+  %shr = lshr i32 %0, 1
+  %and = and i32 %shr, 2
+  %add = or i32 %and, 1
+  %idxprom = zext i32 %add to i64
+  %arrayidx = getelementptr [0 x i32], ptr %f, i64 0, i64 %idxprom
+  tail call void asm "", "=*m,*m,~{dirflag},~{fpsr},~{flags}"(ptr elementtype(i32) %arrayidx, ptr elementtype(i32) %arrayidx)
+  ret void
+}

@Mekapaedia
Copy link

Does this patch also solve/affect #72026?

@RKSimon
Copy link
Collaborator

RKSimon commented Apr 24, 2024

@RicoAfoat Sorry this fell off my radar over the Easter break - I'll take a look soon.

@Mekapaedia
Copy link

I tested this patch rebased on main and it does seem to correctly compile the reproducer from #72026, so this PR could close that and by extension #85443

// make it easier for caller to tell whether the match is successful. Attempt to
// peek further into a scaled index register, collecting additional extensions /
// offsets / etc. Returns /p N if we can't peek any further.
std::optional<SDValue>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't look too closely at that patch. Do we need a std::optional here? SDValue already has an empty state, can't we use that like we normally do in SelectionDAG?

// method
AM.IndexReg = matchIndexRecursively(Zext, AM, Depth + 1);
// All relevant operations are moved into matchIndexRecursively.
if (auto Index = matchIndexRecursively(N, AM, Depth + 1)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#85081 (comment)

@topperc @RKSimon

For most of the time, AM.IndexReg will be set to the return value of matchIndexRecursively whenever this function is called except here. I'm wondering if a bool is nice for the return type, and we directly change the value of AM.IndexReg inside matchIndexRecursively(We have already done that to AM.Scale).

In addition, some functions (foldMaskAndShiftToExtract,foldMaskAndShiftToScale, foldMaskedShiftToBEXTR) which directly modify AM.IndexReg and are previously called in matchAddressRecursively are moved to matchIndexRecursively. It will be a little strange to return the value... Otherwise we need to adjust them as well.

Should we change it to a bool ?

@topperc
Copy link
Collaborator

topperc commented Apr 25, 2024

To make sure I understand. The fundamental issue seems to be that X86 is deleting nodes while Select_INLINEASM is holding references to them from the selection of an earlier memory operand.

So my immediate thoughs/questions are.
-Can we remove the deletion? If the node is truly dead, it will be pruned when we reach it in the isel selection process. This might affect use counts and affect the behavior of selection while that node is still hanging out in the graph.
-Can we disable the deletion when selecting inline assembly?
-Maybe Select_INLINEASM should be holding the nodes as HandleSDNodes so they will be updated by replaceAllUses.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this - the point of matchIndexRecursively was that it could improve the indexReg / AM.scale values as it worked up the tree, its not designed to fail in anyway, just stop recursing.

Technically matchIndexRecursively could return void and it updates the AM.IndexReg internally (like it does AM.Scale), I didn't do it that way as I felt the recursion worked better by setting the value on return, but I may have been wrong.

 during SelectInlineAsmMemoryOperands
@Mekapaedia
Copy link

Is there any update on this? The reproducer from #72026 is still failing with trunk on godbolt.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
5 participants