diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 7c80982b0d7a49..598024c87724b5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -401,37 +401,6 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { if (!EC.isScalable() && IndexC->getValue().uge(NumElts)) return nullptr; - // This instruction only demands the single element from the input vector. - // Skip for scalable type, the number of elements is unknown at - // compile-time. - if (!EC.isScalable() && NumElts != 1) { - // If the input vector has a single use, simplify it based on this use - // property. - if (SrcVec->hasOneUse()) { - APInt UndefElts(NumElts, 0); - APInt DemandedElts(NumElts, 0); - DemandedElts.setBit(IndexC->getZExtValue()); - if (Value *V = - SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) - return replaceOperand(EI, 0, V); - } else { - // If the input vector has multiple uses, simplify it based on a union - // of all elements used. - APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); - if (!DemandedElts.isAllOnes()) { - APInt UndefElts(NumElts, 0); - if (Value *V = SimplifyDemandedVectorElts( - SrcVec, DemandedElts, UndefElts, 0 /* Depth */, - true /* AllowMultipleUsers */)) { - if (V != SrcVec) { - SrcVec->replaceAllUsesWith(V); - return &EI; - } - } - } - } - } - if (Instruction *I = foldBitcastExtElt(EI)) return I; @@ -549,6 +518,44 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { } } } + + // Run demanded elements after other transforms as this can drop flags on + // binops. If there's two paths to the same final result, we prefer the + // one which doesn't force us to drop flags. + if (IndexC) { + ElementCount EC = EI.getVectorOperandType()->getElementCount(); + unsigned NumElts = EC.getKnownMinValue(); + // This instruction only demands the single element from the input vector. + // Skip for scalable type, the number of elements is unknown at + // compile-time. + if (!EC.isScalable() && NumElts != 1) { + // If the input vector has a single use, simplify it based on this use + // property. + if (SrcVec->hasOneUse()) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = + SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) + return replaceOperand(EI, 0, V); + } else { + // If the input vector has multiple uses, simplify it based on a union + // of all elements used. + APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); + if (!DemandedElts.isAllOnes()) { + APInt UndefElts(NumElts, 0); + if (Value *V = SimplifyDemandedVectorElts( + SrcVec, DemandedElts, UndefElts, 0 /* Depth */, + true /* AllowMultipleUsers */)) { + if (V != SrcVec) { + SrcVec->replaceAllUsesWith(V); + return &EI; + } + } + } + } + } + } return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/X86/x86-avx512-inseltpoison.ll b/llvm/test/Transforms/InstCombine/X86/x86-avx512-inseltpoison.ll index 630770645f5d0b..a4c84905b5dacc 100644 --- a/llvm/test/Transforms/InstCombine/X86/x86-avx512-inseltpoison.ll +++ b/llvm/test/Transforms/InstCombine/X86/x86-avx512-inseltpoison.ll @@ -921,8 +921,8 @@ declare float @llvm.fma.f32(float, float, float) #1 define <4 x float> @test_mask_vfmadd_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask_vfmadd_ss( ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 -; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]]) ; CHECK-NEXT: [[TMP5:%.*]] = and i8 [[MASK:%.*]], 1 ; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP5]], 0 @@ -1063,8 +1063,8 @@ define double @test_mask_vfmadd_sd_1(<2 x double> %a, <2 x double> %b, <2 x doub define <4 x float> @test_maskz_vfmadd_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_maskz_vfmadd_ss( ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 -; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]]) ; CHECK-NEXT: [[TMP5:%.*]] = and i8 [[MASK:%.*]], 1 ; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP5]], 0 @@ -1202,8 +1202,8 @@ define double @test_maskz_vfmadd_sd_1(<2 x double> %a, <2 x double> %b, <2 x dou define <4 x float> @test_mask3_vfmadd_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask3_vfmadd_ss( -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 ; CHECK-NEXT: [[TMP4:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]]) ; CHECK-NEXT: [[TMP5:%.*]] = and i8 [[MASK:%.*]], 1 @@ -1342,8 +1342,8 @@ define double @test_mask3_vfmadd_sd_1(<2 x double> %a, <2 x double> %b, <2 x dou define <4 x float> @test_mask3_vfmsub_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask3_vfmsub_ss( -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 ; CHECK-NEXT: [[TMP4:%.*]] = fneg float [[TMP3]] ; CHECK-NEXT: [[TMP5:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP4]]) @@ -1542,9 +1542,9 @@ define double @test_mask3_vfmsub_sd_1_unary_fneg(<2 x double> %a, <2 x double> % define <4 x float> @test_mask3_vfnmsub_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask3_vfnmsub_ss( -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 ; CHECK-NEXT: [[TMP2:%.*]] = fneg float [[TMP1]] -; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 ; CHECK-NEXT: [[TMP5:%.*]] = fneg float [[TMP4]] ; CHECK-NEXT: [[TMP6:%.*]] = call float @llvm.fma.f32(float [[TMP2]], float [[TMP3]], float [[TMP5]]) diff --git a/llvm/test/Transforms/InstCombine/X86/x86-avx512.ll b/llvm/test/Transforms/InstCombine/X86/x86-avx512.ll index 1d48351ae233af..7772946751c5c5 100644 --- a/llvm/test/Transforms/InstCombine/X86/x86-avx512.ll +++ b/llvm/test/Transforms/InstCombine/X86/x86-avx512.ll @@ -921,8 +921,8 @@ declare float @llvm.fma.f32(float, float, float) #1 define <4 x float> @test_mask_vfmadd_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask_vfmadd_ss( ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 -; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]]) ; CHECK-NEXT: [[TMP5:%.*]] = and i8 [[MASK:%.*]], 1 ; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP5]], 0 @@ -1063,8 +1063,8 @@ define double @test_mask_vfmadd_sd_1(<2 x double> %a, <2 x double> %b, <2 x doub define <4 x float> @test_maskz_vfmadd_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_maskz_vfmadd_ss( ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 -; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]]) ; CHECK-NEXT: [[TMP5:%.*]] = and i8 [[MASK:%.*]], 1 ; CHECK-NEXT: [[DOTNOT:%.*]] = icmp eq i8 [[TMP5]], 0 @@ -1202,8 +1202,8 @@ define double @test_maskz_vfmadd_sd_1(<2 x double> %a, <2 x double> %b, <2 x dou define <4 x float> @test_mask3_vfmadd_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask3_vfmadd_ss( -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 ; CHECK-NEXT: [[TMP4:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]]) ; CHECK-NEXT: [[TMP5:%.*]] = and i8 [[MASK:%.*]], 1 @@ -1342,8 +1342,8 @@ define double @test_mask3_vfmadd_sd_1(<2 x double> %a, <2 x double> %b, <2 x dou define <4 x float> @test_mask3_vfmsub_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask3_vfmsub_ss( -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 ; CHECK-NEXT: [[TMP4:%.*]] = fneg float [[TMP3]] ; CHECK-NEXT: [[TMP5:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP4]]) @@ -1542,9 +1542,9 @@ define double @test_mask3_vfmsub_sd_1_unary_fneg(<2 x double> %a, <2 x double> % define <4 x float> @test_mask3_vfnmsub_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) { ; CHECK-LABEL: @test_mask3_vfnmsub_ss( -; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 ; CHECK-NEXT: [[TMP2:%.*]] = fneg float [[TMP1]] -; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 ; CHECK-NEXT: [[TMP5:%.*]] = fneg float [[TMP4]] ; CHECK-NEXT: [[TMP6:%.*]] = call float @llvm.fma.f32(float [[TMP2]], float [[TMP3]], float [[TMP5]]) diff --git a/llvm/test/Transforms/InstCombine/X86/x86-fma.ll b/llvm/test/Transforms/InstCombine/X86/x86-fma.ll index 4893131baf23b1..cddb1bf9c4e00a 100644 --- a/llvm/test/Transforms/InstCombine/X86/x86-fma.ll +++ b/llvm/test/Transforms/InstCombine/X86/x86-fma.ll @@ -5,8 +5,8 @@ target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" define <4 x float> @test_vfmadd_ss(<4 x float> %a, <4 x float> %b, <4 x float> %c) { ; CHECK-LABEL: @test_vfmadd_ss( ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[A:%.*]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i64 0 -; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> [[B:%.*]], i32 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[C:%.*]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = call float @llvm.fma.f32(float [[TMP1]], float [[TMP2]], float [[TMP3]]) ; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x float> [[A]], float [[TMP4]], i64 0 ; CHECK-NEXT: ret <4 x float> [[TMP5]] diff --git a/llvm/test/Transforms/InstCombine/scalarization-inseltpoison.ll b/llvm/test/Transforms/InstCombine/scalarization-inseltpoison.ll index f4712fc3a908d9..a7b233e5f71bb7 100644 --- a/llvm/test/Transforms/InstCombine/scalarization-inseltpoison.ll +++ b/llvm/test/Transforms/InstCombine/scalarization-inseltpoison.ll @@ -228,7 +228,7 @@ define i32 @extelt_binop_binop_insertelt(<4 x i32> %A, <4 x i32> %B, i32 %f) { ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], [[F:%.*]] ; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[B]], i32 0 -; CHECK-NEXT: [[E:%.*]] = mul i32 [[TMP2]], [[TMP3]] +; CHECK-NEXT: [[E:%.*]] = mul nsw i32 [[TMP2]], [[TMP3]] ; CHECK-NEXT: ret i32 [[E]] ; %v = insertelement <4 x i32> %A, i32 %f, i32 0 diff --git a/llvm/test/Transforms/InstCombine/scalarization.ll b/llvm/test/Transforms/InstCombine/scalarization.ll index 586509542b8b82..08cc534b76f8e2 100644 --- a/llvm/test/Transforms/InstCombine/scalarization.ll +++ b/llvm/test/Transforms/InstCombine/scalarization.ll @@ -67,8 +67,8 @@ define void @scalarize_phi(i32 * %n, float * %inout) { ; CHECK-NEXT: [[TMP0:%.*]] = phi float [ [[T0]], [[ENTRY:%.*]] ], [ [[TMP1:%.*]], [[FOR_BODY:%.*]] ] ; CHECK-NEXT: [[I_0:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[INC:%.*]], [[FOR_BODY]] ] ; CHECK-NEXT: [[T1:%.*]] = load i32, i32* [[N:%.*]], align 4 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[I_0]], [[T1]] -; CHECK-NEXT: br i1 [[CMP]], label [[FOR_END:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i32 [[I_0]], [[T1]] +; CHECK-NEXT: br i1 [[CMP_NOT]], label [[FOR_END:%.*]], label [[FOR_BODY]] ; CHECK: for.body: ; CHECK-NEXT: store volatile float [[TMP0]], float* [[INOUT]], align 4 ; CHECK-NEXT: [[TMP1]] = fmul float [[TMP0]], 0x4002A3D700000000 @@ -221,14 +221,12 @@ define float @extelt_binop_insertelt(<4 x float> %A, <4 x float> %B, float %f) { } ; We recurse to find a scalarizable operand. -; FIXME: We should propagate the IR flags including wrapping flags. - define i32 @extelt_binop_binop_insertelt(<4 x i32> %A, <4 x i32> %B, i32 %f) { ; CHECK-LABEL: @extelt_binop_binop_insertelt( ; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[B:%.*]], i32 0 ; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], [[F:%.*]] ; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[B]], i32 0 -; CHECK-NEXT: [[E:%.*]] = mul i32 [[TMP2]], [[TMP3]] +; CHECK-NEXT: [[E:%.*]] = mul nsw i32 [[TMP2]], [[TMP3]] ; CHECK-NEXT: ret i32 [[E]] ; %v = insertelement <4 x i32> %A, i32 %f, i32 0