diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index f8900f3434cca..0241ec4f2111d 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -308,10 +308,12 @@ std::optional ConstantFoldIntToFloat(unsigned Opcode, LLT DstTy, Register Src, const MachineRegisterInfo &MRI); -/// Tries to constant fold a G_CTLZ operation on \p Src. If \p Src is a vector -/// then it tries to do an element-wise constant fold. +/// Tries to constant fold a counting-zero operation (G_CTLZ or G_CTTZ) on \p +/// Src. If \p Src is a vector then it tries to do an element-wise constant +/// fold. std::optional> -ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI); +ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI, + std::function CB); /// Test if the given value is known to have exactly one bit set. This differs /// from computeKnownBits in that it doesn't necessarily determine which bit is diff --git a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp index 1869e0d41a51f..a0bc325c6cda7 100644 --- a/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp @@ -256,10 +256,16 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc, return buildFConstant(DstOps[0], *Cst); break; } - case TargetOpcode::G_CTLZ: { + case TargetOpcode::G_CTLZ: + case TargetOpcode::G_CTTZ: { assert(SrcOps.size() == 1 && "Expected one source"); assert(DstOps.size() == 1 && "Expected one dest"); - auto MaybeCsts = ConstantFoldCTLZ(SrcOps[0].getReg(), *getMRI()); + std::function CB; + if (Opc == TargetOpcode::G_CTLZ) + CB = [](APInt V) -> unsigned { return V.countl_zero(); }; + else + CB = [](APInt V) -> unsigned { return V.countTrailingZeros(); }; + auto MaybeCsts = ConstantFoldCountZeros(SrcOps[0].getReg(), *getMRI(), CB); if (!MaybeCsts) break; if (MaybeCsts->size() == 1) diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index a9fa73b60a097..8c41f8b1bdcdb 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -966,14 +966,15 @@ llvm::ConstantFoldIntToFloat(unsigned Opcode, LLT DstTy, Register Src, } std::optional> -llvm::ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI) { +llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI, + std::function CB) { LLT Ty = MRI.getType(Src); SmallVector FoldedCTLZs; auto tryFoldScalar = [&](Register R) -> std::optional { auto MaybeCst = getIConstantVRegVal(R, MRI); if (!MaybeCst) return std::nullopt; - return MaybeCst->countl_zero(); + return CB(*MaybeCst); }; if (Ty.isVector()) { // Try to constant fold each element. diff --git a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp index 116099eff14aa..08857de3cf4e4 100644 --- a/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/CSETest.cpp @@ -233,4 +233,46 @@ TEST_F(AArch64GISelMITest, TestConstantFoldCTL) { EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; } +TEST_F(AArch64GISelMITest, TestConstantFoldCTT) { + setUp(); + if (!TM) + GTEST_SKIP(); + + LLT s32 = LLT::scalar(32); + + GISelCSEInfo CSEInfo; + CSEInfo.setCSEConfig(std::make_unique()); + CSEInfo.analyze(*MF); + B.setCSEInfo(&CSEInfo); + CSEMIRBuilder CSEB(B.getState()); + auto Cst8 = CSEB.buildConstant(s32, 8); + auto *CttzDef = &*CSEB.buildCTTZ(s32, Cst8); + EXPECT_TRUE(CttzDef->getOpcode() == TargetOpcode::G_CONSTANT); + EXPECT_TRUE(CttzDef->getOperand(1).getCImm()->getZExtValue() == 3); + + // Test vector. + auto Cst16 = CSEB.buildConstant(s32, 16); + auto Cst32 = CSEB.buildConstant(s32, 32); + auto Cst64 = CSEB.buildConstant(s32, 64); + LLT VecTy = LLT::fixed_vector(4, s32); + auto BV = CSEB.buildBuildVector(VecTy, {Cst8.getReg(0), Cst16.getReg(0), + Cst32.getReg(0), Cst64.getReg(0)}); + CSEB.buildCTTZ(VecTy, BV); + + auto CheckStr = R"( + ; CHECK: [[CST8:%[0-9]+]]:_(s32) = G_CONSTANT i32 8 + ; CHECK: [[CST3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3 + ; CHECK: [[CST16:%[0-9]+]]:_(s32) = G_CONSTANT i32 16 + ; CHECK: [[CST32:%[0-9]+]]:_(s32) = G_CONSTANT i32 32 + ; CHECK: [[CST64:%[0-9]+]]:_(s32) = G_CONSTANT i32 64 + ; CHECK: [[BV1:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[CST8]]:_(s32), [[CST16]]:_(s32), [[CST32]]:_(s32), [[CST64]]:_(s32) + ; CHECK: [[CST27:%[0-9]+]]:_(s32) = G_CONSTANT i32 4 + ; CHECK: [[CST26:%[0-9]+]]:_(s32) = G_CONSTANT i32 5 + ; CHECK: [[CST25:%[0-9]+]]:_(s32) = G_CONSTANT i32 6 + ; CHECK: [[BV2:%[0-9]+]]:_(<4 x s32>) = G_BUILD_VECTOR [[CST3]]:_(s32), [[CST27]]:_(s32), [[CST26]]:_(s32), [[CST25]]:_(s32) + )"; + + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} + } // namespace