Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Refactor matchFunnelShift to allow more pattern (NFC) #68474

Merged
merged 2 commits into from
Oct 19, 2023

Conversation

HaohaiWen
Copy link
Contributor

Current implementation of matchFunnelShift only allows opposite shift pattern. Refactor it to allow more pattern.

Current implementation of matchFunnelShift only allows opposite shift
pattern. Refactor it to allow more pattern.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 7, 2023

@llvm/pr-subscribers-llvm-transforms

Changes

Current implementation of matchFunnelShift only allows opposite shift pattern. Refactor it to allow more pattern.


Full diff: https://github.com/llvm/llvm-project/pull/68474.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+93-79)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index cbdab3e9c5fb91d..b04e070fd19d7d1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2732,100 +2732,114 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
   // rotate matching code under visitSelect and visitTrunc?
   unsigned Width = Or.getType()->getScalarSizeInBits();
 
-  // First, find an or'd pair of opposite shifts:
-  // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
-  BinaryOperator *Or0, *Or1;
-  if (!match(Or.getOperand(0), m_BinOp(Or0)) ||
-      !match(Or.getOperand(1), m_BinOp(Or1)))
-    return nullptr;
-
-  Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
-  if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
-      !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
-      Or0->getOpcode() == Or1->getOpcode())
+  Instruction *Or0, *Or1;
+  if (!match(Or.getOperand(0), m_Instruction(Or0)) ||
+      !match(Or.getOperand(1), m_Instruction(Or1)))
     return nullptr;
 
-  // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
-  if (Or0->getOpcode() == BinaryOperator::LShr) {
-    std::swap(Or0, Or1);
-    std::swap(ShVal0, ShVal1);
-    std::swap(ShAmt0, ShAmt1);
-  }
-  assert(Or0->getOpcode() == BinaryOperator::Shl &&
-         Or1->getOpcode() == BinaryOperator::LShr &&
-         "Illegal or(shift,shift) pair");
-
-  // Match the shift amount operands for a funnel shift pattern. This always
-  // matches a subtraction on the R operand.
-  auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
-    // Check for constant shift amounts that sum to the bitwidth.
-    const APInt *LI, *RI;
-    if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
-      if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
-        return ConstantInt::get(L->getType(), *LI);
-
-    Constant *LC, *RC;
-    if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
-        match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
-        match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
-        match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
-      return ConstantExpr::mergeUndefsWith(LC, RC);
-
-    // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
-    // We limit this to X < Width in case the backend re-expands the intrinsic,
-    // and has to reintroduce a shift modulo operation (InstCombine might remove
-    // it after this fold). This still doesn't guarantee that the final codegen
-    // will match this original pattern.
-    if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
-      KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
-      return KnownL.getMaxValue().ult(Width) ? L : nullptr;
-    }
+  bool IsFshl = true; // Sub on LSHR.
+  SmallVector<Value *, 3> FShiftArgs;
 
-    // For non-constant cases, the following patterns currently only work for
-    // rotation patterns.
-    // TODO: Add general funnel-shift compatible patterns.
-    if (ShVal0 != ShVal1)
+  // First, find an or'd pair of opposite shifts:
+  // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
+  if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) {
+    Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+    if (!match(Or0,
+               m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+        !match(Or1,
+               m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+        Or0->getOpcode() == Or1->getOpcode())
       return nullptr;
 
-    // For non-constant cases we don't support non-pow2 shift masks.
-    // TODO: Is it worth matching urem as well?
-    if (!isPowerOf2_32(Width))
-      return nullptr;
+    // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+    if (Or0->getOpcode() == BinaryOperator::LShr) {
+      std::swap(Or0, Or1);
+      std::swap(ShVal0, ShVal1);
+      std::swap(ShAmt0, ShAmt1);
+    }
+    assert(Or0->getOpcode() == BinaryOperator::Shl &&
+           Or1->getOpcode() == BinaryOperator::LShr &&
+           "Illegal or(shift,shift) pair");
+
+    // Match the shift amount operands for a funnel shift pattern. This always
+    // matches a subtraction on the R operand.
+    auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
+      // Check for constant shift amounts that sum to the bitwidth.
+      const APInt *LI, *RI;
+      if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
+        if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
+          return ConstantInt::get(L->getType(), *LI);
+
+      Constant *LC, *RC;
+      if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
+          match(L,
+                m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+          match(R,
+                m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+          match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
+        return ConstantExpr::mergeUndefsWith(LC, RC);
+
+      // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
+      // We limit this to X < Width in case the backend re-expands the
+      // intrinsic, and has to reintroduce a shift modulo operation (InstCombine
+      // might remove it after this fold). This still doesn't guarantee that the
+      // final codegen will match this original pattern.
+      if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
+        KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
+        return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+      }
 
-    // The shift amount may be masked with negation:
-    // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
-    Value *X;
-    unsigned Mask = Width - 1;
-    if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
-        match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
-      return X;
+      // For non-constant cases, the following patterns currently only work for
+      // rotation patterns.
+      // TODO: Add general funnel-shift compatible patterns.
+      if (ShVal0 != ShVal1)
+        return nullptr;
 
-    // Similar to above, but the shift amount may be extended after masking,
-    // so return the extended value as the parameter for the intrinsic.
-    if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
-        match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
-                       m_SpecificInt(Mask))))
-      return L;
+      // For non-constant cases we don't support non-pow2 shift masks.
+      // TODO: Is it worth matching urem as well?
+      if (!isPowerOf2_32(Width))
+        return nullptr;
 
