diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp index 3b2cf31910927..4b675e8da691c 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -7946,13 +7946,11 @@ LegalizerHelper::lowerISFPCLASS(MachineInstr &MI) { } LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) { - // Implement vector G_SELECT in terms of XOR, AND, OR. + // Implement G_SELECT in terms of XOR, AND, OR. auto [DstReg, DstTy, MaskReg, MaskTy, Op1Reg, Op1Ty, Op2Reg, Op2Ty] = MI.getFirst4RegLLTs(); - if (!DstTy.isVector()) - return UnableToLegalize; - bool IsEltPtr = DstTy.getElementType().isPointer(); + bool IsEltPtr = DstTy.getScalarType().isPointer(); if (IsEltPtr) { LLT ScalarPtrTy = LLT::scalar(DstTy.getScalarSizeInBits()); LLT NewTy = DstTy.changeElementType(ScalarPtrTy); @@ -7962,7 +7960,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) { } if (MaskTy.isScalar()) { - // Turn the scalar condition into a vector condition mask. + // Turn the scalar condition into a vector condition mask if needed. Register MaskElt = MaskReg; @@ -7972,13 +7970,20 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) { MaskElt = MIRBuilder.buildSExtInReg(MaskTy, MaskElt, 1).getReg(0); // Continue the sign extension (or truncate) to match the data type. - MaskElt = MIRBuilder.buildSExtOrTrunc(DstTy.getElementType(), - MaskElt).getReg(0); + MaskElt = + MIRBuilder.buildSExtOrTrunc(DstTy.getScalarType(), MaskElt).getReg(0); - // Generate a vector splat idiom. - auto ShufSplat = MIRBuilder.buildShuffleSplat(DstTy, MaskElt); - MaskReg = ShufSplat.getReg(0); + if (DstTy.isVector()) { + // Generate a vector splat idiom. + auto ShufSplat = MIRBuilder.buildShuffleSplat(DstTy, MaskElt); + MaskReg = ShufSplat.getReg(0); + } else { + MaskReg = MaskElt; + } MaskTy = DstTy; + } else if (!DstTy.isVector()) { + // Cannot handle the case that mask is a vector and dst is a scalar. + return UnableToLegalize; } if (MaskTy.getSizeInBits() != DstTy.getSizeInBits()) { diff --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp index d7876b7ce8749..73837279701a9 100644 --- a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp @@ -3431,6 +3431,47 @@ TEST_F(AArch64GISelMITest, LowerUDIVREM) { EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; } +// Test G_SELECT lowering. +// Note: This is for testing the legalizer, aarch64 does not lower scalar +// selects like this. +TEST_F(AArch64GISelMITest, LowerSelect) { + setUp(); + if (!TM) + GTEST_SKIP(); + + // Declare your legalization info + DefineLegalizerInfo(A, { getActionDefinitionsBuilder(G_SELECT).lower(); }); + + LLT S1 = LLT::scalar(1); + LLT S32 = LLT::scalar(32); + auto Tst = B.buildTrunc(S1, Copies[0]); + auto SrcA = B.buildTrunc(S32, Copies[1]); + auto SrcB = B.buildTrunc(S32, Copies[2]); + auto SELECT = B.buildInstr(TargetOpcode::G_SELECT, {S32}, {Tst, SrcA, SrcB}); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + // Perform Legalization + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*SELECT, 0, S32)); + + auto CheckStr = R"( + CHECK: [[TST:%[0-9]+]]:_(s1) = G_TRUNC + CHECK: [[TRUE:%[0-9]+]]:_(s32) = G_TRUNC + CHECK: [[FALSE:%[0-9]+]]:_(s32) = G_TRUNC + CHECK: [[MSK:%[0-9]+]]:_(s32) = G_SEXT [[TST]] + CHECK: [[M:%[0-9]+]]:_(s32) = G_CONSTANT i32 -1 + CHECK: [[NEGMSK:%[0-9]+]]:_(s32) = G_XOR [[MSK]]:_, [[M]]:_ + CHECK: [[TVAL:%[0-9]+]]:_(s32) = G_AND [[TRUE]]:_, [[MSK]]:_ + CHECK: [[FVAL:%[0-9]+]]:_(s32) = G_AND [[FALSE]]:_, [[NEGMSK]]:_ + CHECK: [[RES:%[0-9]+]]:_(s32) = G_OR [[TVAL]]:_, [[FVAL]]:_ + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} + // Test widening of G_UNMERGE_VALUES TEST_F(AArch64GISelMITest, WidenUnmerge) { setUp();