Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Handle disjoint or in RISCVGatherScatterLowering #77800

Merged
merged 2 commits into from
Jan 15, 2024

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Jan 11, 2024

This patch adds support for the disjoint flag in the non-recursive case, as well as adding an additional check for it in the recursive case. Note that haveNoCommonBitsSet should be equivalent to having the disjoint flag set, and the check can be removed in a follow-up patch.

Co-authored-by: Philip Reames preames@rivosinc.com

This patch adds support for the disjoint flag in the non-recursive case, but
for the recursive case we were already handling this by checking that there
were no common bits. This patch replaces that check with a check for the
disjoint flag instead, since instcombine will already compute it for us.

Co-authored-by: Philip Reames <preames@rivosinc.com>
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

This patch adds support for the disjoint flag in the non-recursive case, but
for the recursive case we were already handling this by checking that there
were no common bits. This patch replaces that check with a check for the
disjoint flag instead, since instcombine will already compute it for us.

Co-authored-by: Philip Reames <preames@rivosinc.com>


Full diff: https://github.com/llvm/llvm-project/pull/77800.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp (+10-1)
  • (modified) llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll (+2-5)
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 1129206800ad36..1dcb83a6078ed7 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -136,10 +136,15 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   // multipled.
   auto *BO = dyn_cast<BinaryOperator>(Start);
   if (!BO || (BO->getOpcode() != Instruction::Add &&
+              BO->getOpcode() != Instruction::Or &&
               BO->getOpcode() != Instruction::Shl &&
               BO->getOpcode() != Instruction::Mul))
     return std::make_pair(nullptr, nullptr);
 
+  if (BO->getOpcode() == Instruction::Or &&
+      !cast<PossiblyDisjointInst>(BO)->isDisjoint())
+    return std::make_pair(nullptr, nullptr);
+
   // Look for an operand that is splatted.
   unsigned OtherIndex = 0;
   Value *Splat = getSplatValue(BO->getOperand(1));
@@ -163,6 +168,10 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   switch (BO->getOpcode()) {
   default:
     llvm_unreachable("Unexpected opcode");
+  case Instruction::Or:
+    // TODO: We'd be better off creating disjoint or here, but we don't yet
+    // have an IRBuilder API for that.
+    [[fallthrough]];
   case Instruction::Add:
     Start = Builder.CreateAdd(Start, Splat);
     break;
@@ -241,7 +250,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
     return false;
   case Instruction::Or:
     // We need to be able to treat Or as Add.
-    if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
+    if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
       return false;
     break;
   case Instruction::Add:
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index 838089baa46fc4..54e5d39e248544 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -183,11 +183,8 @@ define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
 
 define <vscale x 1 x i64> @straightline_offset_disjoint_or(ptr %p, i64 %offset) {
 ; CHECK-LABEL: @straightline_offset_disjoint_or(
-; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT:    [[STEP_SHL:%.*]] = shl <vscale x 1 x i64> [[STEP]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[OFFSETV:%.*]] = or disjoint <vscale x 1 x i64> [[STEP_SHL]], shufflevector (<vscale x 1 x i64> insertelement (<vscale x 1 x i64> poison, i64 1, i32 0), <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETV]]
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 1
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp Outdated Show resolved Hide resolved
@lukel97 lukel97 merged commit 0cf768e into llvm:main Jan 15, 2024
4 checks passed
lukel97 added a commit that referenced this pull request Jan 15, 2024
…ts. NFC

InstCombine will add the disjoint flag to these or instructions. This patch
adds them to the tests so that it matches the input RISCVGatherScatterLowering
will receive in practice, allowing us to rely on said disjoint flag:
#77800 (comment)
lukel97 added a commit that referenced this pull request Jan 15, 2024
If an or instruction has no common bits set in its operands, InstCombine will
set the disjoint flag. This means we shouldn't need to compute it ourselves
anymore in RISCVGatherScatterLowering, and can just rely on said flag being
set.
Originally split out from #77800
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
This patch adds support for the disjoint flag in the non-recursive case,
as well as adding an additional check for it in the recursive case. Note
that haveNoCommonBitsSet should be equivalent to having the disjoint
flag set, and the check can be removed in a follow-up patch.

Co-authored-by: Philip Reames <preames@rivosinc.com>

---------

Co-authored-by: Philip Reames <preames@rivosinc.com>
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…ts. NFC

InstCombine will add the disjoint flag to these or instructions. This patch
adds them to the tests so that it matches the input RISCVGatherScatterLowering
will receive in practice, allowing us to rely on said disjoint flag:
llvm#77800 (comment)
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
If an or instruction has no common bits set in its operands, InstCombine will
set the disjoint flag. This means we shouldn't need to compute it ourselves
anymore in RISCVGatherScatterLowering, and can just rely on said flag being
set.
Originally split out from llvm#77800
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants