Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GISEL] Add G_SPLAT_VECTOR_PARTS to represent 64-bit splat vectors on… #86970

Closed

Conversation

michaelmaitland
Copy link
Contributor

… i32 targets

We'd like to be able to represent the construction of a splat vector when the target has 32-bit integers but supports 64 bit vectors. This opcode allows us to represent that. It is the equivalent of ISD::SPLAT_VECTOR_PARTS. The ISD version takes a list of scalars, but this opcode accepts two scalars.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 28, 2024

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-llvm-globalisel

Author: Michael Maitland (michaelmaitland)

Changes

… i32 targets

We'd like to be able to represent the construction of a splat vector when the target has 32-bit integers but supports 64 bit vectors. This opcode allows us to represent that. It is the equivalent of ISD::SPLAT_VECTOR_PARTS. The ISD version takes a list of scalars, but this opcode accepts two scalars.


Full diff: https://github.com/llvm/llvm-project/pull/86970.diff

7 Files Affected:

  • (modified) llvm/docs/GlobalISel/GenericOpcode.rst (+9)
  • (modified) llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h (+14)
  • (modified) llvm/include/llvm/Support/TargetOpcodes.def (+3)
  • (modified) llvm/include/llvm/Target/GenericOpcodes.td (+10)
  • (modified) llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp (+11)
  • (modified) llvm/lib/CodeGen/MachineVerifier.cpp (+16)
  • (added) llvm/test/MachineVerifier/test_g_splat_vector_parts.mir (+53)
diff --git a/llvm/docs/GlobalISel/GenericOpcode.rst b/llvm/docs/GlobalISel/GenericOpcode.rst
index cae2c21b80d7e7..b0b9bce3fc90b0 100644
--- a/llvm/docs/GlobalISel/GenericOpcode.rst
+++ b/llvm/docs/GlobalISel/GenericOpcode.rst
@@ -690,6 +690,15 @@ G_SPLAT_VECTOR
 
 Create a vector where all elements are the scalar from the source operand.
 
+G_SPLAT_VECTOR_PARTS
+^^^^^^^^^^^^^^^^^^^^
+
+Create a vector where all elements are the scalar created by joining the
+operands together. This allows representing 64-bit splat on a target with 32-bit
+integers. The total width of the scalars must cover the element width exactly.
+The lo operand contains the least significant bits and the hi operand contains
+the most significant bits.
+
 Vector Reduction Operations
 ---------------------------
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
index 16a7fc446fbe1d..8000e610447ce9 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
@@ -1107,6 +1107,20 @@ class MachineIRBuilder {
   /// \return a MachineInstrBuilder for the newly created instruction.
   MachineInstrBuilder buildSplatVector(const DstOp &Res, const SrcOp &Val);
 
+  /// Build and insert \p Res = G_SPLAT_VECTOR_PARTS \p Lo Hi.
+  ///
+  /// \p Lo contains the least significant bits of the value. \p Hi contains the
+  /// most significant bits of the value.
+  ///
+  /// \pre setBasicBlock or setMI must have been called.
+  /// \pre \p Res must be a generic virtual register with vector type.
+  /// \pre \p Lo must be a generic virtual register with scalar type.
+  /// \pre \p Hi must be a generic virtual register with scalar type.
+  ///
+  /// \return a MachineInstrBuilder for the newly created instruction.
+  MachineInstrBuilder buildSplatVectorParts(const DstOp &Res, const SrcOp &Lo,
+                                            const SrcOp &Hi);
+
   /// Build and insert \p Res = G_CONCAT_VECTORS \p Op0, ...
   ///
   /// G_CONCAT_VECTORS creates a vector from the concatenation of 2 or more
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 5765926d6d93d3..b01622be31ef09 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -748,6 +748,9 @@ HANDLE_TARGET_OPCODE(G_SHUFFLE_VECTOR)
 /// Generic splatvector.
 HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR)
 
+/// Generic splatvector parts.
+HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR_PARTS)
+
 /// Generic count trailing zeroes.
 HANDLE_TARGET_OPCODE(G_CTTZ)
 
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index d0f471eb29b6fd..6d6eade99ca314 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1480,6 +1480,16 @@ def G_SPLAT_VECTOR: GenericInstruction {
   let hasSideEffects = false;
 }
 
+// Generic splatvector parts. This allows representing 64-bit splat on a target
+// with 32-bit integers. The total width of the scalars must cover the element
+// width. The lo operand contains the least significant bits and the hi operand
+// contains the most significant bits.
+def G_SPLAT_VECTOR_PARTS : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type1:$lo, type1:$hi);
+  let hasSideEffects = false;
+}
+
 //------------------------------------------------------------------------------
 // Vector reductions
 //------------------------------------------------------------------------------
diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
index 07d4cb5eaa23c8..f025e631aa0b2f 100644
--- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
@@ -749,6 +749,17 @@ MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
   return buildInstr(TargetOpcode::G_SPLAT_VECTOR, Res, Src);
 }
 
+MachineInstrBuilder MachineIRBuilder::buildSplatVectorParts(const DstOp &Res,
+                                                            const SrcOp &Lo,
+                                                            const SrcOp &Hi) {
+  TypeSize LoSize = Lo.getLLTTy(*getMRI()).getSizeInBits();
+  TypeSize HiSize = Hi.getLLTTy(*getMRI()).getSizeInBits();
+  TypeSize EltSize = Res.getLLTTy(*getMRI()).getElementType().getSizeInBits();
+  assert(LoSize + HiSize == EltSize &&
+         "Expected scalar sizes to cover Dst element size");
+  return buildInstr(TargetOpcode::G_SPLAT_VECTOR_PARTS, {Res}, {Lo, Hi});
+}
+
 MachineInstrBuilder MachineIRBuilder::buildShuffleVector(const DstOp &Res,
                                                          const SrcOp &Src1,
                                                          const SrcOp &Src2,
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index e4e05ce9278caf..f936c47ad65b1b 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -1781,6 +1781,22 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
 
     break;
   }
