@@ -13083,15 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1308313083 SDValue LHSExtOp = LHS->getOperand(0);
1308413084 EVT LHSExtOpVT = LHSExtOp.getValueType();
1308513085
13086- // Return 'select(P, Op, splat(0))' if P is nonzero,
13087- // or 'P' otherwise.
13088- auto tryPredicate = [&](SDValue P, SDValue Op) {
13086+ // Sets Op = select(P, Op, splat(0)) if P is nonzero, or Op otherwise.
13087+ // Set ToFreezeOp = freeze(ToFreezeOp) if the value may be poison, to
13088+ // keep the same semantics.
13089+ auto ApplyPredicate = [&](SDValue P, SDValue &Op, SDValue &ToFreezeOp) {
1308913090 if (!P)
13090- return Op;
13091+ return;
13092+ if (!DAG.isGuaranteedNotToBePoison(ToFreezeOp))
13093+ ToFreezeOp = DAG.getFreeze(ToFreezeOp);
1309113094 EVT OpVT = Op.getValueType();
1309213095 SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
1309313096 : DAG.getConstant(0, DL, OpVT);
13094- return DAG.getSelect(DL, OpVT, P, Op, Zero);
13097+ Op = DAG.getSelect(DL, OpVT, P, Op, Zero);
1309513098 };
1309613099
1309713100 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
@@ -13116,10 +13119,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1311613119 TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1311713120 return SDValue();
1311813121
13119- SDValue Constant =
13120- tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
13121- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
13122- Constant);
13122+ SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
13123+ ApplyPredicate(Pred, C, LHSExtOp);
13124+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
1312313125 }
1312413126
1312513127 unsigned RHSOpcode = RHS->getOpcode();
@@ -13160,7 +13162,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1316013162 TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1316113163 return SDValue();
1316213164
13163- RHSExtOp = tryPredicate (Pred, RHSExtOp);
13165+ ApplyPredicate (Pred, RHSExtOp, LHSExtOp );
1316413166 return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
1316513167}
1316613168
0 commit comments