Skip to content

Commit 7506ebc

Browse files
committed
Add freeze
1 parent edbebb4 commit 7506ebc

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13083,15 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1308313083
SDValue LHSExtOp = LHS->getOperand(0);
1308413084
EVT LHSExtOpVT = LHSExtOp.getValueType();
1308513085

13086-
// Return 'select(P, Op, splat(0))' if P is nonzero,
13087-
// or 'P' otherwise.
13088-
auto tryPredicate = [&](SDValue P, SDValue Op) {
13086+
// Sets Op = select(P, Op, splat(0)) if P is nonzero, or Op otherwise.
13087+
// Set ToFreezeOp = freeze(ToFreezeOp) if the value may be poison, to
13088+
// keep the same semantics.
13089+
auto ApplyPredicate = [&](SDValue P, SDValue &Op, SDValue &ToFreezeOp) {
1308913090
if (!P)
13090-
return Op;
13091+
return;
13092+
if (!DAG.isGuaranteedNotToBePoison(ToFreezeOp))
13093+
ToFreezeOp = DAG.getFreeze(ToFreezeOp);
1309113094
EVT OpVT = Op.getValueType();
1309213095
SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
1309313096
: DAG.getConstant(0, DL, OpVT);
13094-
return DAG.getSelect(DL, OpVT, P, Op, Zero);
13097+
Op = DAG.getSelect(DL, OpVT, P, Op, Zero);
1309513098
};
1309613099

1309713100
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
@@ -13116,10 +13119,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1311613119
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1311713120
return SDValue();
1311813121

13119-
SDValue Constant =
13120-
tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
13121-
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
13122-
Constant);
13122+
SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
13123+
ApplyPredicate(Pred, C, LHSExtOp);
13124+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
1312313125
}
1312413126

1312513127
unsigned RHSOpcode = RHS->getOpcode();
@@ -13160,7 +13162,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1316013162
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1316113163
return SDValue();
1316213164

13163-
RHSExtOp = tryPredicate(Pred, RHSExtOp);
13165+
ApplyPredicate(Pred, RHSExtOp, LHSExtOp);
1316413166
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
1316513167
}
1316613168

0 commit comments

Comments
 (0)