+  case TargetOpcode::G_SPLAT_VECTOR_PARTS: {
+    LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
+    LLT LoTy = MRI->getType(MI->getOperand(1).getReg());
+    LLT HiTy = MRI->getType(MI->getOperand(2).getReg());
+
+    if (!DstTy.isScalableVector())
+      report("Destination type must be a scalable vector", MI);
+
+    if (!LoTy.isScalar() || !HiTy.isScalar())
+      report("Source types must be scalar", MI);
+
+    if (LoTy.getSizeInBits() + HiTy.getSizeInBits() != DstTy.getSizeInBits())
+      report("Source types must cover the element type", MI);
+
+    break;
+  }
   case TargetOpcode::G_DYN_STACKALLOC: {
     const MachineOperand &DstOp = MI->getOperand(0);
     const MachineOperand &AllocOp = MI->getOperand(1);
diff --git a/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir b/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir
new file mode 100644
index 00000000000000..4d0d7f71046b6f
--- /dev/null
+++ b/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir
@@ -0,0 +1,53 @@
+# RUN: not --crash llc -o - -mtriple=arm64 -run-pass=none -verify-machineinstrs %s 2>&1 | FileCheck %s
+# REQUIRES: aarch64-registered-target
+---
+name:            g_splat_vector_parts
+tracksRegLiveness: true
+liveins:
+body:             |
+  bb.0:
+    %0:_(s32) = G_CONSTANT i32 0
+    %1:_(<2 x s32>) = G_IMPLICIT_DEF
+    %2:_(<vscale x 2 x s32>) = G_IMPLICIT_DEF
+
+    ; CHECK: Destination type must be a scalable vector
+    %3:_(s32) = G_SPLAT_VECTOR_PARTS %0, %0
+
+    ; CHECK: Destination type must be a scalable vector
+    %4:_(<2 x s32>) = G_SPLAT_VECTOR_PARTS %0, %0
+
+    ; CHECK: Source types must be scalar
+    %5:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR %1, %0
+
+    ; CHECK: Source types must be scalar
+    %6:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR %0, %1
+
+    ; CHECK: Source types must be scalar
+    %7:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR %1, %1
+
+    ; CHECK: Source type must be a scalar
+    %8:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR %0, %2
+
+    ; CHECK: Source type must be a scalar
+    %9:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR %2, %0
+
+    ; CHECK: Source type must be a scalar
+    %10:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR %2, %2
+
+    %11:_(s16) = G_CONSTANT i16 0
+
+    ; CHECK: Source types must cover the element type
+    %12:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR %11, %11
+
+    ; CHECK: Source types must cover the element type
+    %13:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR %11, %0
+
+    %14:_(s64) = G_CONSTANT i64
+
+    ; CHECK: Source types must cover the element type
+    %15:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR %14, %14
+
+    ; CHECK: Source types must cover the element type
+    %16:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR %14, %0
+
+...

… i32 targets

We'd like to be able to represent the construction of a splat vector
when the target has 32-bit integers but supports 64 bit vectors. This
opcode allows us to represent that. It is the equivalent of
ISD::SPLAT_VECTOR_PARTS. The ISD version takes a list of scalars, but
this opcode accepts two scalars.
@tschuett
Copy link
Member

I assume that this again for patterns, but G_SPLAT_VECTOR(G_MERGE_VALUES(lower, upper)) is the same?!?

@michaelmaitland
Copy link
Contributor Author

I assume that this again for patterns, but G_SPLAT_VECTOR(G_MERGE_VALUES(lower, upper)) is the same?!?

This is for patterns. There are no patterns for SPLAT_VECTOR with an i64 operand on RV32.

@aemerson
Copy link
Contributor

aemerson commented Mar 28, 2024

Which is the "lower level" representation? SPLAT_VECTOR_PARTS or a SPLAT of G_MERGE? Can we lower SPLAT_VECTOR_PARTS to the merge representation for targets that don't need it?

I'm also wondering if since only RISCV seems to need this whether we can just make it a RISCV G_ opcode.

@tschuett tschuett requested a review from aemerson March 28, 2024 17:07
@tschuett
Copy link
Member

Will AArch64 also need the parts or is this an SDAG legacy artefact?

@michaelmaitland
Copy link
Contributor Author

michaelmaitland commented Mar 28, 2024

Which is the "lower level" representation? SPLAT_VECTOR_PARTS or a SPLAT of G_MERGE? Can we lower SPLAT_VECTOR_PARTS to the merge representation for targets that don't need it?

I think targets that don't need it can lower to the merge representation. I think it depends on your target which is the lower level representation.

I'm also wondering if since only RISCV seems to need this whether we can just make it a RISCV G_ opcode.

I am okay with this. This way we don't need to worry about other targets needing to lower it. I wonder if in the long run we could remove this opcode and replace with splat patterns of i64 operand. I don't have the bandwidth right now to do this though. Would appreciate the ability to reuse patterns.

@aemerson
Copy link
Contributor

Will AArch64 also need the parts or is this an SDAG legacy artefact?

AFAICT no one else except for RISC-V is using it in SDAG. AArch64 doesn't have this illegal scalar type oddity.

@aemerson
Copy link
Contributor

Which is the "lower level" representation? SPLAT_VECTOR_PARTS or a SPLAT of G_MERGE? Can we lower SPLAT_VECTOR_PARTS to the merge representation for targets that don't need it?

I think targets that don't need it can lower to the merge representation. I think it depends on your target which is the lower level representation.

I'm also wondering if since only RISCV seems to need this whether we can just make it a RISCV G_ opcode.

I am okay with this. This way we don't need to worry about other targets needing to lower it. I wonder if in the long run we could remove this opcode and replace with splat patterns of i64 operand. I don't have the bandwidth right now to do this though. Would appreciate the ability to reuse patterns.

You can check AArch64InstrGISel.td for examples of where we define our target specific G_ opcodes. That said if you do this please add RISCV to the name like G_RISCV_SPLAT_VECTOR. We need to rename the AArch64 ones too.

@michaelmaitland
Copy link
Contributor Author

You can check AArch64InstrGISel.td for examples of where we define our target specific G_ opcodes. That said if you do this please add RISCV to the name like G_RISCV_SPLAT_VECTOR. We need to rename the AArch64 ones too.

I thought that these opcodes are part of the RISCV namespace. Do they need to duplicate RISCV in the name too?

@topperc
Copy link
Collaborator

topperc commented Mar 28, 2024

This node exists in SelectionDAG because type legalization and op legalization are separate steps. That's not true in GISel where there's only one legalizer. There's no G_SPLAT_VECTOR_PARTS instruction in RISC-V. So if the legalizer in GISel creates a G_SPLAT_VECTOR_PARTS when does it get removed?

@michaelmaitland
Copy link
Contributor Author

michaelmaitland commented Mar 28, 2024

This node exists in SelectionDAG because type legalization and op legalization are separate steps. That's not true in GISel where there's only one legalizer. There's no G_SPLAT_VECTOR_PARTS instruction in RISC-V. So if the legalizer in GISel creates a G_SPLAT_VECTOR_PARTS when does it get removed?

Do you think it makes more sense to lower G_SPLAT_VECTOR with i64 elements to a RISCV::G_SPLAT_VECTOR_PARTS_I64 directly?

@topperc
Copy link
Collaborator

topperc commented Mar 28, 2024

This node exists in SelectionDAG because type legalization and op legalization are separate steps. That's not true in GISel where there's only one legalizer. There's no G_SPLAT_VECTOR_PARTS instruction in RISC-V. So if the legalizer in GISel creates a G_SPLAT_VECTOR_PARTS when does it get removed?

Do you think it makes more sense to lower G_SPLAT_VECTOR with i64 elements to a RISCV::G_SPLAT_VECTOR_PARTS_I64 directly?

I think so.

@michaelmaitland
Copy link
Contributor Author

Going to lower directly to RISCV opcode space from G_SPLAT_VECTOR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants