-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Handle disjoint or in RISCVGatherScatterLowering #77800
Conversation
This patch adds support for the disjoint flag in the non-recursive case, but for the recursive case we were already handling this by checking that there were no common bits. This patch replaces that check with a check for the disjoint flag instead, since instcombine will already compute it for us. Co-authored-by: Philip Reames <preames@rivosinc.com>
@llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) ChangesThis patch adds support for the disjoint flag in the non-recursive case, but Co-authored-by: Philip Reames <preames@rivosinc.com> Full diff: https://github.com/llvm/llvm-project/pull/77800.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 1129206800ad36..1dcb83a6078ed7 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -136,10 +136,15 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
// multipled.
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || (BO->getOpcode() != Instruction::Add &&
+ BO->getOpcode() != Instruction::Or &&
BO->getOpcode() != Instruction::Shl &&
BO->getOpcode() != Instruction::Mul))
return std::make_pair(nullptr, nullptr);
+ if (BO->getOpcode() == Instruction::Or &&
+ !cast<PossiblyDisjointInst>(BO)->isDisjoint())
+ return std::make_pair(nullptr, nullptr);
+
// Look for an operand that is splatted.
unsigned OtherIndex = 0;
Value *Splat = getSplatValue(BO->getOperand(1));
@@ -163,6 +168,10 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
switch (BO->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
+ case Instruction::Or:
+ // TODO: We'd be better off creating disjoint or here, but we don't yet
+ // have an IRBuilder API for that.
+ [[fallthrough]];
case Instruction::Add:
Start = Builder.CreateAdd(Start, Splat);
break;
@@ -241,7 +250,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
return false;
case Instruction::Or:
// We need to be able to treat Or as Add.
- if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
+ if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
return false;
break;
case Instruction::Add:
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index 838089baa46fc4..54e5d39e248544 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -183,11 +183,8 @@ define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
define <vscale x 1 x i64> @straightline_offset_disjoint_or(ptr %p, i64 %offset) {
; CHECK-LABEL: @straightline_offset_disjoint_or(
-; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT: [[STEP_SHL:%.*]] = shl <vscale x 1 x i64> [[STEP]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT: [[OFFSETV:%.*]] = or disjoint <vscale x 1 x i64> [[STEP_SHL]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETV]]
-; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 1
+; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ts. NFC InstCombine will add the disjoint flag to these or instructions. This patch adds them to the tests so that it matches the input RISCVGatherScatterLowering will receive in practice, allowing us to rely on said disjoint flag: #77800 (comment)
If an or instruction has no common bits set in its operands, InstCombine will set the disjoint flag. This means we shouldn't need to compute it ourselves anymore in RISCVGatherScatterLowering, and can just rely on said flag being set. Originally split out from #77800
This patch adds support for the disjoint flag in the non-recursive case, as well as adding an additional check for it in the recursive case. Note that haveNoCommonBitsSet should be equivalent to having the disjoint flag set, and the check can be removed in a follow-up patch. Co-authored-by: Philip Reames <preames@rivosinc.com> --------- Co-authored-by: Philip Reames <preames@rivosinc.com>
…ts. NFC InstCombine will add the disjoint flag to these or instructions. This patch adds them to the tests so that it matches the input RISCVGatherScatterLowering will receive in practice, allowing us to rely on said disjoint flag: llvm#77800 (comment)
If an or instruction has no common bits set in its operands, InstCombine will set the disjoint flag. This means we shouldn't need to compute it ourselves anymore in RISCVGatherScatterLowering, and can just rely on said flag being set. Originally split out from llvm#77800
This patch adds support for the disjoint flag in the non-recursive case, as well as adding an additional check for it in the recursive case. Note that haveNoCommonBitsSet should be equivalent to having the disjoint flag set, and the check can be removed in a follow-up patch.
Co-authored-by: Philip Reames preames@rivosinc.com