-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[InstCombine] Fold selects into masked loads #160522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Selects can be folded into masked loads if the masks are identical
@llvm/pr-subscribers-llvm-transforms Author: Matthew Devereau (MDevereau) ChangesSelects can be folded into masked loads if the masks are identical. Full diff: https://github.com/llvm/llvm-project/pull/160522.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4ea75409252bd..50dbc965a3b9b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4611,5 +4611,16 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
return replaceOperand(SI, 2, ConstantInt::get(FalseVal->getType(), 0));
}
+ Value *MaskedLoadPtr;
+ const APInt *MaskedLoadAlignment;
+ if (match(TrueVal,
+ m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment),
+ m_Specific(CondVal), m_Value())))
+ return replaceInstUsesWith(
+ SI, Builder.CreateMaskedLoad(
+ TrueVal->getType(), MaskedLoadPtr,
+ llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal,
+ FalseVal));
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/select-masked_load.ll b/llvm/test/Transforms/InstCombine/select-masked_load.ll
index b6bac612d6f9b..650b0b79c7cbf 100644
--- a/llvm/test/Transforms/InstCombine/select-masked_load.ll
+++ b/llvm/test/Transforms/InstCombine/select-masked_load.ll
@@ -26,8 +26,7 @@ define <4 x i32> @masked_load_and_zero_inactive_2(ptr %ptr, <4 x i1> %mask) {
; No transform when the load's passthrough cannot be reused or altered.
define <4 x i32> @masked_load_and_zero_inactive_3(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthrough) {
; CHECK-LABEL: @masked_load_and_zero_inactive_3(
-; CHECK-NEXT: [[LOAD:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHROUGH:%.*]])
-; CHECK-NEXT: [[MASKED:%.*]] = select <4 x i1> [[MASK]], <4 x i32> [[LOAD]], <4 x i32> zeroinitializer
+; CHECK-NEXT: [[MASKED:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr [[PTR:%.*]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> zeroinitializer)
; CHECK-NEXT: ret <4 x i32> [[MASKED]]
;
%load = call <4 x i32> @llvm.masked.load.v4i32.p0(ptr %ptr, i32 4, <4 x i1> %mask, <4 x i32> %passthrough)
@@ -116,6 +115,27 @@ entry:
ret <8 x float> %1
}
+define <vscale x 4 x float> @fold_sel_into_masked_load_scalable(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough) {
+; CHECK-LABEL: @fold_sel_into_masked_load_scalable(
+; CHECK-NEXT: [[SEL:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
+; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
+;
+ %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> zeroinitializer)
+ %sel = select <vscale x 4 x i1> %mask, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
+ ret <vscale x 4 x float> %sel
+}
+
+define <vscale x 4 x float> @neg_fold_sel_into_masked_load_mask_mismatch(ptr %loc, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask2, <vscale x 4 x float> %passthrough) {
+; CHECK-LABEL: @neg_fold_sel_into_masked_load_mask_mismatch(
+; CHECK-NEXT: [[LOAD:%.*]] = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr [[LOC:%.*]], i32 1, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x float> [[PASSTHROUGH:%.*]])
+; CHECK-NEXT: [[SEL:%.*]] = select <vscale x 4 x i1> [[MASK2:%.*]], <vscale x 4 x float> [[LOAD]], <vscale x 4 x float> [[PASSTHROUGH]]
+; CHECK-NEXT: ret <vscale x 4 x float> [[SEL]]
+;
+ %load = call <vscale x 4 x float> @llvm.masked.load.nxv4f32.p0(ptr %loc, i32 1, <vscale x 4 x i1> %mask, <vscale x 4 x float> %passthrough)
+ %sel = select <vscale x 4 x i1> %mask2, <vscale x 4 x float> %load, <vscale x 4 x float> %passthrough
+ ret <vscale x 4 x float> %sel
+}
+
declare <8 x float> @llvm.masked.load.v8f32.p0(ptr, i32 immarg, <8 x i1>, <8 x float>)
declare <4 x i32> @llvm.masked.load.v4i32.p0(ptr, i32 immarg, <4 x i1>, <4 x i32>)
declare <4 x float> @llvm.masked.load.v4f32.p0(ptr, i32 immarg, <4 x i1>, <4 x float>)
|
const APInt *MaskedLoadAlignment; | ||
if (match(TrueVal, | ||
m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment), | ||
m_Specific(CondVal), m_Value()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a one-use check? Otherwise you'll end up with two masked loads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've added a check & test for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are already using PatternMatch, the more idiomatic way to do this would be m_OneUse(m_MakedLoad(...))
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
const APInt *MaskedLoadAlignment; | ||
if (match(TrueVal, | ||
m_MaskedLoad(m_Value(MaskedLoadPtr), m_APInt(MaskedLoadAlignment), | ||
m_Specific(CondVal), m_Value()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are already using PatternMatch, the more idiomatic way to do this would be m_OneUse(m_MakedLoad(...))
.
return replaceInstUsesWith( | ||
SI, Builder.CreateMaskedLoad( | ||
TrueVal->getType(), MaskedLoadPtr, | ||
llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm::Align(MaskedLoadAlignment->getZExtValue()), CondVal, | |
Align(MaskedLoadAlignment->getZExtValue()), CondVal, |
Selects can be folded into masked loads if the masks are identical.
Selects can be folded into masked loads if the masks are identical.