diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 4aba42b014d17..e1f1c49094424 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -7242,9 +7242,14 @@ static SDValue performANY_EXTENDCombine(SDNode *N, // Try to form VWMUL or VWMULU. // FIXME: Support VWMULSU. -static SDValue combineMUL_VLToVWMUL(SDNode *N, SDValue Op0, SDValue Op1, - SelectionDAG &DAG) { +static SDValue combineMUL_VLToVWMUL_VL(SDNode *N, SelectionDAG &DAG, + bool Commute) { assert(N->getOpcode() == RISCVISD::MUL_VL && "Unexpected opcode"); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + if (Commute) + std::swap(Op0, Op1); + bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL; bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL; if ((!IsSignExt && !IsZeroExt) || !Op0.hasOneUse()) @@ -7887,15 +7892,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, } break; } - case RISCVISD::MUL_VL: { - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - if (SDValue V = combineMUL_VLToVWMUL(N, Op0, Op1, DAG)) + case RISCVISD::MUL_VL: + if (SDValue V = combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ false)) return V; - if (SDValue V = combineMUL_VLToVWMUL(N, Op1, Op0, DAG)) - return V; - return SDValue(); - } + // Mul is commutative. + return combineMUL_VLToVWMUL_VL(N, DAG, /*Commute*/ true); case ISD::STORE: { auto *Store = cast(N); SDValue Val = Store->getValue();