-    if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
-        match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
-      return L;
+      // The shift amount may be masked with negation:
+      // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
+      Value *X;
+      unsigned Mask = Width - 1;
+      if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
+          match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))
+        return X;
+
+      // Similar to above, but the shift amount may be extended after masking,
+      // so return the extended value as the parameter for the intrinsic.
+      if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+          match(R,
+                m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))),
+                      m_SpecificInt(Mask))))
+        return L;
+
+      if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+          match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
+        return L;
 
-    return nullptr;
-  };
+      return nullptr;
+    };
 
-  Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
-  bool IsFshl = true; // Sub on LSHR.
-  if (!ShAmt) {
-    ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
-    IsFshl = false; // Sub on SHL.
+    Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
+    if (!ShAmt) {
+      ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
+      IsFshl = false; // Sub on SHL.
+    }
+    if (!ShAmt)
+      return nullptr;
+
+    FShiftArgs = {ShVal0, ShVal1, ShAmt};
   }
-  if (!ShAmt)
+
+  if (FShiftArgs.empty())
     return nullptr;
 
   Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
   Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
-  return CallInst::Create(F, {ShVal0, ShVal1, ShAmt});
+  return CallInst::Create(F, FShiftArgs);
 }
 
 /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing patch description, alive2 proofs and tests.

@HaohaiWen
Copy link
Contributor Author

Missing patch description, alive2 proofs and tests.

This is a NFC patch. No new tests are needed.

@nikic
Copy link
Contributor

nikic commented Oct 7, 2023

This is a NFC patch. No new tests are needed.

Sorry, I missed the NFC in the title. Here is a diff without whitespace changes that shows that this is mostly just an indentation change: https://github.com/llvm/llvm-project/pull/68474/files?diff=split&w=1

Please put up the non-NFC patch you want to base on this. It's hard to say whether this makes sense without the dependent patch.

@HaohaiWen
Copy link
Contributor Author

This is a NFC patch. No new tests are needed.

Sorry, I missed the NFC in the title. Here is a diff without whitespace changes that shows that this is mostly just an indentation change: https://github.com/llvm/llvm-project/pull/68474/files?diff=split&w=1

Please put up the non-NFC patch you want to base on this. It's hard to say whether this makes sense without the dependent patch.

non-NFC patch: 68ab662

if (ShVal0 != ShVal1)
// First, find an or'd pair of opposite shifts:
// or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the rationale for this check? Isn't it covered by the matching of logical shifts?

But either way, if it is indeed useful, can you invert it and early return to keep nested scopes down.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to match either
or (shl %x, 16) (shr %y, 16)
or (shl %x, 16) (zext %y)

if (isa(Or0) && isa(Or1)) is used to filter in first case and that is the original code matched.
68ab662 filters in second case.
If we want to further reduce nested scopes. We may need to split this function to several static functions to match each case, but this requires to pass some args and may not be concise.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code seems new? Not that its wrong or anything.

But either way don't understand why it can't be an early return. If we don't enter the if FShiftArgs will be null and we will return anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code seems new? Not that its wrong or anything.

I added this if statement in this patch.

It can be early returned in this patch, but I'd like to add another else if on 5b3b1bb. If I removed this if statement, then I need to add the if back and indent the code in the if statement on #68502. This make the diff not clear.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay fair enough.

@goldsteinn
Copy link
Contributor

LGTM. @nikic you still want changes?

@HaohaiWen
Copy link
Contributor Author

If no objection, I'll merge it tomorrow.

@HaohaiWen HaohaiWen merged commit 8ff3e4f into llvm:main Oct 19, 2023
3 checks passed
@HaohaiWen HaohaiWen deleted the fshl branch October 19, 2023 01:06
@goldsteinn
Copy link
Contributor

Rebase you other patch then I can review.

@HaohaiWen
Copy link
Contributor Author

Rebase you other patch then I can review.

Already rebased. Please take a look #68502

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants