diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 74c4bba72ee90..169ff9d22f989 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -10435,14 +10435,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case ISD::STORE: { auto *Store = cast(N); SDValue Val = Store->getValue(); - // Combine store of vmv.x.s to vse with VL of 1. - // FIXME: Support FP. - if (Val.getOpcode() == RISCVISD::VMV_X_S) { + // Combine store of vmv.x.s/vfmv.f.s to vse with VL of 1. + // vfmv.f.s is represented as extract element from 0. Match it late to avoid + // any illegal types. + if (Val.getOpcode() == RISCVISD::VMV_X_S || + (DCI.isAfterLegalizeDAG() && + Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + isNullConstant(Val.getOperand(1)))) { SDValue Src = Val.getOperand(0); MVT VecVT = Src.getSimpleValueType(); EVT MemVT = Store->getMemoryVT(); - // The memory VT and the element type must match. - if (MemVT == VecVT.getVectorElementType()) { + // VecVT should be scalable and memory VT should match the element type. + if (VecVT.isScalableVector() && + MemVT == VecVT.getVectorElementType()) { SDLoc DL(N); MVT MaskVT = getMaskTypeFor(VecVT); return DAG.getStoreVP( diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index aec02162a2fd9..c07bb775c7968 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -1035,15 +1035,6 @@ foreach fvti = AllFloatVectors in { //===----------------------------------------------------------------------===// let Predicates = [HasVInstructionsAnyF] in foreach vti = AllFloatVectors in { - // Fold store of vmv.f.s to a vse with VL=1. - defvar store_instr = !cast("PseudoVSE"#vti.SEW#"_V_"#vti.LMul.MX); - - let AddedComplexity = 2 in { - // Add complexity to increase the priority of this pattern being matched. - def : Pat<(store (extractelt (vti.Vector vti.RegClass:$rs2), 0), GPR:$rs1), - (store_instr vti.RegClass:$rs2, GPR:$rs1, 1, vti.Log2SEW)>; - } - defvar vmv_f_s_inst = !cast(!strconcat("PseudoVFMV_", vti.ScalarSuffix, "_S_", vti.LMul.MX));