Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LV] Add support for uniform parameters on vectorized function variants #72891

Merged
merged 1 commit into from
Nov 28, 2023

Conversation

huntergr-arm
Copy link
Collaborator

Parameters marked as uniform take a scalar value, assuming the value is
invariant in the scalar loop.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-llvm-transforms

Author: Graham Hunter (huntergr-arm)

Changes

Parameters marked as uniform take a scalar value, assuming the value is
invariant in the scalar loop.


Full diff: https://github.com/llvm/llvm-project/pull/72891.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+8)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll (+20-12)
  • (modified) llvm/test/Transforms/LoopVectorize/uniform-args-call-variants.ll (+6-11)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d52afc5508be947..601d668d37a1022 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7031,6 +7031,14 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
           switch (Param.ParamKind) {
           case VFParamKind::Vector:
             break;
+          case VFParamKind::OMP_Uniform: {
+            Value *ScalarParam = CI->getArgOperand(Param.ParamPos);
+            // Make sure the scalar parameter in the loop is invariant.
+            if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(ScalarParam),
+                                              TheLoop))
+              ParamsOk = false;
+            break;
+          }
           case VFParamKind::GlobalPredicate:
             UsesMask = true;
             break;
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll b/llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll
index 2c0298b849f513b..cbf697586013494 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/uniform-args-call-variants.ll
@@ -9,17 +9,25 @@ define void @test_uniform(ptr noalias %dst, ptr readonly %src, i64 %uniform , i6
 ; CHECK-LABEL: define void @test_uniform
 ; CHECK-SAME: (ptr noalias [[DST:%.*]], ptr readonly [[SRC:%.*]], i64 [[UNIFORM:%.*]], i64 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    br label [[FOR_BODY:%.*]]
-; CHECK:       for.body:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[GEPSRC:%.*]] = getelementptr double, ptr [[SRC]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    [[DATA:%.*]] = load double, ptr [[GEPSRC]], align 8
-; CHECK-NEXT:    [[CALL:%.*]] = call double @foo(double [[DATA]], i64 [[UNIFORM]]) #[[ATTR1:[0-9]+]]
-; CHECK-NEXT:    [[GEPDST:%.*]] = getelementptr inbounds double, ptr [[DST]], i64 [[INDVARS_IV]]
-; CHECK-NEXT:    store double [[CALL]], ptr [[GEPDST]], align 8
-; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp eq i64 [[INDVARS_IV_NEXT]], [[N]]
-; CHECK-NEXT:    br i1 [[EXITCOND]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 [[TMP0]], 1
+; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[N]], i64 [[TMP1]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_ENTRY:%.*]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 0, i64 [[N]])
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = phi <vscale x 2 x i1> [ [[ACTIVE_LANE_MASK_ENTRY]], [[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr double, ptr [[SRC]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <vscale x 2 x double> @llvm.masked.load.nxv2f64.p0(ptr [[TMP3]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]], <vscale x 2 x double> poison)
+; CHECK-NEXT:    [[TMP4:%.*]] = call <vscale x 2 x double> @foo_uniform(<vscale x 2 x double> [[WIDE_MASKED_LOAD]], i64 [[UNIFORM]], <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr inbounds double, ptr [[DST]], i64 [[INDEX]]
+; CHECK-NEXT:    call void @llvm.masked.store.nxv2f64.p0(<vscale x 2 x double> [[TMP4]], ptr [[TMP5]], i32 8, <vscale x 2 x i1> [[ACTIVE_LANE_MASK]])
+; CHECK-NEXT:    [[TMP6:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP7:%.*]] = shl i64 [[TMP6]], 1
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[TMP7]]
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK_NEXT]] = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 [[INDEX]], i64 [[TMP2]])
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <vscale x 2 x i1> [[ACTIVE_LANE_MASK_NEXT]], i64 0
+; CHECK-NEXT:    br i1 [[TMP8]], label [[VECTOR_BODY]], label [[FOR_COND_CLEANUP:%.*]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       for.cond.cleanup:
 ; CHECK-NEXT:    ret void
 ;
@@ -51,7 +59,7 @@ define void @test_uniform_not_invariant(ptr noalias %dst, ptr readonly %src, i64
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[GEPSRC:%.*]] = getelementptr double, ptr [[SRC]], i64 [[INDVARS_IV]]
 ; CHECK-NEXT:    [[DATA:%.*]] = load double, ptr [[GEPSRC]], align 8
-; CHECK-NEXT:    [[CALL:%.*]] = call double @foo(double [[DATA]], i64 [[INDVARS_IV]]) #[[ATTR1]]
+; CHECK-NEXT:    [[CALL:%.*]] = call double @foo(double [[DATA]], i64 [[INDVARS_IV]]) #[[ATTR5:[0-9]+]]
 ; CHECK-NEXT:    [[GEPDST:%.*]] = getelementptr inbounds double, ptr [[DST]], i64 [[INDVARS_IV]]
 ; CHECK-NEXT:    store double [[CALL]], ptr [[GEPDST]], align 8
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
diff --git a/llvm/test/Transforms/LoopVectorize/uniform-args-call-variants.ll b/llvm/test/Transforms/LoopVectorize/uniform-args-call-variants.ll
index 4d322f30b693b49..9c12b9c2dcfca49 100644
--- a/llvm/test/Transforms/LoopVectorize/uniform-args-call-variants.ll
+++ b/llvm/test/Transforms/LoopVectorize/uniform-args-call-variants.ll
@@ -16,17 +16,12 @@ define void @test_uniform(ptr noalias %dst, ptr readonly %src, i64 %uniform , i6
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[SRC]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <2 x double> [[WIDE_LOAD]], i64 0
-; CHECK-NEXT:    [[TMP2:%.*]] = call double @foo(double [[TMP1]], i64 [[UNIFORM]]) #[[ATTR0:[0-9]+]]
-; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x double> [[WIDE_LOAD]], i64 1
-; CHECK-NEXT:    [[TMP4:%.*]] = call double @foo(double [[TMP3]], i64 [[UNIFORM]]) #[[ATTR0]]
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x double> poison, double [[TMP2]], i64 0
-; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x double> [[TMP5]], double [[TMP4]], i64 1
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds double, ptr [[DST]], i64 [[INDEX]]
-; CHECK-NEXT:    store <2 x double> [[TMP6]], ptr [[TMP7]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = call <2 x double> @foo_uniform(<2 x double> [[WIDE_LOAD]], i64 [[UNIFORM]])
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds double, ptr [[DST]], i64 [[INDEX]]
+; CHECK-NEXT:    store <2 x double> [[TMP1]], ptr [[TMP2]], align 8
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP8]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP3]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK:       middle.block:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[N_VEC]], [[N]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
@@ -37,7 +32,7 @@ define void @test_uniform(ptr noalias %dst, ptr readonly %src, i64 %uniform , i6
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[GEPSRC:%.*]] = getelementptr double, ptr [[SRC]], i64 [[INDVARS_IV]]
 ; CHECK-NEXT:    [[DATA:%.*]] = load double, ptr [[GEPSRC]], align 8
-; CHECK-NEXT:    [[CALL:%.*]] = call double @foo(double [[DATA]], i64 [[UNIFORM]]) #[[ATTR0]]
+; CHECK-NEXT:    [[CALL:%.*]] = call double @foo(double [[DATA]], i64 [[UNIFORM]]) #[[ATTR0:[0-9]+]]
 ; CHECK-NEXT:    [[GEPDST:%.*]] = getelementptr inbounds double, ptr [[DST]], i64 [[INDVARS_IV]]
 ; CHECK-NEXT:    store double [[CALL]], ptr [[GEPDST]], align 8
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1

Parameters marked as uniform take a scalar value, assuming the value is
invariant in the scalar loop.
@huntergr-arm
Copy link
Collaborator Author

Rebased on top of b1fba56 and adjusted ISA name to 's' for scalable vector tests. Added a case where the uniform parameter has a smaller element type than the vector parameter.

@huntergr-arm huntergr-arm merged commit 104b7c6 into llvm:main Nov 28, 2023
3 checks passed
@huntergr-arm huntergr-arm deleted the uniform-params branch November 28, 2023 15:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants