diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 995ae75da1c30..3b69edacb8982 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -17867,6 +17867,7 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, SmallVector Worklist; SmallPtrSet Inserted; + SmallPtrSet ExtensionsToRemove; Worklist.push_back(N); Inserted.insert(N); SmallVector CombinesToApply; @@ -17876,22 +17877,25 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, NodeExtensionHelper LHS(Root, 0, DAG, Subtarget); NodeExtensionHelper RHS(Root, 1, DAG, Subtarget); - auto AppendUsersIfNeeded = [&Worklist, &Subtarget, - &Inserted](const NodeExtensionHelper &Op) { - if (Op.needToPromoteOtherUsers()) { - for (SDUse &Use : Op.OrigOperand->uses()) { - SDNode *TheUser = Use.getUser(); - if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) - return false; - // We only support the first 2 operands of FMA. - if (Use.getOperandNo() >= 2) - return false; - if (Inserted.insert(TheUser).second) - Worklist.push_back(TheUser); - } - } - return true; - }; + auto AppendUsersIfNeeded = + [&Worklist, &Subtarget, &Inserted, + &ExtensionsToRemove](const NodeExtensionHelper &Op) { + if (Op.needToPromoteOtherUsers()) { + // Remember that we're supposed to remove this extension. + ExtensionsToRemove.insert(Op.OrigOperand.getNode()); + for (SDUse &Use : Op.OrigOperand->uses()) { + SDNode *TheUser = Use.getUser(); + if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget)) + return false; + // We only support the first 2 operands of FMA. + if (Use.getOperandNo() >= 2) + return false; + if (Inserted.insert(TheUser).second) + Worklist.push_back(TheUser); + } + } + return true; + }; // Control the compile time by limiting the number of node we look at in // total. @@ -17912,6 +17916,15 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N, std::optional Res = FoldingStrategy(Root, LHS, RHS, DAG, Subtarget); if (Res) { + // If this strategy wouldn't remove an extension we're supposed to + // remove, reject it. + if (!Res->LHSExt.has_value() && + ExtensionsToRemove.contains(LHS.OrigOperand.getNode())) + continue; + if (!Res->RHSExt.has_value() && + ExtensionsToRemove.contains(RHS.OrigOperand.getNode())) + continue; + Matched = true; CombinesToApply.push_back(*Res); // All the inputs that are extended need to be folded, otherwise diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll index b1f0eee3e9f52..034186210513c 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll @@ -595,12 +595,11 @@ define @mismatched_extend_sub_add_commuted( ; FOLDING: # %bb.0: ; FOLDING-NEXT: vsetvli a0, zero, e32, m2, ta, ma ; FOLDING-NEXT: vzext.vf2 v10, v8 -; FOLDING-NEXT: vsext.vf2 v12, v9 ; FOLDING-NEXT: vsetvli zero, zero, e16, m1, ta, ma -; FOLDING-NEXT: vwsub.wv v10, v10, v9 -; FOLDING-NEXT: vwaddu.wv v12, v12, v8 +; FOLDING-NEXT: vwsub.wv v12, v10, v9 +; FOLDING-NEXT: vwadd.wv v10, v10, v9 ; FOLDING-NEXT: vsetvli zero, zero, e32, m2, ta, ma -; FOLDING-NEXT: vmul.vv v8, v10, v12 +; FOLDING-NEXT: vmul.vv v8, v12, v10 ; FOLDING-NEXT: ret %a = zext %x to %b = sext %y to