Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GlobalISel] Legalize Scalable Vector Loads and Stores #84965

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jiahanxie353
Copy link
Contributor

@jiahanxie353 jiahanxie353 commented Mar 12, 2024

This patch works on legalizing load and store instruction for scalable vectors

Copy link

github-actions bot commented Mar 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link

✅ With the latest revision this PR passed the Python code formatter.

@michaelmaitland michaelmaitland self-requested a review March 25, 2024 16:24
@michaelmaitland
Copy link
Contributor

Can you please update PR description?

@@ -220,7 +220,8 @@ struct TypePairAndMemDesc {
Align >= Other.Align &&
// FIXME: This perhaps should be stricter, but the current legality
// rules are written only considering the size.
MemTy.getSizeInBits() == Other.MemTy.getSizeInBits();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this really cause a problem? It looks like TypeSize::operator== should not cause a crash for scalable vectors:

constexpr bool operator==(const FixedOrScalableQuantity &RHS) const {
  return Quantity == RHS.Quantity && Scalable == RHS.Scalable;
}

Copy link
Collaborator

@topperc topperc Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the min value is losing the scalable bit. A fixed size and scalable size should not be compatible right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the min value is losing the scalable bit. Why a fixed size and scalable size should not be compatible right?

Should scalable nxv2s8 be compatible with s8? What about nxv2s8 and s8? I'm a bit unsure about the logic behind

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have an s8 its type size in bits will be (ElementCount::Fixed(1), 8 bits) If you have an nxv2s8 its type size in bits will be (ElementCount::Sclable(2), 8). (ElementCount::Fixed(1), 8) != (ElementCount::Sclable(2), 8)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, (ElementCount::Scalable(2), 8) != (ElementCount::Fixed(1), 8), but does that mean they are incompatible?
Since (ElementCount::Scalable(2), 8) can hold at least two s8 types, it'll definitely be a multiple of (ElementCount::Fixed(1), 8). So it seems to me (ElementCount::Scalable(2), 8) can hold/be "compatible" with (ElementCount::Fixed(1), 8), if it makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true we can do BIT_CAST if ScalableEltTy != FixedEltTy && MinNumElts * EltTy == FixedEltTy (2 * 8 == 16 in your example).
I was wondering what if ScalableEltTy == FixedEltTy and MinNumElts > 1? In these cases, the scalable vector is the "superset" of the fixed one so they can be compatible in the sense of fitting one into another. Am I understanding it correctly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently support a G_BITCAST between a scalar and a vector at the moment. I was using it as an example to show that if the types do not obviously match, then maybe it is a better idea to explicitly convert them into a form that does match using an instruction.

In the case of ScalableEltTy == ScalarEltTy and MinNumElts > 1 I am suggesting that we disallow the load because the TypeSize does not match, and instead rely on an instruction to get us in a form where type size does match. My point is that it is possible but we should be explicit in making it possible. Vectors and scalars use different registers on RISC-V and it is likely that that "convert instruction" will make it easier to select a transfer from vector to scalar or scalar to vector if we need to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, thanks!
To answer your question:

Did this really cause a problem? It looks like TypeSize::operator== should not cause a crash for scalable vectors: ...

The answer is yes. And the reason why I was changing to MemTy.getSizeInBits().getKnownMinValue() == Other.MemTy.getSizeInBits().getKnownMinValue() is because I have:

auto &LoadStoreActions = getActionDefinitionsBuilder({G_LOAD, G_STORE})
.legalForTypesWithMemDesc({{s32, p0, s8, 8},
{s32, p0, s16, 16},
{s32, p0, s32, 32},
{p0, p0, sXLen, XLen},
{nxv1s8, p0, s8, 8}});

In this case, MemTy is nxv1s8 and Other.MemTy is s8. Therefore, based on our discussion, they should be incompatible. And seems like the only solution to make them compatible is:

.legalForTypesWithMemDesc({{s32, p0, s8, 8},
                                                     {s32, p0, s16, 16},
                                                     {s32, p0, s32, 32},
                                                     {p0, p0, sXLen, XLen},
                                                     {nxv1s8, p0, nxv1s8, 8}});

But this doesn't look right to me..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The third parameter definitely needs to be a scalable type to match true size of nxv1s8.

AArch64 has this for vectors, but they don't support scalable vectors yet. The third parameter is a scalar the same total size as the vector

      .legalForTypesWithMemDesc({{s8, p0, s8, 8},                                
                                 {s16, p0, s16, 8},                              
                                 {s32, p0, s32, 8},                              
                                 {s64, p0, s64, 8},                              
                                 {p0, p0, s64, 8},                               
                                 {s128, p0, s128, 8},                            
                                 {v8s8, p0, s64, 8},                             
                                 {v16s8, p0, s128, 8},                           
                                 {v4s16, p0, s64, 8},                            
                                 {v8s16, p0, s128, 8},                           
                                 {v2s32, p0, s64, 8},                            
                                 {v4s32, p0, s128, 8},                           
                                 {v2s64, p0, s128, 8}})   

So {nxv1s8, p0, s8, 8} seems wrong to me, but I'm not sure what we can put there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess since isCompatible just checks the size, I guess {v16s8, p0, s128, 8} or {v16s8, p0, v16s8, 8} are equivalent for AArch64?

So using {nxv1s8, p0, nxv1s8, 8} would be correct for RISC-V.

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp Outdated Show resolved Hide resolved
llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp Outdated Show resolved Hide resolved
{s32, p0, s16, 16},
{s32, p0, s32, 32},
{p0, p0, sXLen, XLen}});
auto &LoadStoreActions = getActionDefinitionsBuilder({G_LOAD, G_STORE})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do lowerIfMemSizeNotByteSizePow2 like AArch64?

