-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCVGatherScatterLowering] Support vp_gather and vp_scatter #73612
[RISCVGatherScatterLowering] Support vp_gather and vp_scatter #73612
Conversation
@llvm/pr-subscribers-backend-risc-v Author: None (ShivaChen) ChangesSupport transfering vp_gather to experimental_vp_strided_load and vp_scatter to experimental_vp_strided_store. Full diff: https://github.com/llvm/llvm-project/pull/73612.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 5ad1e082344e77a..973b970f8ce131d 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -65,7 +65,7 @@ class RISCVGatherScatterLowering : public FunctionPass {
private:
bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
- Value *AlignOp);
+ MaybeAlign MA);
std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
IRBuilderBase &Builder);
@@ -459,9 +459,8 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
Type *DataType,
Value *Ptr,
- Value *AlignOp) {
+ MaybeAlign MA) {
// Make sure the operation will be supported by the backend.
- MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
EVT DataTypeVT = TLI->getValueType(*DL, DataType);
if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
return false;
@@ -493,11 +492,22 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
Intrinsic::riscv_masked_strided_load,
{DataType, BasePtr->getType(), Stride->getType()},
{II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
- else
+ else if (II->getIntrinsicID() == Intrinsic::vp_gather)
+ Call = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vp_strided_load,
+ {DataType, BasePtr->getType(), Stride->getType()},
+ {BasePtr, Stride, II->getArgOperand(1), II->getArgOperand(2)});
+ else if (II->getIntrinsicID() == Intrinsic::masked_scatter)
Call = Builder.CreateIntrinsic(
Intrinsic::riscv_masked_strided_store,
{DataType, BasePtr->getType(), Stride->getType()},
{II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
+ else if (II->getIntrinsicID() == Intrinsic::vp_scatter)
+ Call = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vp_strided_store,
+ {DataType, BasePtr->getType(), Stride->getType()},
+ {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(2),
+ II->getArgOperand(3)});
Call->takeName(II);
II->replaceAllUsesWith(Call);
@@ -533,22 +543,40 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
- if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
+ if (!II)
+ continue;
+ if (II->getIntrinsicID() == Intrinsic::masked_gather ||
+ II->getIntrinsicID() == Intrinsic::vp_gather) {
Gathers.push_back(II);
- } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
+ } else if (II->getIntrinsicID() == Intrinsic::masked_scatter ||
+ II->getIntrinsicID() == Intrinsic::vp_scatter) {
Scatters.push_back(II);
}
}
}
// Rewrite gather/scatter to form strided load/store if possible.
- for (auto *II : Gathers)
+ MaybeAlign MA;
+ for (auto *II : Gathers) {
+ if (II->getIntrinsicID() == Intrinsic::masked_gather)
+ MA = cast<ConstantInt>(II->getArgOperand(1))->getMaybeAlignValue();
+ else if (II->getIntrinsicID() == Intrinsic::vp_gather)
+ MA = II->getAttributes().getParamAttrs(0).getAlignment();
+
Changed |= tryCreateStridedLoadStore(
- II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
- for (auto *II : Scatters)
+ II, II->getType(), II->getArgOperand(0), MA);
+ }
+
+ for (auto *II : Scatters) {
+ if (II->getIntrinsicID() == Intrinsic::masked_scatter)
+ MA = cast<ConstantInt>(II->getArgOperand(2))->getMaybeAlignValue();
+ else if (II->getIntrinsicID() == Intrinsic::vp_scatter)
+ MA = II->getAttributes().getParamAttrs(1).getAlignment();
+
Changed |=
tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
- II->getArgOperand(1), II->getArgOperand(2));
+ II->getArgOperand(1), MA);
+ }
// Remove any dead phis.
while (!MaybeDeadPHIs.empty()) {
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index fcb3742eb2363ba..e722bea1919b06f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -230,6 +230,69 @@ define void @constant_stride(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
ret void
}
+define void @vp_gather_scatter(ptr %A, i64 %vl, i64 %stride, i64 %n.vec) {
+; CHECK-LABEL: @vp_gather_scatter(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDVARS_IV36:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDVARS_IV_NEXT37:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[EVL_BASED_IV:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[AVL:%.*]] = sub i64 100, [[EVL_BASED_IV]]
+; CHECK-NEXT: [[EVL:%.*]] = tail call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 4, i1 true)
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [200 x i32], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i64 [[INDVARS_IV36]]
+; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vp.strided.load.nxv4i32.p0.i64(ptr [[TMP0]], i64 800, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 [[EVL]])
+; CHECK-NEXT: [[TMP1:%.*]] = shl nsw <vscale x 4 x i32> [[WIDE_MASKED_GATHER]], shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 1, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+; CHECK-NEXT: call void @llvm.experimental.vp.strided.store.nxv4i32.p0.i64(<vscale x 4 x i32> [[TMP1]], ptr [[TMP0]], i64 800, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 [[EVL]])
+; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[EVL]] to i64
+; CHECK-NEXT: [[INDEX_EVL_NEXT]] = add i64 [[EVL_BASED_IV]], [[TMP2]]
+; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[VL:%.*]]
+; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STRIDE:%.*]]
+; CHECK-NEXT: [[INDVARS_IV_NEXT37]] = add nuw nsw i64 [[INDVARS_IV36]], 1
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC:%.*]]
+; CHECK-NEXT: br i1 [[TMP3]], label [[END:%.*]], label [[VECTOR_BODY]]
+; CHECK: end:
+; CHECK-NEXT: ret void
+;
+entry:
+ %stepvector = tail call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %.splatinsert = insertelement <vscale x 4 x i64> poison, i64 %stride, i64 0
+ %.splat = shufflevector <vscale x 4 x i64> %.splatinsert, <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer
+ br label %vector.ph
+
+vector.ph: ; preds = %for.inc14, %entry
+ br label %vector.body
+
+vector.body: ; preds = %vector.body, %vector.ph
+ %indvars.iv36 = phi i64 [ 0, %vector.ph ], [ %indvars.iv.next37, %vector.body ]
+ %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
+ %evl.based.iv = phi i64 [ 0, %vector.ph ], [ %index.evl.next, %vector.body ]
+ %vec.ind = phi <vscale x 4 x i64> [ %stepvector, %vector.ph ], [ %vec.ind.next, %vector.body ]
+ %avl = sub i64 100, %evl.based.iv
+ %evl = tail call i32 @llvm.experimental.get.vector.length.i64(i64 %avl, i32 4, i1 true)
+ %0 = getelementptr inbounds [200 x i32], ptr %A, <vscale x 4 x i64> %vec.ind, i64 %indvars.iv36
+ %wide.masked.gather = tail call <vscale x 4 x i32> @llvm.vp.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> align 4 %0, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 %evl)
+ %1 = shl nsw <vscale x 4 x i32> %wide.masked.gather, shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 1, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+ tail call void @llvm.vp.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> %1, <vscale x 4 x ptr> align 4 %0, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 %evl)
+ %2 = zext i32 %evl to i64
+ %index.evl.next = add i64 %evl.based.iv, %2
+ %index.next = add i64 %index, %vl
+ %vec.ind.next = add <vscale x 4 x i64> %vec.ind, %.splat
+ %indvars.iv.next37 = add nuw nsw i64 %indvars.iv36, 1
+ %3 = icmp eq i64 %index.next, %n.vec
+ br i1 %3, label %end, label %vector.body
+
+end:
+ ret void
+}
+
declare i64 @llvm.vscale.i64()
declare void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64>, <vscale x 1 x ptr>, i32, <vscale x 1 x i1>)
declare <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr>, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+declare i32 @llvm.experimental.get.vector.length.i64(i64, i32 immarg, i1 immarg)
+declare <vscale x 4 x i32> @llvm.vp.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr>, <vscale x 4 x i1>, i32)
+declare void @llvm.vp.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32>, <vscale x 4 x ptr>, <vscale x 4 x i1>, i32)
|
You can test this locally with the following command:git-clang-format --diff ffcc5c7796b00dd7f945cb2a1bd30399f616c6b0 12c50c30d4963051836016978297ddfe3c2dcb47 -- llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp View the diff from clang-format here.diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 973b970f8c..4ef9d84309 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -563,8 +563,8 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
else if (II->getIntrinsicID() == Intrinsic::vp_gather)
MA = II->getAttributes().getParamAttrs(0).getAlignment();
- Changed |= tryCreateStridedLoadStore(
- II, II->getType(), II->getArgOperand(0), MA);
+ Changed |=
+ tryCreateStridedLoadStore(II, II->getType(), II->getArgOperand(0), MA);
}
for (auto *II : Scatters) {
@@ -573,9 +573,8 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
else if (II->getIntrinsicID() == Intrinsic::vp_scatter)
MA = II->getAttributes().getParamAttrs(1).getAlignment();
- Changed |=
- tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
- II->getArgOperand(1), MA);
+ Changed |= tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
+ II->getArgOperand(1), MA);
}
// Remove any dead phis.
|
I'm not opposed to this patch, but I thought the BSC vectorizer that produces VP intrinics emits strided load/stores directly? |
It seems Roger favor a new patch. |
We had something like this patch in our downstream, but we got a lot better results doing it in the vectorizer with SCEV information like what's in the BSC vectorizer. |
Thanks for the information. Could I ask is there plan to upstream the vp stride intrinsic generations in vectorizer? |
Yes, we have a plan to fully upstream our vectorizer (with some changes, of course, to meet the requirements of the upstream implementation) |
That would be great! Thanks :-) |
Support transfering vp_gather to experimental_vp_strided_load and vp_scatter to experimental_vp_strided_store.