-
Notifications
You must be signed in to change notification settings - Fork 15k
[InstCombine] Support multi-use values in cast elimination transforms #165877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[InstCombine] Support multi-use values in cast elimination transforms #165877
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: Valeriy Savchenko (SavchenkoValeriy) Changes
This change tracks visited values and defers decisions on multi-use values until we verify all their users were visited. Applied to truncation and sext. Zext unchanged due to its dual-return nature. Patch is 47.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165877.diff 7 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4c9b10a094981..6184c6d25d929 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -12,14 +12,21 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <iterator>
#include <optional>
using namespace llvm;
@@ -27,12 +34,19 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
-/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
-/// true for, actually insert the code to evaluate the expression.
-Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
- bool isSigned) {
+using EvaluatedMap = SmallDenseMap<Value *, Value *, 8>;
+
+static Value *EvaluateInDifferentTypeImpl(Value *V, Type *Ty, bool isSigned,
+ InstCombinerImpl &IC,
+ EvaluatedMap &Processed) {
+ // Since we cover transformation of isntructions with multiple users, we might
+ // come to the same node via multiple paths. We should not create a
+ // replacement for every single one of them though.
+ if (const auto It = Processed.find(V); It != Processed.end())
+ return It->getSecond();
+
if (Constant *C = dyn_cast<Constant>(V))
- return ConstantFoldIntegerCast(C, Ty, isSigned, DL);
+ return ConstantFoldIntegerCast(C, Ty, isSigned, IC.getDataLayout());
// Otherwise, it must be an instruction.
Instruction *I = cast<Instruction>(V);
@@ -50,8 +64,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
case Instruction::Shl:
case Instruction::UDiv:
case Instruction::URem: {
- Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
- Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
+ Value *LHS = EvaluateInDifferentTypeImpl(I->getOperand(0), Ty, isSigned, IC,
+ Processed);
+ Value *RHS = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned, IC,
+ Processed);
Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
if (Opc == Instruction::LShr || Opc == Instruction::AShr)
Res->setIsExact(I->isExact());
@@ -72,8 +88,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
Opc == Instruction::SExt);
break;
case Instruction::Select: {
- Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
- Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned);
+ Value *True = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned,
+ IC, Processed);
+ Value *False = EvaluateInDifferentTypeImpl(I->getOperand(2), Ty, isSigned,
+ IC, Processed);
Res = SelectInst::Create(I->getOperand(0), True, False);
break;
}
@@ -81,8 +99,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
PHINode *OPN = cast<PHINode>(I);
PHINode *NPN = PHINode::Create(Ty, OPN->getNumIncomingValues());
for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) {
- Value *V =
- EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned);
+ Value *V = EvaluateInDifferentTypeImpl(OPN->getIncomingValue(i), Ty,
+ isSigned, IC, Processed);
NPN->addIncoming(V, OPN->getIncomingBlock(i));
}
Res = NPN;
@@ -90,8 +108,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
}
case Instruction::FPToUI:
case Instruction::FPToSI:
- Res = CastInst::Create(
- static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty);
+ Res = CastInst::Create(static_cast<Instruction::CastOps>(Opc),
+ I->getOperand(0), Ty);
break;
case Instruction::Call:
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
@@ -111,8 +129,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
auto *ScalarTy = cast<VectorType>(Ty)->getElementType();
auto *VTy = cast<VectorType>(I->getOperand(0)->getType());
auto *FixedTy = VectorType::get(ScalarTy, VTy->getElementCount());
- Value *Op0 = EvaluateInDifferentType(I->getOperand(0), FixedTy, isSigned);
- Value *Op1 = EvaluateInDifferentType(I->getOperand(1), FixedTy, isSigned);
+ Value *Op0 = EvaluateInDifferentTypeImpl(I->getOperand(0), FixedTy,
+ isSigned, IC, Processed);
+ Value *Op1 = EvaluateInDifferentTypeImpl(I->getOperand(1), FixedTy,
+ isSigned, IC, Processed);
Res = new ShuffleVectorInst(Op0, Op1,
cast<ShuffleVectorInst>(I)->getShuffleMask());
break;
@@ -123,7 +143,22 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
}
Res->takeName(I);
- return InsertNewInstWith(Res, I->getIterator());
+ Value *Result = IC.InsertNewInstWith(Res, I->getIterator());
+ // There is no need in keeping track of the old value/new value relationship
+ // when we have only one user, we came have here from that user and no-one
+ // else cares.
+ if (!V->hasOneUse()) {
+ Processed[V] = Result;
+ }
+ return Result;
+}
+
+/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
+/// true for, actually insert the code to evaluate the expression.
+Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
+ bool isSigned) {
+ EvaluatedMap Processed;
+ return EvaluateInDifferentTypeImpl(V, Ty, isSigned, *this, Processed);
}
Instruction::CastOps
@@ -227,9 +262,175 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
return nullptr;
}
+namespace {
+
+/// Helper class for evaluating whether a value can be computed in a different
+/// type without changing its value. Used by cast simplification transforms.
+class TypeEvaluationHelper {
+public:
+ /// Return true if we can evaluate the specified expression tree as type Ty
+ /// instead of its larger type, and arrive with the same value.
+ /// This is used by code that tries to eliminate truncates.
+ [[nodiscard]] static bool canEvaluateTruncated(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+
+ /// Determine if the specified value can be computed in the specified wider
+ /// type and produce the same low bits. If not, return false.
+ [[nodiscard]] static bool canEvaluateZExtd(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+
+ /// Return true if we can take the specified value and return it as type Ty
+ /// without inserting any new casts and without changing the value of the
+ /// common low bits.
+ [[nodiscard]] static bool canEvaluateSExtd(Value *V, Type *Ty);
+
+private:
+ /// Constants and extensions/truncates from the destination type are always
+ /// free to be evaluated in that type.
+ [[nodiscard]] static bool canAlwaysEvaluateInType(Value *V, Type *Ty);
+
+ /// Check if we traversed all the users of the multi-use values we've seen.
+ [[nodiscard]] bool allPendingVisited() const {
+ return llvm::all_of(Pending,
+ [this](Value *V) { return Visited.contains(V); });
+ }
+
+ /// A generic wrapper for canEvaluate* recursions to inject visitation
+ /// tracking and enforce correct multi-use value evaluations.
+ [[nodiscard]] bool
+ canEvaluate(Value *V, Type *Ty,
+ llvm::function_ref<bool(Value *, Type *Type)> Pred) {
+ if (canAlwaysEvaluateInType(V, Ty))
+ return true;
+
+ if (!isa<Instruction>(V))
+ return false;
+
+ auto *I = cast<Instruction>(V);
+ // We insert false by default to return false when we encounter user loops.
+ const auto [It, Inserted] = Visited.insert({V, false});
+
+ // There are three possible cases for us having information on this value
+ // in the Visited map:
+ // 1. We properly checked it and concluded that we can evaluate it (true)
+ // 2. We properly checked it and concluded that we can't (false)
+ // 3. We started to check it, but during the recursive traversal we came
+ // back to it.
+ //
+ // For cases 1 and 2, we can safely return the stored result. For case 3, we
+ // can potentially have a situation where we can evaluate recursive user
+ // chains, but that can be quite tricky to do properly and isntead, we
+ // return false.
+ //
+ // In any case, we should return whatever was there in the map to begin
+ // with.
+ if (!Inserted)
+ return It->getSecond();
+
+ // We can easily make a decision about single-user values whether they can
+ // be evaluated in a different type or not, we came from that user. This is
+ // not as simple for multi-user values.
+ //
+ // In general, we have the following case (inverted control-flow, users are
+ // at the top):
+ //
+ // Cast %A
+ // ____|
+ // /
+ // %A = Use %B, %C
+ // ________| |
+ // / |
+ // %B = Use %D |
+ // ________| |
+ // / |
+ // %D = Use %C |
+ // ________|___|
+ // /
+ // %C = ...
+ //
+ // In this case, when we check %A, %B and %C, we are confident that we can
+ // make the decision here and now, since we came from their only users.
+ //
+ // For %C, it is harder. We come there twice, and when we come the first
+ // time, it's hard to tell if we will visit the second user (technically
+ // it's not hard, but we might need a lot of repetitive checks with non-zero
+ // cost).
+ //
+ // In the case above, we are allowed to evaluate %C in different type
+ // because all of it users were part of the traversal.
+ //
+ // In the following case, however, we can't make this conclusion:
+ //
+ // Cast %A
+ // ____|
+ // /
+ // %A = Use %B, %C
+ // ________| |
+ // / |
+ // %B = Use %D |
+ // ________| |
+ // / |
+ // %D = Use %C |
+ // | |
+ // foo(%C) | | <- never traversing foo(%C)
+ // ________|___|
+ // /
+ // %C = ...
+ //
+ // In this case, we still can evaluate %C in a different type, but we'd need
+ // to create a copy of the original %C to be used in foo(%C). Such
+ // duplication might be not profitable.
+ //
+ // For this reason, we collect all users of the mult-user values and mark
+ // them as "pending" and defer this decision to the very end. When we are
+ // done and and ready to have a positive verdict, we should double-check all
+ // of the pending users and ensure that we visited them. allPendingVisited
+ // predicate checks exactly that.
+ if (!I->hasOneUse()) {
+ llvm::transform(I->uses(), std::back_inserter(Pending),
+ [](Use &U) { return U.getUser(); });
+ }
+
+ const bool Result = Pred(V, Ty);
+ // We have to set result this way and not via It because Pred is recursive
+ // and it is very likely that we grew Visited and invalidated It.
+ Visited[V] = Result;
+ return Result;
+ }
+
+ /// Filter out values that we can not evaluate in the destination type for
+ /// free.
+ [[nodiscard]] bool canNotEvaluateInType(Value *V, Type *Ty);
+
+ [[nodiscard]] bool canEvaluateTruncatedImpl(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateTruncatedPred(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateZExtdImpl(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateSExtdImpl(Value *V, Type *Ty);
+ [[nodiscard]] bool canEvaluateSExtdPred(Value *V, Type *Ty);
+
+ /// A bookkeeping map to memorize an already made decision for a traversed
+ /// value.
+ SmallDenseMap<Value *, bool, 8> Visited;
+
+ /// A list of pending values to check in the end.
+ SmallVector<Value *, 8> Pending;
+};
+
+} // anonymous namespace
+
/// Constants and extensions/truncates from the destination type are always
/// free to be evaluated in that type. This is a helper for canEvaluate*.
-static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canAlwaysEvaluateInType(Value *V, Type *Ty) {
if (isa<Constant>(V))
return match(V, m_ImmConstant());
@@ -243,7 +444,7 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
/// Filter out values that we can not evaluate in the destination type for free.
/// This is a helper for canEvaluate*.
-static bool canNotEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canNotEvaluateInType(Value *V, Type *Ty) {
if (!isa<Instruction>(V))
return true;
// We don't extend or shrink something that has multiple uses -- doing so
@@ -265,13 +466,27 @@ static bool canNotEvaluateInType(Value *V, Type *Ty) {
///
/// This function works on both vectors and scalars.
///
-static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
- Instruction *CxtI) {
- if (canAlwaysEvaluateInType(V, Ty))
- return true;
- if (canNotEvaluateInType(V, Ty))
- return false;
+bool TypeEvaluationHelper::canEvaluateTruncated(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
+ TypeEvaluationHelper TYH;
+ return TYH.canEvaluateTruncatedImpl(V, Ty, IC, CxtI) &&
+ // We need to check whether we visited all users of multi-user values,
+ // and we have to do it at the very end, outside of the recursion.
+ TYH.allPendingVisited();
+}
+bool TypeEvaluationHelper::canEvaluateTruncatedImpl(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
+ return canEvaluate(V, Ty, [this, &IC, CxtI](Value *V, Type *Ty) {
+ return canEvaluateTruncatedPred(V, Ty, IC, CxtI);
+ });
+}
+
+bool TypeEvaluationHelper::canEvaluateTruncatedPred(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
auto *I = cast<Instruction>(V);
Type *OrigTy = V->getType();
switch (I->getOpcode()) {
@@ -282,8 +497,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
case Instruction::Or:
case Instruction::Xor:
// These operators can all arbitrarily be extended or truncated.
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
case Instruction::UDiv:
case Instruction::URem: {
@@ -296,8 +511,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// based on later context may introduce a trap.
if (IC.MaskedValueIsZero(I->getOperand(0), Mask, I) &&
IC.MaskedValueIsZero(I->getOperand(1), Mask, I)) {
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, I) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, I);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
break;
}
@@ -308,8 +523,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
KnownBits AmtKnownBits =
llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
if (AmtKnownBits.getMaxValue().ult(BitWidth))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::LShr: {
@@ -329,12 +544,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) {
auto DemandedBits = Trunc->getType()->getScalarSizeInBits();
if ((MaxShiftAmt + DemandedBits).ule(BitWidth))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
if (IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, CxtI))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
break;
}
@@ -351,8 +566,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
unsigned ShiftedBits = OrigBitWidth - BitWidth;
if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), CxtI))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::Trunc:
@@ -365,18 +580,18 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
return true;
case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
- return canEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) &&
- canEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(SI->getTrueValue(), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(SI->getFalseValue(), Ty, IC, CxtI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
- // get into trouble with cyclic PHIs here because we only consider
- // instructions with a single use.
+ // get into trouble with cyclic PHIs here because canEvaluate handles use
+ // chain loops.
PHINode *PN = cast<PHINode>(I);
- for (Value *IncValue : PN->incoming_values())
- if (!canEvaluateTruncated(IncValue, Ty, IC, CxtI))
- return false;
- return true;
+ return llvm::all_of(
+ PN->incoming_values(), [this, Ty, &IC, CxtI](Value *IncValue) {
+ return canEvaluateTruncatedImpl(IncValue, Ty, IC, CxtI);
+ });
}
case Instruction::FPToUI:
case Instruction::FPToSI: {
@@ -385,14 +600,14 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// that did not exist in the original code.
Type *InputTy = I->g...
[truncated]
|
c6e4b32 to
bc04f35
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
bc04f35 to
289336f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me. I will have a deeper review later.
|
|
||
| /// Check if we traversed all the users of the multi-use values we've seen. | ||
| [[nodiscard]] bool allPendingVisited() const { | ||
| return llvm::all_of(Pending, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see the reason for adding a Pending buffer. Can't we just use I->users()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might have multiple values that we visit during our traversal that we need to check at the end of traversal. So, it doesn't matter if we keep track of multi-user values themselves or their users, we still would need to have a buffer for that.
289336f to
5559d54
Compare
5559d54 to
5886f6e
Compare
canEvaluateTruncatedandcanEvaluateSExtdpreviously rejected multi-use values to avoid duplication. This was overly conservative, if all users of a multi-use value are part of the transform, we can evaluate it in a different type without duplication.This change tracks visited values and defers decisions on multi-use values until we verify all their users were visited.
EvaluateInDifferentTypenow memoizes multi-use values to avoid creating duplicates.Applied to truncation and sext. Zext unchanged due to its dual-return nature.