@@ -0,0 +1,29 @@
# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
# RUN: llc -mtriple=riscv32 -mattr=+v -run-pass=legalizer %s -o - | FileCheck %s
# RUN: llc -mtriple=riscv64 -mattr=+v -run-pass=legalizer %s -o - | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You changed G_LOAD and G_STORE legalizer rules above. I only see G_LOAD tests. Either we need to add tests and change PR description or we need to not change legalizer rules for G_STORE.

}

...
---
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be testing all the cases that came out of the IRTranslator tests.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 25, 2024

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-risc-v

Author: Jiahan Xie (jiahanxie353)

Changes

This patch works on legalizing load instruction for scalable vectors


Full diff: https://github.com/llvm/llvm-project/pull/84965.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h (+2-1)
  • (modified) llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp (+6-2)
  • (modified) llvm/lib/CodeGen/MIRParser/MIParser.cpp (+1-1)
  • (modified) llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp (+6-6)
  • (added) llvm/test/CodeGen/RISCV/GlobalISel/legalizer/rvv/legalize-load.mir (+29)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 6afaea3f3fc5c6..5d60f4f1829397 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -220,7 +220,8 @@ struct TypePairAndMemDesc {
            Align >= Other.Align &&
            // FIXME: This perhaps should be stricter, but the current legality
            // rules are written only considering the size.
-           MemTy.getSizeInBits() == Other.MemTy.getSizeInBits();
+           MemTy.getSizeInBits().getKnownMinValue() ==
+               Other.MemTy.getSizeInBits().getKnownMinValue();
   }
 };
 
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 1b25da8833e4fb..fb18a2fd4d3e8c 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -3330,8 +3330,12 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
   LLT MemTy = MMO.getMemoryType();
   MachineFunction &MF = MIRBuilder.getMF();
 
-  unsigned MemSizeInBits = MemTy.getSizeInBits();
-  unsigned MemStoreSizeInBits = 8 * MemTy.getSizeInBytes();
+  unsigned MemSizeInBits = MemTy.isScalable()
+                               ? MemTy.getSizeInBits().getKnownMinValue()
+                               : MemTy.getSizeInBits();
+  unsigned MemStoreSizeInBits =
+      MemTy.isScalable() ? 8 * MemTy.getSizeInBytes().getKnownMinValue()
+                         : 8 * MemTy.getSizeInBytes();
 
   if (MemSizeInBits != MemStoreSizeInBits) {
     if (MemTy.isVector())
diff --git a/llvm/lib/CodeGen/MIRParser/MIParser.cpp b/llvm/lib/CodeGen/MIRParser/MIParser.cpp
index 691c60d22724f3..43f6d3219bc6da 100644
--- a/llvm/lib/CodeGen/MIRParser/MIParser.cpp
+++ b/llvm/lib/CodeGen/MIRParser/MIParser.cpp
@@ -3415,7 +3415,7 @@ bool MIParser::parseMachineMemoryOperand(MachineMemOperand *&Dest) {
     if (expectAndConsume(MIToken::rparen))
       return true;
 
-    Size = MemoryType.getSizeInBytes();
+    Size = MemoryType.getSizeInBytes().getKnownMinValue();
   }
 
   MachinePointerInfo Ptr = MachinePointerInfo();
diff --git a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
index 64ae4e94a8c929..bcded4b227e7ef 100644
--- a/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
+++ b/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
@@ -210,12 +210,12 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
       .clampScalar(0, s32, (XLen == 64 || ST.hasStdExtD()) ? s64 : s32)
       .clampScalar(1, sXLen, sXLen);
 
-  auto &LoadStoreActions =
-      getActionDefinitionsBuilder({G_LOAD, G_STORE})
-          .legalForTypesWithMemDesc({{s32, p0, s8, 8},
-                                     {s32, p0, s16, 16},
-                                     {s32, p0, s32, 32},
-                                     {p0, p0, sXLen, XLen}});
+  auto &LoadStoreActions = getActionDefinitionsBuilder({G_LOAD, G_STORE})
+                               .legalForTypesWithMemDesc({{s32, p0, s8, 8},
+                                                          {s32, p0, s16, 16},
+                                                          {s32, p0, s32, 32},
+                                                          {p0, p0, sXLen, XLen},
+                                                          {nxv1s8, p0, s8, 8}});
   auto &ExtLoadActions =
       getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
           .legalForTypesWithMemDesc({{s32, p0, s8, 8}, {s32, p0, s16, 16}});
diff --git a/llvm/test/CodeGen/RISCV/GlobalISel/legalizer/rvv/legalize-load.mir b/llvm/test/CodeGen/RISCV/GlobalISel/legalizer/rvv/legalize-load.mir
new file mode 100644
index 00000000000000..1b62f555207165
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/GlobalISel/legalizer/rvv/legalize-load.mir
@@ -0,0 +1,29 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -mtriple=riscv32 -mattr=+v -run-pass=legalizer %s -o - | FileCheck %s
+# RUN: llc -mtriple=riscv64 -mattr=+v -run-pass=legalizer %s -o - | FileCheck %s
+--- |
+
+  define <vscale x 1 x i8> @vload_nx1i8(ptr %pa) {
+    %va = load <vscale x 1 x i8>, ptr %pa
+    ret <vscale x 1 x i8> %va
+  }
+
+...
+---
+name:            vload_nx1i8
+body:             |
+  bb.1 (%ir-block.0):
+    liveins: $x10
+
+    ; CHECK-LABEL: name: vload_nx1i8
+    ; CHECK: liveins: $x10
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p0) = COPY $x10
+    ; CHECK-NEXT: [[LOAD:%[0-9]+]]:_(<vscale x 1 x s8>) = G_LOAD [[COPY]](p0) :: (load (<vscale x 1 x s8>) from %ir.pa)
+    ; CHECK-NEXT: $v8 = COPY [[LOAD]](<vscale x 1 x s8>)
+    ; CHECK-NEXT: PseudoRET implicit $v8
+    %0:_(p0) = COPY $x10
+    %1:_(<vscale x 1 x s8>) = G_LOAD %0(p0) :: (load (<vscale x 1 x s8>) from %ir.pa)
+    $v8 = COPY %1(<vscale x 1 x s8>)
+    PseudoRET implicit $v8
+

@@ -220,7 +220,8 @@ struct TypePairAndMemDesc {
Align >= Other.Align &&
// FIXME: This perhaps should be stricter, but the current legality
// rules are written only considering the size.
MemTy.getSizeInBits() == Other.MemTy.getSizeInBits();
Copy link
Collaborator

@topperc topperc Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the min value is losing the scalable bit. A fixed size and scalable size should not be compatible right?

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp Outdated Show resolved Hide resolved
@jiahanxie353 jiahanxie353 changed the title [RISCV][GlobalISel] Legalize Scalable Vector Loads [RISCV][GlobalISel] Legalize Scalable Vector Loads and Stores Mar 26, 2024
@@ -3330,16 +3330,17 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerLoad(GAnyLoad &LoadMI) {
LLT MemTy = MMO.getMemoryType();
MachineFunction &MF = MIRBuilder.getMF();

unsigned MemSizeInBits = MemTy.getSizeInBits();
unsigned MemStoreSizeInBits = 8 * MemTy.getSizeInBytes();
unsigned MinMemSizeInBits = MemTy.getSizeInBits().getKnownMinValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do this as TypeSize MemSizeInBits and TypeSize MinMemStoreSizeInBits as Craig Pointed out, which allows us to avoid calling getKnownMinValue. That way we avoid the case comparing fixed vector and scalable vector.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood.
I used getKnownMinValue because we have cases for doing scalar arithmetic operations below. For example:

if (!isPowerOf2_32(MinMemSizeInBits)) {
// This load needs splitting into power of 2 sized loads.
LargeSplitSize = llvm::bit_floor(MinMemSizeInBits);
SmallSplitSize = MinMemSizeInBits - LargeSplitSize;
} else {
// This is already a power of 2, but we still need to split this in half.
//
// Assume we're being asked to decompose an unaligned load.
// TODO: If this requires multiple splits, handle them all at once.
auto &Ctx = MF.getFunction().getContext();
if (TLI.allowsMemoryAccess(Ctx, MIRBuilder.getDataLayout(), MemTy, MMO))
return UnableToLegalize;
SmallSplitSize = LargeSplitSize = MinMemSizeInBits / 2;
}

bit_floor has to take a scalar.

So are you suggesting we make them TypeSize MemSizeInBits but call getKnownMinValue on demand?

unsigned MemSizeInBits = MemTy.getSizeInBits();
unsigned MemStoreSizeInBits = 8 * MemTy.getSizeInBytes();
unsigned MinMemSizeInBits = MemTy.getSizeInBits().getKnownMinValue();
unsigned MinMemStoreSizeInBits =
Copy link
Collaborator

@topperc topperc Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to rewrite this as MemTy.getSizeInBytes() * 8 to use TypeSize. I think I remember that the constant multiplier has to be on the right hand side for the operator overloading.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -249,7 +249,15 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
.legalForTypesWithMemDesc({{s32, p0, s8, 8},
{s32, p0, s16, 16},
{s32, p0, s32, 32},
{p0, p0, sXLen, XLen}});
{p0, p0, sXLen, XLen},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about s16, s32, sXLen and their vector types?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be sXLen, but we do need s64 when ST.hasVInstructionsI64() is true.

nxv1s8 shouldn't be supported unless getELEN() == 64. Same as the type legality rules for G_ADD.

@jiahanxie353
Copy link
Contributor Author

When we have alignment restrictions for load, we need to check allowsMemoryAccessForAlignment, like what SelectionDAG did. Therefore, equivalently, we can insert customIf(!allowsMemoryAccessForAlignment) around here

But to get such information, we need to have LLVMContext, DataLayout, etc. However, RISCVLegalizerInfo does not have those fields in it. Nor can I figure out a way to use LegalityQuery to obtain those required arguments.

Can I get some help on this? Thanks!

@michaelmaitland
Copy link
Contributor

When we have alignment restrictions for load, we need to check allowsMemoryAccessForAlignment, like what SelectionDAG did. Therefore, equivalently, we can insert customIf(!allowsMemoryAccessForAlignment) around here

But to get such information, we need to have LLVMContext, DataLayout, etc. However, RISCVLegalizerInfo does not have those fields in it. Nor can I figure out a way to use LegalityQuery to obtain those required arguments.

Can I get some help on this? Thanks!

Check out legalizeIntrinsic in RISCVLegalizerInfo. You can get the DataLayout from MIRBuilder.getDataLayout(). You can get LLVMContext from MF.getFunction.getContext().

These are all available in legalizeCustom but not in the LegalityQuery, since the LegalityQuery gets built before seeing any instructions or functions. That means you will need to write a customIf that sends us to legalizeCustom more circumstances than are actually legal -- we will likely have to return false in some instances from legalizeCustom.

@topperc
Copy link
Collaborator

topperc commented May 28, 2024

When we have alignment restrictions for load, we need to check allowsMemoryAccessForAlignment, like what SelectionDAG did. Therefore, equivalently, we can insert customIf(!allowsMemoryAccessForAlignment) around here

But to get such information, we need to have LLVMContext, DataLayout, etc. However, RISCVLegalizerInfo does not have those fields in it. Nor can I figure out a way to use LegalityQuery to obtain those required arguments.

Can I get some help on this? Thanks!

The DataLayout belongs to Module. The LegalizerInfo constructor runs before the Module exists or at least without knowledge of the Module. I don't think you can call allowsMemoryAccessForAlignment.

A misaligned vector can be handled by converting it to a vle8. Not sure if the generic legalization can handle that or we need to do custom legalization for RISC-V. SelectionDAG uses custom legalization, see RISCVTargetLowering::expandUnalignedRVVLoad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants