From c223655e2b48b50dc246e6d0abbe1eaf14efffa8 Mon Sep 17 00:00:00 2001 From: Matthew Devereau Date: Wed, 24 Sep 2025 12:54:25 +0000 Subject: [PATCH 1/3] [InstCombine] Fold selects into masked loads Selects can be folded into masked loads if the masks are identical --- .../InstCombine/InstCombineSelect.cpp | 11 +++++++++ .../InstCombine/select-masked_load.ll | 24 +++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 4ea75409252bd..50dbc965a3b9b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -4611,5 +4611,16 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { return replaceOperand(SI, 2, ConstantInt::get(FalseVal->getType(), 0)); } + Value *MaskedLoadPtr; + const APInt *MaskedLoadAlignment; + if (match(TrueVal, + m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment), + m_Specific(CondVal), m_Value()))) + return replaceInstUsesWith( + SI, Builder.CreateMaskedLoad( + TrueVal->getType(), MaskedLoadPtr, + llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal, + FalseVal)); + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/select-masked_load.ll b/llvm/test/Transforms/InstCombine/select-masked_load.ll index b6bac612d6f9b..650b0b79c7cbf 100644 --- a/llvm/test/Transforms/InstCombine/select-masked_load.ll +++ b/llvm/test/Transforms/InstCombine/select-masked_load.ll @@ -26,8 +26,7 @@ define <4 x i32> @masked_load_and_zero_inactive_2(ptr %ptr, <4 x i1> %mask) { ; No transform when the load's passthrough cannot be reused or altered. define <4 x i32> @masked_load_and_zero_inactive_3(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthrough) { ; CHECK-LABEL: @masked_load_and_zero_inactive_3( -; CHECK-NEXT: [[LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHROUGH:%.*]]) -; CHECK-NEXT: [[MASKED:%.*]] = select <4 x i1> [[MASK]], <4 x i32> [[LOAD]], <4 x i32> zeroinitializer +; CHECK-NEXT: [[MASKED:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> zeroinitializer) ; CHECK-NEXT: ret <4 x i32> [[MASKED]] ; %load = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough) @@ -116,6 +115,27 @@ entry: ret <8 x float> %1 } +define @fold_sel_into_masked_load_scalable(ptr %loc, %mask, %passthrough) { +; CHECK-LABEL: @fold_sel_into_masked_load_scalable( +; CHECK-NEXT: [[SEL:%.*]] = call @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, [[MASK:%.*]], [[PASSTHROUGH:%.*]]) +; CHECK-NEXT: ret [[SEL]] +; + %load = call @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, %mask, zeroinitializer) + %sel = select %mask, %load, %passthrough + ret %sel +} + +define @neg_fold_sel_into_masked_load_mask_mismatch(ptr %loc, %mask, %mask2, %passthrough) { +; CHECK-LABEL: @neg_fold_sel_into_masked_load_mask_mismatch( +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, [[MASK:%.*]], [[PASSTHROUGH:%.*]]) +; CHECK-NEXT: [[SEL:%.*]] = select [[MASK2:%.*]], [[LOAD]], [[PASSTHROUGH]] +; CHECK-NEXT: ret [[SEL]] +; + %load = call @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, %mask, %passthrough) + %sel = select %mask2, %load, %passthrough + ret %sel +} + declare <8 x float> @llvm.masked.load.v8f32.p0(ptr, i32 immarg, <8 x i1>, <8 x float>) declare <4 x i32> @llvm.masked.load.v4i32.p0(ptr, i32 immarg, <4 x i1>, <4 x i32>) declare <4 x float> @llvm.masked.load.v4f32.p0(ptr, i32 immarg, <4 x i1>, <4 x float>) From c802e73bc1b7abda1419b2ae40934245a00c1e34 Mon Sep 17 00:00:00 2001 From: Matthew Devereau Date: Wed, 24 Sep 2025 14:04:56 +0000 Subject: [PATCH 2/3] Add hasOneUse check and test --- .../Transforms/InstCombine/InstCombineSelect.cpp | 3 ++- .../Transforms/InstCombine/select-masked_load.ll | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 50dbc965a3b9b..6ffaebf425394 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -4613,7 +4613,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *MaskedLoadPtr; const APInt *MaskedLoadAlignment; - if (match(TrueVal, + if (TrueVal->hasOneUse() && + match(TrueVal, m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment), m_Specific(CondVal), m_Value()))) return replaceInstUsesWith( diff --git a/llvm/test/Transforms/InstCombine/select-masked_load.ll b/llvm/test/Transforms/InstCombine/select-masked_load.ll index 650b0b79c7cbf..22e30ac019a5d 100644 --- a/llvm/test/Transforms/InstCombine/select-masked_load.ll +++ b/llvm/test/Transforms/InstCombine/select-masked_load.ll @@ -136,6 +136,19 @@ define @neg_fold_sel_into_masked_load_mask_mismatch(ptr %lo ret %sel } +define @fold_sel_into_masked_load_scalable_one_use_check(ptr %loc1, %mask, %passthrough, ptr %loc2) { +; CHECK-LABEL: @fold_sel_into_masked_load_scalable_one_use_check( +; CHECK-NEXT: [[LOAD:%.*]] = call @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, [[MASK:%.*]], zeroinitializer) +; CHECK-NEXT: [[SEL:%.*]] = select [[MASK]], [[LOAD]], [[PASSTHROUGH:%.*]] +; CHECK-NEXT: call void @llvm.masked.store.nxv4f32.p0( [[LOAD]], ptr [[LOC2:%.*]], i32 1, [[MASK]]) +; CHECK-NEXT: ret [[SEL]] +; + %load = call @llvm.masked.load.nxv4f32.p0(ptr %loc1, i32 1, %mask, zeroinitializer) + %sel = select %mask, %load, %passthrough + call void @llvm.masked.store.nxv4f32.p0( %load, ptr %loc2, i32 1, %mask) + ret %sel +} + declare <8 x float> @llvm.masked.load.v8f32.p0(ptr, i32 immarg, <8 x i1>, <8 x float>) declare <4 x i32> @llvm.masked.load.v4i32.p0(ptr, i32 immarg, <4 x i1>, <4 x i32>) declare <4 x float> @llvm.masked.load.v4f32.p0(ptr, i32 immarg, <4 x i1>, <4 x float>) From 438081406d4eaa4db7b9d991148ae7aedf31b4bf Mon Sep 17 00:00:00 2001 From: Matthew Devereau Date: Wed, 24 Sep 2025 15:17:55 +0000 Subject: [PATCH 3/3] Add suggested changes --- .../Transforms/InstCombine/InstCombineSelect.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 6ffaebf425394..b6b3a95f35c76 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -4613,15 +4613,13 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *MaskedLoadPtr; const APInt *MaskedLoadAlignment; - if (TrueVal->hasOneUse() && - match(TrueVal, - m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment), - m_Specific(CondVal), m_Value()))) + if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr), + m_APInt(MaskedLoadAlignment), + m_Specific(CondVal), m_Value())))) return replaceInstUsesWith( - SI, Builder.CreateMaskedLoad( - TrueVal->getType(), MaskedLoadPtr, - llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal, - FalseVal)); + SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr, + Align(MaskedLoadAlignment->getZExtValue()), + CondVal, FalseVal)); return nullptr; }