Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternMatch] Add a matching helper m_ElementWiseBitCast. NFC. #80764

Merged
merged 1 commit into from
Feb 7, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Feb 5, 2024

This patch introduces a matching helper m_ElementWiseBitCast, which is used for matching element-wise int<-> fp casts.
The motivation of this patch is to avoid duplicating checks in #80740 and #80414.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 5, 2024

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch introduces a matching helper m_ElementWiseBitCast, which is used for matching element-wise int<-> fp casts.
The motivation of this patch is to avoid duplicating checks in #80740 and #80414.


Full diff: https://github.com/llvm/llvm-project/pull/80764.diff

5 Files Affected:

  • (modified) llvm/include/llvm/IR/PatternMatch.h (+24)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+1-1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+9-13)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+5-12)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+4-2)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 878079c4fe4e8e..7ce42e30639092 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1711,6 +1711,30 @@ m_BitCast(const OpTy &Op) {
   return CastOperator_match<OpTy, Instruction::BitCast>(Op);
 }
 
+template <typename Op_t> struct ElementWiseBitCast_match {
+  Op_t Op;
+
+  ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
+
+  template <typename OpTy> bool match(OpTy *V) {
+    if (auto *I = dyn_cast<BitCastInst>(V)) {
+      Type *SrcType = I->getSrcTy();
+      Type *DstType = I->getType();
+      // Make sure the bitcast doesn't change between scalar and vector and
+      // doesn't change the number of vector elements.
+      if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+          SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits())
+        return Op.match(I->getOperand(0));
+    }
+    return false;
+  }
+};
+
+template <typename OpTy>
+inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
+  return ElementWiseBitCast_match<OpTy>(Op);
+}
+
 /// Matches PtrToInt.
 template <typename OpTy>
 inline CastOperator_match<OpTy, Instruction::PtrToInt>
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2793b798f35f36..01b017142cfcb0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   // floating-point casts:
   // icmp slt (bitcast (uitofp X)),  0 --> false
   // icmp sgt (bitcast (uitofp X)), -1 --> true
-  if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+  if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
     if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
       return ConstantInt::getFalse(ITy);
     if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6ca4d6d673068e..0409f33445f0af 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2493,14 +2493,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   // Assumes any IEEE-represented type has the sign bit in the high bit.
   // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
       match(Op1, m_MaxSignedValue()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
-        Attribute::NoImplicitFloat)) {
+          Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       return new BitCastInst(FAbs, I.getType());
     }
@@ -3925,13 +3923,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   // This is generous interpretation of noimplicitfloat, this is not a true
   // floating-point operation.
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+      match(Op1, m_SignMask()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       Value *FNegFAbs = Builder.CreateFNeg(FAbs);
       return new BitCastInst(FNegFAbs, I.getType());
@@ -4701,13 +4698,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
     // Assumes any IEEE-represented type has the sign bit in the high bit.
     // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
     Value *CastOp;
-    if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+    if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+        match(Op1, m_SignMask()) &&
         !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
             Attribute::NoImplicitFloat)) {
       Type *EltTy = CastOp->getType()->getScalarType();
-      if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-          EltTy->getPrimitiveSizeInBits() ==
-          I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+      if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
         Value *FNeg = Builder.CreateFNeg(CastOp);
         return new BitCastInst(FNeg, I.getType());
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 380cb3504209d3..cda1061fb35a8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
   Value *V;
   if (!Cmp.getParent()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat) &&
-      Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
-    Type *SrcType = V->getType();
-    Type *DstType = X->getType();
-    Type *FPType = SrcType->getScalarType();
-    // Make sure the bitcast doesn't change between scalar and vector and
-    // doesn't change the number of vector elements.
-    if (SrcType->isVectorTy() == DstType->isVectorTy() &&
-        SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
-        FPType->isIEEELikeFPTy() && C1 == *C2) {
+      Cmp.isEquality() &&
+      match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
+    Type *FPType = V->getType()->getScalarType();
+    if (FPType->isIEEELikeFPTy() && C1 == *C2) {
       APInt ExponentMask =
           APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
       if (C1 == ExponentMask) {
@@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
   // Ignore signbit of bitcasted int when comparing equality to FP 0.0:
   // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
   if (match(Op1, m_PosZeroFP()) &&
-      match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
-      X->getType()->isVectorTy() == OpType->isVectorTy() &&
-      X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+      match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
     ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
     if (Pred == FCmpInst::FCMP_OEQ)
       IntPred = ICmpInst::ICMP_EQ;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2756f81ed9e620..d6f447125269c9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2382,7 +2382,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
   const APInt *C;
   bool IsTrueIfSignSet;
   ICmpInst::Predicate Pred;
-  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
+  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+                                   m_APInt(C)))) ||
       !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
       X->getType() != SelType)
     return nullptr;
@@ -2783,7 +2784,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
     CmpInst::Predicate Pred;
     const APInt *C;
     bool TrueIfSigned;
-    if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
+    if (!match(CondVal,
+               m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
         !IC.isSignBitCheck(Pred, *C, TrueIfSigned))
       continue;
     if (!match(TrueVal, m_FNeg(m_Specific(X))))

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit test in unittests/IR/PatternMatch.cpp

llvm/include/llvm/IR/PatternMatch.h Outdated Show resolved Hide resolved
llvm/include/llvm/IR/PatternMatch.h Outdated Show resolved Hide resolved
llvm/lib/Analysis/InstructionSimplify.cpp Show resolved Hide resolved
llvm/unittests/IR/PatternMatch.cpp Show resolved Hide resolved
Comment on lines +186 to +192
// If it's a bitcast involving vectors, make sure it has the same number
// of elements on both sides.
if (CI.getOpcode() != Instruction::BitCast ||
match(&CI, m_ElementWiseBitCast(m_Value()))) {
if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
return NV;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this another untested bug fix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the check in InstructionCombing.cpp:

 // If it's a bitcast involving vectors, make sure it has the same number of
 // elements on both sides.
 if (auto *BC = dyn_cast<BitCastInst>(&Op)) {
   VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy());
   VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy());

   // Verify that either both or neither are vectors.
   if ((SrcTy == nullptr) != (DestTy == nullptr))
     return nullptr;

   // If vectors, verify that they have the same number of elements.
   if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount())
     return nullptr;
 }

@dtcxzyw dtcxzyw merged commit f37d81f into llvm:main Feb 7, 2024
3 of 4 checks passed
@dtcxzyw dtcxzyw deleted the elementwise-bitcast branch February 7, 2024 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants