diff --git a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp index 64ae4e94a8c92..3a865fcf921b0 100644 --- a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp @@ -40,6 +40,30 @@ static LegalityPredicate typeIsScalarFPArith(unsigned TypeIdx, }; } +static LegalityPredicate +typeIsLegalIntOrFPVec(unsigned TypeIdx, + std::initializer_list IntOrFPVecTys, + const RISCVSubtarget &ST) { + LegalityPredicate P = [=, &ST](const LegalityQuery &Query) { + return ST.hasVInstructions() && + (Query.Types[TypeIdx].getScalarSizeInBits() != 64 || + ST.hasVInstructionsI64()) && + (Query.Types[TypeIdx].getElementCount().getKnownMinValue() != 1 || + ST.getELen() == 64); + }; + + return all(typeInSet(TypeIdx, IntOrFPVecTys), P); +} + +static LegalityPredicate +typeIsLegalBoolVec(unsigned TypeIdx, std::initializer_list BoolVecTys, + const RISCVSubtarget &ST) { + LegalityPredicate HasV = [=, &ST](const LegalityQuery &Query) { + return ST.hasVInstructions(); + }; + return all(typeInSet(TypeIdx, BoolVecTys), HasV); +} + RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) : STI(ST), XLen(STI.getXLen()), sXLen(LLT::scalar(XLen)) { const LLT sDoubleXLen = LLT::scalar(2 * XLen); @@ -50,6 +74,14 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) const LLT s32 = LLT::scalar(32); const LLT s64 = LLT::scalar(64); + const LLT nxv1s1 = LLT::scalable_vector(1, s1); + const LLT nxv2s1 = LLT::scalable_vector(2, s1); + const LLT nxv4s1 = LLT::scalable_vector(4, s1); + const LLT nxv8s1 = LLT::scalable_vector(8, s1); + const LLT nxv16s1 = LLT::scalable_vector(16, s1); + const LLT nxv32s1 = LLT::scalable_vector(32, s1); + const LLT nxv64s1 = LLT::scalable_vector(64, s1); + const LLT nxv1s8 = LLT::scalable_vector(1, s8); const LLT nxv2s8 = LLT::scalable_vector(2, s8); const LLT nxv4s8 = LLT::scalable_vector(4, s8); @@ -78,22 +110,16 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST) using namespace TargetOpcode; - auto AllVecTys = {nxv1s8, nxv2s8, nxv4s8, nxv8s8, nxv16s8, nxv32s8, - nxv64s8, nxv1s16, nxv2s16, nxv4s16, nxv8s16, nxv16s16, - nxv32s16, nxv1s32, nxv2s32, nxv4s32, nxv8s32, nxv16s32, - nxv1s64, nxv2s64, nxv4s64, nxv8s64}; + auto BoolVecTys = {nxv1s1, nxv2s1, nxv4s1, nxv8s1, nxv16s1, nxv32s1, nxv64s1}; + + auto IntOrFPVecTys = {nxv1s8, nxv2s8, nxv4s8, nxv8s8, nxv16s8, nxv32s8, + nxv64s8, nxv1s16, nxv2s16, nxv4s16, nxv8s16, nxv16s16, + nxv32s16, nxv1s32, nxv2s32, nxv4s32, nxv8s32, nxv16s32, + nxv1s64, nxv2s64, nxv4s64, nxv8s64}; getActionDefinitionsBuilder({G_ADD, G_SUB, G_AND, G_OR, G_XOR}) .legalFor({s32, sXLen}) - .legalIf(all( - typeInSet(0, AllVecTys), - LegalityPredicate([=, &ST](const LegalityQuery &Query) { - return ST.hasVInstructions() && - (Query.Types[0].getScalarSizeInBits() != 64 || - ST.hasVInstructionsI64()) && - (Query.Types[0].getElementCount().getKnownMinValue() != 1 || - ST.getELen() == 64); - }))) + .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST)) .widenScalarToNextPow2(0) .clampScalar(0, s32, sXLen);