Skip to content

Commit

Permalink
[RISCV] Handle disjoint or in RISCVGatherScatterLowering (#77800)
Browse files Browse the repository at this point in the history
This patch adds support for the disjoint flag in the non-recursive case,
as well as adding an additional check for it in the recursive case. Note
that haveNoCommonBitsSet should be equivalent to having the disjoint
flag set, and the check can be removed in a follow-up patch.

Co-authored-by: Philip Reames <preames@rivosinc.com>

---------

Co-authored-by: Philip Reames <preames@rivosinc.com>
  • Loading branch information
lukel97 and preames committed Jan 15, 2024
1 parent fa5255e commit 0cf768e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
12 changes: 11 additions & 1 deletion llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,15 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
// multipled.
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || (BO->getOpcode() != Instruction::Add &&
BO->getOpcode() != Instruction::Or &&
BO->getOpcode() != Instruction::Shl &&
BO->getOpcode() != Instruction::Mul))
return std::make_pair(nullptr, nullptr);

if (BO->getOpcode() == Instruction::Or &&
!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return std::make_pair(nullptr, nullptr);

// Look for an operand that is splatted.
unsigned OtherIndex = 0;
Value *Splat = getSplatValue(BO->getOperand(1));
Expand All @@ -163,6 +168,10 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
switch (BO->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case Instruction::Or:
// TODO: We'd be better off creating disjoint or here, but we don't yet
// have an IRBuilder API for that.
[[fallthrough]];
case Instruction::Add:
Start = Builder.CreateAdd(Start, Splat);
break;
Expand Down Expand Up @@ -241,7 +250,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
return false;
case Instruction::Or:
// We need to be able to treat Or as Add.
if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL) &&
!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;
break;
case Instruction::Add:
Expand Down
7 changes: 2 additions & 5 deletions llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,8 @@ define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {

define <vscale x 1 x i64> @straightline_offset_disjoint_or(ptr %p, i64 %offset) {
; CHECK-LABEL: @straightline_offset_disjoint_or(
; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
; CHECK-NEXT: [[STEP_SHL:%.*]] = shl <vscale x 1 x i64> [[STEP]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
; CHECK-NEXT: [[OFFSETV:%.*]] = or disjoint <vscale x 1 x i64> [[STEP_SHL]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETV]]
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 1
; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
Expand Down

0 comments on commit 0cf768e

Please sign in to comment.