diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp index 8a440ed29ac35..b1171dac6a094 100644 --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -148,9 +148,11 @@ static std::pair matchStridedStart(Value *Start, return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); } - // Not a constant, maybe it's a strided constant with a splat added to it. + // Not a constant, maybe it's a strided constant with a splat added or + // multipled. auto *BO = dyn_cast(Start); - if (!BO || BO->getOpcode() != Instruction::Add) + if (!BO || (BO->getOpcode() != Instruction::Add && + BO->getOpcode() != Instruction::Mul)) return std::make_pair(nullptr, nullptr); // Look for an operand that is splatted. @@ -169,10 +171,17 @@ static std::pair matchStridedStart(Value *Start, if (!Start) return std::make_pair(nullptr, nullptr); - // Add the splat value to the start. Builder.SetInsertPoint(BO); Builder.SetCurrentDebugLocation(DebugLoc()); - Start = Builder.CreateAdd(Start, Splat); + // Add the splat value to the start + if (BO->getOpcode() == Instruction::Add) { + Start = Builder.CreateAdd(Start, Splat); + } + // Or multiply the start and stride by the splat. + else if (BO->getOpcode() == Instruction::Mul) { + Start = Builder.CreateMul(Start, Splat); + Stride = Builder.CreateMul(Stride, Splat); + } return std::make_pair(Start, Stride); } diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll index 31fcf10fa3804..bcc73e039977a 100644 --- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll @@ -91,12 +91,11 @@ for.cond.cleanup: ; preds = %vector.body define @gather_loopless(ptr %p, i64 %stride) { ; CHECK-LABEL: @gather_loopless( -; CHECK-NEXT: [[STEP:%.*]] = call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement poison, i64 [[STRIDE:%.*]], i64 0 -; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[SPLAT_INSERT]], poison, zeroinitializer -; CHECK-NEXT: [[OFFSETS:%.*]] = mul [[STEP]], [[SPLAT]] -; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], [[OFFSETS]] -; CHECK-NEXT: [[X:%.*]] = call @llvm.masked.gather.nxv1i64.nxv1p0( [[PTRS]], i32 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), poison) +; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i64 1, [[STRIDE]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4 +; CHECK-NEXT: [[X:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( poison, ptr [[TMP3]], i64 [[TMP4]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) ; CHECK-NEXT: ret [[X]] ; %step = call @llvm.experimental.stepvector.nxv1i64() @@ -115,12 +114,11 @@ define @gather_loopless(ptr %p, i64 %stride) { define void @scatter_loopless( %x, ptr %p, i64 %stride) { ; CHECK-LABEL: @scatter_loopless( -; CHECK-NEXT: [[STEP:%.*]] = call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement poison, i64 [[STRIDE:%.*]], i64 0 -; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[SPLAT_INSERT]], poison, zeroinitializer -; CHECK-NEXT: [[OFFSETS:%.*]] = mul [[STEP]], [[SPLAT]] -; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], [[OFFSETS]] -; CHECK-NEXT: call void @llvm.masked.scatter.nxv1i64.nxv1p0( [[X:%.*]], [[PTRS]], i32 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i64 1, [[STRIDE]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4 +; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64( [[X:%.*]], ptr [[TMP3]], i64 [[TMP4]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) ; CHECK-NEXT: ret void ; %step = call @llvm.experimental.stepvector.nxv1i64()