Skip to content

Commit

Permalink
[RISCVGatherScatterLowering] Support shl in non-recursive matching
Browse files Browse the repository at this point in the history
We can apply the same logic as for multiply since a left shift is just a multiply by a power of two. Note that since shl is not commutative, we do need to be careful to match sure that the splat is the RHS of the instruction.

Differential Revision: https://reviews.llvm.org/D150471
  • Loading branch information
preames committed May 12, 2023
1 parent 5d57a9f commit 715a043
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 8 deletions.
26 changes: 18 additions & 8 deletions llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,16 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
// multipled.
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Shl &&
BO->getOpcode() != Instruction::Mul))
return std::make_pair(nullptr, nullptr);

// Look for an operand that is splatted.
unsigned OtherIndex = 1;
Value *Splat = getSplatValue(BO->getOperand(0));
if (!Splat) {
Splat = getSplatValue(BO->getOperand(1));
OtherIndex = 0;
unsigned OtherIndex = 0;
Value *Splat = getSplatValue(BO->getOperand(1));
if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
Splat = getSplatValue(BO->getOperand(0));
OtherIndex = 1;
}
if (!Splat)
return std::make_pair(nullptr, nullptr);
Expand All @@ -158,13 +159,22 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
Builder.SetCurrentDebugLocation(DebugLoc());
// Add the splat value to the start or multiply the start and stride by the
// splat.
if (BO->getOpcode() == Instruction::Add) {
switch (BO->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case Instruction::Add:
Start = Builder.CreateAdd(Start, Splat);
} else {
assert(BO->getOpcode() == Instruction::Mul && "Unexpected opcode");
break;
case Instruction::Mul:
Start = Builder.CreateMul(Start, Splat);
Stride = Builder.CreateMul(Stride, Splat);
break;
case Instruction::Shl:
Start = Builder.CreateShl(Start, Splat);
Stride = Builder.CreateShl(Stride, Splat);
break;
}

return std::make_pair(Start, Stride);
}

Expand Down
88 changes: 88 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,94 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
ret <vscale x 1 x i64> %x
}

define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
; CHECK-LABEL: @straightline_offset_add(
; CHECK-NEXT: [[TMP1:%.*]] = add i64 0, [[OFFSET:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP2]], i64 4, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 %offset, i64 0
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
%offsetv = add <vscale x 1 x i64> %step, %splat
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offsetv
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
<vscale x 1 x ptr> %ptrs,
i32 8,
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
<vscale x 1 x i64> poison
)
ret <vscale x 1 x i64> %x
}

define <vscale x 1 x i64> @straightline_offset_shl(ptr %p) {
; CHECK-LABEL: @straightline_offset_shl(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 0
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 32, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
%offset = shl <vscale x 1 x i64> %step, %splat
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
<vscale x 1 x ptr> %ptrs,
i32 8,
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
<vscale x 1 x i64> poison
)
ret <vscale x 1 x i64> %x
}

define <vscale x 1 x i64> @neg_shl_is_not_commutative(ptr %p) {
; CHECK-LABEL: @neg_shl_is_not_commutative(
; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
; CHECK-NEXT: [[OFFSET:%.*]] = shl <vscale x 1 x i64> [[SPLAT]], [[STEP]]
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSET]]
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
%offset = shl <vscale x 1 x i64> %splat, %step
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
<vscale x 1 x ptr> %ptrs,
i32 8,
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
<vscale x 1 x i64> poison
)
ret <vscale x 1 x i64> %x
}

define <vscale x 1 x i64> @straightline_offset_shl_nonc(ptr %p, i64 %shift) {
; CHECK-LABEL: @straightline_offset_shl_nonc(
; CHECK-NEXT: [[TMP1:%.*]] = shl i64 0, [[SHIFT:%.*]]
; CHECK-NEXT: [[TMP2:%.*]] = shl i64 1, [[SHIFT]]
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
%splat.insert = insertelement <vscale x 1 x i64> poison, i64 %shift, i64 0
%splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
%offset = shl <vscale x 1 x i64> %step, %splat
%ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
%x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
<vscale x 1 x ptr> %ptrs,
i32 8,
<vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
<vscale x 1 x i64> poison
)
ret <vscale x 1 x i64> %x
}

define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
; CHECK-LABEL: @scatter_loopless(
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
Expand Down

0 comments on commit 715a043

Please sign in to comment.