Skip to content

Conversation

MDevereau
Copy link
Contributor

Selects can be folded into masked loads if the masks are identical.

Selects can be folded into masked loads if the masks are identical
@MDevereau MDevereau requested a review from david-arm September 24, 2025 13:16
@MDevereau MDevereau requested a review from nikic as a code owner September 24, 2025 13:16
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels Sep 24, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 24, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Matthew Devereau (MDevereau)

Changes

Selects can be folded into masked loads if the masks are identical.


Full diff: https://github.com/llvm/llvm-project/pull/160522.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+11)
  • (modified) llvm/test/Transforms/InstCombine/select-masked_load.ll (+22-2)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4ea75409252bd..50dbc965a3b9b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4611,5 +4611,16 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
       return replaceOperand(SI, 2, ConstantInt::get(FalseVal->getType(), 0));
   }
 
+  Value *MaskedLoadPtr;
+  const APInt *MaskedLoadAlignment;
+  if (match(TrueVal,
+            m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment),
+                         m_Specific(CondVal), m_Value())))
+    return replaceInstUsesWith(
+        SI, Builder.CreateMaskedLoad(
+                TrueVal->getType(), MaskedLoadPtr,
+                llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal,
+                FalseVal));
+
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/select-masked_load.ll b/llvm/test/Transforms/InstCombine/select-masked_load.ll
index b6bac612d6f9b..650b0b79c7cbf 100644
--- a/llvm/test/Transforms/InstCombine/select-masked_load.ll
+++ b/llvm/test/Transforms/InstCombine/select-masked_load.ll
@@ -26,8 +26,7 @@ define <4 x i32> @masked_load_and_zero_inactive_2(ptr %ptr, <4 x i1> %mask) {
 ; No transform when the load's passthrough cannot be reused or altered.
 define <4 x i32> @masked_load_and_zero_inactive_3(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthrough) {
 ; CHECK-LABEL: @masked_load_and_zero_inactive_3(
-; CHECK-NEXT:    [[LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHROUGH:%.*]])
-; CHECK-NEXT:    [[MASKED:%.*]] = select <4 x i1> [[MASK]], <4 x i32> [[LOAD]], <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[MASKED:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> zeroinitializer)
 ; CHECK-NEXT:    ret <4 x i32> [[MASKED]]
 ;
   %load = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough)
@@ -116,6 +115,27 @@ entry:
   ret <8 x float> %1
 }
 
+define <vscale x 4 x float> @fold_sel_into_masked_load_scalable(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough) {
+; CHECK-LABEL: @fold_sel_into_masked_load_scalable(
+; CHECK-NEXT:    [[SEL:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
+; CHECK-NEXT:    ret <vscale x 4 x float> [[SEL]]
+;
+  %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+  %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
+  ret <vscale x 4 x float> %sel
+}
+
+define <vscale x 4 x float> @neg_fold_sel_into_masked_load_mask_mismatch(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask2, <vscale x 4 x float> %passthrough) {
+; CHECK-LABEL: @neg_fold_sel_into_masked_load_mask_mismatch(
+; CHECK-NEXT:    [[LOAD:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
+; CHECK-NEXT:    [[SEL:%.*]] = select <vscale x 4 x i1> [[MASK2:%.*]], <vscale x 4 x float> [[LOAD]], <vscale x 4 x float> [[PASSTHROUGH]]
+; CHECK-NEXT:    ret <vscale x 4 x float> [[SEL]]
+;
+  %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
+  %sel = select <vscale x 4 x i1> %mask2, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
+  ret <vscale x 4 x float> %sel
+}
+
 declare <8 x float> @llvm.masked.load.v8f32.p0(ptr, i32 immarg, <8 x i1>, <8 x float>)
 declare <4 x i32> @llvm.masked.load.v4i32.p0(ptr, i32 immarg, <4 x i1>, <4 x i32>)
 declare <4 x float> @llvm.masked.load.v4f32.p0(ptr, i32 immarg, <4 x i1>, <4 x float>)

const APInt *MaskedLoadAlignment;
if (match(TrueVal,
m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment),
m_Specific(CondVal), m_Value())))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a one-use check? Otherwise you'll end up with two masked loads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've added a check & test for it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are already using PatternMatch, the more idiomatic way to do this would be m_OneUse(m_MakedLoad(...)).

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

const APInt *MaskedLoadAlignment;
if (match(TrueVal,
m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment),
m_Specific(CondVal), m_Value())))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are already using PatternMatch, the more idiomatic way to do this would be m_OneUse(m_MakedLoad(...)).

return replaceInstUsesWith(
SI, Builder.CreateMaskedLoad(
TrueVal->getType(), MaskedLoadPtr,
llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal,
Align(MaskedLoadAlignment->getZExtValue()), CondVal,

@MDevereau MDevereau merged commit eb85899 into llvm:main Sep 24, 2025
9 checks passed
@MDevereau MDevereau deleted the select_masked_load_fold branch September 24, 2025 17:30
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
Selects can be folded into masked loads if the masks are identical.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants