Skip to content

Commit

Permalink
[RISCV] Handle non-recursive muls of strides in gather/scatter lowering
Browse files Browse the repository at this point in the history
The gather scatter lowering pass can fold multiplies of a step vector
into the stride for the recursive case, so this extends it for the
non-recursive case.
The logic can probably be shared between the two at some point to extend
it to shls and ors.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D146983
  • Loading branch information
lukel97 committed Mar 27, 2023
1 parent 2da8ed3 commit d49f2c6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
17 changes: 13 additions & 4 deletions llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
}

// Not a constant, maybe it's a strided constant with a splat added to it.
// Not a constant, maybe it's a strided constant with a splat added or
// multipled.
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || BO->getOpcode() != Instruction::Add)
if (!BO || (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Mul))
return std::make_pair(nullptr, nullptr);

// Look for an operand that is splatted.
Expand All @@ -169,10 +171,17 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
if (!Start)
return std::make_pair(nullptr, nullptr);

// Add the splat value to the start.
Builder.SetInsertPoint(BO);
Builder.SetCurrentDebugLocation(DebugLoc());
Start = Builder.CreateAdd(Start, Splat);
// Add the splat value to the start
if (BO->getOpcode() == Instruction::Add) {
Start = Builder.CreateAdd(Start, Splat);
}
// Or multiply the start and stride by the splat.
else if (BO->getOpcode() == Instruction::Mul) {
Start = Builder.CreateMul(Start, Splat);
Stride = Builder.CreateMul(Stride, Splat);
}
return std::make_pair(Start, Stride);
}

Expand Down
22 changes: 10 additions & 12 deletions llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,11 @@ for.cond.cleanup: ; preds = %vector.body

define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
; CHECK-LABEL: @gather_loopless(
; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
; CHECK-NEXT: [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
Expand All @@ -115,12 +114,11 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {

define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
; CHECK-LABEL: @scatter_loopless(
; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
; CHECK-NEXT: [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
; CHECK-NEXT: call void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64> [[X:%.*]], <vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4
; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64(<vscale x 1 x i64> [[X:%.*]], ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret void
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
Expand Down

0 comments on commit d49f2c6

Please sign in to comment.