diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp index 1129206800ad3..cd438e153068e 100644 --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -136,10 +136,15 @@ static std::pair matchStridedStart(Value *Start, // multipled. auto *BO = dyn_cast(Start); if (!BO || (BO->getOpcode() != Instruction::Add && + BO->getOpcode() != Instruction::Or && BO->getOpcode() != Instruction::Shl && BO->getOpcode() != Instruction::Mul)) return std::make_pair(nullptr, nullptr); + if (BO->getOpcode() == Instruction::Or && + !cast(BO)->isDisjoint()) + return std::make_pair(nullptr, nullptr); + // Look for an operand that is splatted. unsigned OtherIndex = 0; Value *Splat = getSplatValue(BO->getOperand(1)); @@ -163,6 +168,10 @@ static std::pair matchStridedStart(Value *Start, switch (BO->getOpcode()) { default: llvm_unreachable("Unexpected opcode"); + case Instruction::Or: + // TODO: We'd be better off creating disjoint or here, but we don't yet + // have an IRBuilder API for that. + [[fallthrough]]; case Instruction::Add: Start = Builder.CreateAdd(Start, Splat); break; @@ -241,7 +250,8 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, return false; case Instruction::Or: // We need to be able to treat Or as Add. - if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL)) + if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL) && + !cast(BO)->isDisjoint()) return false; break; case Instruction::Add: diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll index 838089baa46fc..54e5d39e24854 100644 --- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll @@ -183,11 +183,8 @@ define @straightline_offset_add(ptr %p, i64 %offset) { define @straightline_offset_disjoint_or(ptr %p, i64 %offset) { ; CHECK-LABEL: @straightline_offset_disjoint_or( -; CHECK-NEXT: [[STEP:%.*]] = call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[STEP_SHL:%.*]] = shl [[STEP]], shufflevector ( insertelement ( poison, i64 1, i32 0), poison, zeroinitializer) -; CHECK-NEXT: [[OFFSETV:%.*]] = or disjoint [[STEP_SHL]], shufflevector ( insertelement ( poison, i64 1, i32 0), poison, zeroinitializer) -; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], [[OFFSETV]] -; CHECK-NEXT: [[X:%.*]] = call @llvm.masked.gather.nxv1i64.nxv1p0( [[PTRS]], i32 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), poison) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 1 +; CHECK-NEXT: [[X:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( poison, ptr [[TMP1]], i64 8, shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) ; CHECK-NEXT: ret [[X]] ; %step = call @llvm.experimental.stepvector.nxv1i64()