Skip to content

Commit

Permalink
[RISCVGatherScatterLowering] Remove restriction that shift must have …
Browse files Browse the repository at this point in the history
…constant operand

This has been present from the original patch which added the pass, and doesn't appear to be strongly justified. We do need to be careful of commutativity.

Differential Revision: https://reviews.llvm.org/D150468
  • Loading branch information
preames committed May 12, 2023
1 parent 64f1fda commit 297e06c
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 4 deletions.
6 changes: 2 additions & 4 deletions llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
case Instruction::Add:
break;
case Instruction::Shl:
// Only support shift by constant.
if (!isa<Constant>(BO->getOperand(1)))
return false;
break;
case Instruction::Mul:
break;
Expand All @@ -251,7 +248,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
Index = cast<Instruction>(BO->getOperand(0));
OtherOp = BO->getOperand(1);
} else if (isa<Instruction>(BO->getOperand(1)) &&
L->contains(cast<Instruction>(BO->getOperand(1)))) {
L->contains(cast<Instruction>(BO->getOperand(1))) &&
Instruction::isCommutative(BO->getOpcode())) {
Index = cast<Instruction>(BO->getOperand(1));
OtherOp = BO->getOperand(0);
} else {
Expand Down
94 changes: 94 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,100 @@ for.cond.cleanup: ; preds = %vector.body
ret void
}

define void @gather_unknown_pow2(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
; CHECK-LABEL: @gather_unknown_pow2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT:%.*]]
; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT]]
; CHECK-NEXT: [[TMP0:%.*]] = mul i64 [[STRIDE]], 4
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[B:%.*]], i64 [[VEC_IND_SCALAR]]
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.riscv.masked.strided.load.v8i32.p0.i64(<8 x i32> undef, ptr [[TMP1]], i64 [[TMP0]], <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STEP]]
; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
; CHECK: for.cond.cleanup:
; CHECK-NEXT: ret void
;
entry:
%.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
%.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
br label %vector.body

vector.body: ; preds = %vector.body, %entry
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
%vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
%i = shl nsw <8 x i64> %vec.ind, %.splat
%i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
%wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
%i2 = getelementptr inbounds i32, ptr %A, i64 %index
%wide.load = load <8 x i32>, ptr %i2, align 1
%i4 = add <8 x i32> %wide.load, %wide.masked.gather
store <8 x i32> %i4, ptr %i2, align 1
%index.next = add nuw i64 %index, 8
%vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
%i6 = icmp eq i64 %index.next, 1024
br i1 %i6, label %for.cond.cleanup, label %vector.body

for.cond.cleanup: ; preds = %vector.body
ret void
}

define void @negative_shl_non_commute(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
; CHECK-LABEL: @negative_shl_non_commute(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x i64> poison, i64 [[SHIFT:%.*]], i64 0
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x i64> [[DOTSPLATINSERT]], <8 x i64> poison, <8 x i32> zeroinitializer
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, [[ENTRY]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[I:%.*]] = shl nsw <8 x i64> [[DOTSPLAT]], [[VEC_IND]]
; CHECK-NEXT: [[I1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], <8 x i64> [[I]]
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[I1]], i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1
; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]]
; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <8 x i64> [[VEC_IND]], <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
; CHECK: for.cond.cleanup:
; CHECK-NEXT: ret void
;
entry:
%.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0
%.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer
br label %vector.body

vector.body: ; preds = %vector.body, %entry
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
%vec.ind = phi <8 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7>, %entry ], [ %vec.ind.next, %vector.body ]
%i = shl nsw <8 x i64> %.splat, %vec.ind
%i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i
%wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
%i2 = getelementptr inbounds i32, ptr %A, i64 %index
%wide.load = load <8 x i32>, ptr %i2, align 1
%i4 = add <8 x i32> %wide.load, %wide.masked.gather
store <8 x i32> %i4, ptr %i2, align 1
%index.next = add nuw i64 %index, 8
%vec.ind.next = add <8 x i64> %vec.ind, <i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i64 8>
%i6 = icmp eq i64 %index.next, 1024
br i1 %i6, label %for.cond.cleanup, label %vector.body

for.cond.cleanup: ; preds = %vector.body
ret void
}

;void scatter_pow2(signed char * __restrict A, signed char * __restrict B) {
; for (int i = 0; i < 1024; ++i)
; A[i * 4] += B[i];
Expand Down

0 comments on commit 297e06c

Please sign in to comment.