Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][GlobalISel] Basic SVE and fadd #72976

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

davemgreen
Copy link
Collaborator

This appears to be the minimum needed to get SVE fadd working. It needs more testing, just putting it up to show it works OK so far.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 21, 2023

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-regalloc

Author: David Green (davemgreen)

Changes

This appears to be the minimum needed to get SVE fadd working. It needs more testing, just putting it up to show it works OK so far.


Patch is 25.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72976.diff

13 Files Affected:

  • (modified) llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/RegisterBankInfo.cpp (+3-3)
  • (modified) llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def (+7-5)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+8-3)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterBanks.td (+1-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp (+8-6)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+20-7)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+6-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp (+33-29)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h (+3-3)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+2-4)
  • (added) llvm/test/CodeGen/AArch64/sve-add.ll (+12)
  • (modified) llvm/utils/TableGen/InfoByHwMode.cpp (+4-5)
diff --git a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
index baea773cf528e92..f04e4cdb764f2a3 100644
--- a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
@@ -277,7 +277,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
     }
 
     const LLT Ty = MRI.getType(VReg);
-    if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
+    if (Ty.isValid() &&
+        TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
       reportGISelFailure(
           MF, TPC, MORE, "gisel-select",
           "VReg's low-level type and register class have different sizes", *MI);
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index 6a96bb40f56aed9..5548430d1b0ae88 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -565,9 +565,9 @@ bool RegisterBankInfo::ValueMapping::verify(const RegisterBankInfo &RBI,
     OrigValueBitWidth =
         std::max(OrigValueBitWidth, PartMap.getHighBitIdx() + 1);
   }
-  assert(MeaningfulBitWidth.isScalable() ||
-         OrigValueBitWidth >= MeaningfulBitWidth &&
-             "Meaningful bits not covered by the mapping");
+  assert((MeaningfulBitWidth.isScalable() ||
+          OrigValueBitWidth >= MeaningfulBitWidth) &&
+         "Meaningful bits not covered by the mapping");
   APInt ValueMask(OrigValueBitWidth, 0);
   for (const RegisterBankInfo::PartialMapping &PartMap : *this) {
     // Check that the union of the partial mappings covers the whole value,
diff --git a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
index b87421e5ee46ae5..0b3557e67240520 100644
--- a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
+++ b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
@@ -136,8 +136,8 @@ bool AArch64GenRegisterBankInfo::checkValueMapImpl(unsigned Idx,
                                                    unsigned Size,
                                                    unsigned Offset) {
   unsigned PartialMapBaseIdx = Idx - PartialMappingIdx::PMI_Min;
-  const ValueMapping &Map =
-      AArch64GenRegisterBankInfo::getValueMapping((PartialMappingIdx)FirstInBank, Size)[Offset];
+  const ValueMapping &Map = AArch64GenRegisterBankInfo::getValueMapping(
+      (PartialMappingIdx)FirstInBank, TypeSize::Fixed(Size))[Offset];
   return Map.BreakDown == &PartMappings[PartialMapBaseIdx] &&
          Map.NumBreakDowns == 1;
 }
@@ -167,7 +167,7 @@ bool AArch64GenRegisterBankInfo::checkPartialMappingIdx(
 }
 
 unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
-                                                             unsigned Size) {
+                                                             TypeSize Size) {
   if (RBIdx == PMI_FirstGPR) {
     if (Size <= 32)
       return 0;
@@ -178,6 +178,8 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
     return -1;
   }
   if (RBIdx == PMI_FirstFPR) {
+    if (Size.isScalable())
+      return 3;
     if (Size <= 16)
       return 0;
     if (Size <= 32)
@@ -197,7 +199,7 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getValueMapping(PartialMappingIdx RBIdx,
-                                            unsigned Size) {
+                                            TypeSize Size) {
   assert(RBIdx != PartialMappingIdx::PMI_None && "No mapping needed for that");
   unsigned BaseIdxOffset = getRegBankBaseIdxOffset(RBIdx, Size);
   if (BaseIdxOffset == -1u)
@@ -221,7 +223,7 @@ const AArch64GenRegisterBankInfo::PartialMappingIdx
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getCopyMapping(unsigned DstBankID,
-                                           unsigned SrcBankID, unsigned Size) {
+                                           unsigned SrcBankID, TypeSize Size) {
   assert(DstBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   assert(SrcBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   PartialMappingIdx DstRBIdx = BankIDToCopyMapIdx[DstBankID];
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d42ae4ff93a4442..2dc7ffbd7be4335 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -144,6 +144,11 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
                                  cl::desc("Maximum of xors"));
 
+cl::opt<bool> DisableSVEGISel(
+    "aarch64-disable-sve-gisel", cl::Hidden,
+    cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+    cl::init(true));
+
 /// Value type used for condition codes.
 static const MVT MVT_CC = MVT::i32;
 
@@ -25277,15 +25282,15 @@ bool AArch64TargetLowering::shouldLocalize(
 }
 
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  if (Inst.getType()->isScalableTy())
+  if (DisableSVEGISel && Inst.getType()->isScalableTy())
     return true;
 
   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (Inst.getOperand(i)->getType()->isScalableTy())
+    if (DisableSVEGISel && Inst.getOperand(i)->getType()->isScalableTy())
       return true;
 
   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
-    if (AI->getAllocatedType()->isScalableTy())
+    if (DisableSVEGISel && AI->getAllocatedType()->isScalableTy())
       return true;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba74..9e2ed356299e2bc 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -13,7 +13,7 @@
 def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
 
 /// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
 
 /// Conditional register: NZCV.
 def CCRegBank : RegisterBank<"CC", [CCR]>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 84057ea8d2214ac..f8f321c5881b68e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -51,6 +51,8 @@
 
 using namespace llvm;
 
+extern cl::opt<bool> DisableSVEGISel;
+
 AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
   : CallLowering(&TLI) {}
 
@@ -387,8 +389,8 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
       // i1 is a special case because SDAG i1 true is naturally zero extended
       // when widened using ANYEXT. We need to do it explicitly here.
       auto &Flags = CurArgInfo.Flags[0];
-      if (MRI.getType(CurVReg).getSizeInBits() == 1 && !Flags.isSExt() &&
-          !Flags.isZExt()) {
+      if (MRI.getType(CurVReg).getSizeInBits() == TypeSize::Fixed(1) &&
+          !Flags.isSExt() && !Flags.isZExt()) {
         CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
       } else if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) ==
                  1) {
@@ -523,10 +525,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,
 
 bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
   auto &F = MF.getFunction();
-  if (F.getReturnType()->isScalableTy() ||
-      llvm::any_of(F.args(), [](const Argument &A) {
-        return A.getType()->isScalableTy();
-      }))
+  if (DisableSVEGISel && (F.getReturnType()->isScalableTy() ||
+                          llvm::any_of(F.args(), [](const Argument &A) {
+                            return A.getType()->isScalableTy();
+                          })))
     return true;
   const auto &ST = MF.getSubtarget<AArch64Subtarget>();
   if (!ST.hasNEON() || !ST.hasFPARMv8()) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index bdaae4dd724d536..9ad1e30c802dad9 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -595,11 +595,12 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
 /// Given a register bank, and size in bits, return the smallest register class
 /// that can represent that combination.
 static const TargetRegisterClass *
-getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
+getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
                       bool GetAllRegSet = false) {
   unsigned RegBankID = RB.getID();
 
   if (RegBankID == AArch64::GPRRegBankID) {
+    assert(!SizeInBits.isScalable() && "Unexpected scalable register size");
     if (SizeInBits <= 32)
       return GetAllRegSet ? &AArch64::GPR32allRegClass
                           : &AArch64::GPR32RegClass;
@@ -611,6 +612,12 @@ getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
   }
 
   if (RegBankID == AArch64::FPRRegBankID) {
+    if (SizeInBits.isScalable()) {
+      assert(SizeInBits == TypeSize::Scalable(128) &&
+             "Unexpected scalable register size");
+      return &AArch64::ZPRRegClass;
+    }
+
     switch (SizeInBits) {
     default:
       return nullptr;
@@ -937,8 +944,8 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   Register SrcReg = I.getOperand(1).getReg();
   const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
   const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
-  unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
-  unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
+  TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
+  TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
 
   // Special casing for cross-bank copies of s1s. We can technically represent
   // a 1-bit value with any size of register. The minimum size for a GPR is 32
@@ -948,8 +955,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   // then we can pull it into the helpers that get the appropriate class for a
   // register bank. Or make a new helper that carries along some constraint
   // information.
-  if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
-    SrcSize = DstSize = 32;
+  if (SrcRegBank != DstRegBank &&
+      (DstSize == TypeSize::Fixed(1) && SrcSize == TypeSize::Fixed(1)))
+    SrcSize = DstSize = TypeSize::Fixed(32);
 
   return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
           getMinClassForRegBank(DstRegBank, DstSize, true)};
@@ -1014,10 +1022,15 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
       return false;
     }
 
-    unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
-    unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
+    TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
+    TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
     unsigned SubReg;
 
+    if (SrcSize.isScalable()) {
+      assert(DstSize.isScalable() && "Unhandled scalable copy");
+      return true;
+    }
+
     // If the source bank doesn't support a subregister copy small enough,
     // then we first need to copy to the destination bank.
     if (getMinSizeForRegBank(SrcRegBank) > DstSize) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 000fd648595222b..e55cf5400565215 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -59,6 +59,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT v4s32 = LLT::fixed_vector(4, 32);
   const LLT v2s64 = LLT::fixed_vector(2, 64);
   const LLT v2p0 = LLT::fixed_vector(2, p0);
+  const LLT nxv16s8 = LLT::scalable_vector(16, 8);
+  const LLT nxv8s16 = LLT::scalable_vector(8, 16);
+  const LLT nxv4s32 = LLT::scalable_vector(4, 32);
+  const LLT nxv2s64 = LLT::scalable_vector(2, 64);
 
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
@@ -238,7 +242,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                                G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
                                G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
                                G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
+      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64, nxv16s8, nxv8s16,
+                 nxv4s32, nxv2s64})
       .legalIf([=](const LegalityQuery &Query) {
         const auto &Ty = Query.Types[0];
         return (Ty == v8s16 || Ty == v4s16) && HasFP16;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 4ca5b3674461d89..1466570cf317a7e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -161,17 +161,18 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
     unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min;               \
     (void)PartialMapDstIdx;                                                    \
     (void)PartialMapSrcIdx;                                                    \
-    const ValueMapping *Map = getCopyMapping(                                  \
-        AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size);  \
+    const ValueMapping *Map =                                                  \
+        getCopyMapping(AArch64::RBNameDst##RegBankID,                          \
+                       AArch64::RBNameSrc##RegBankID, TypeSize::Fixed(Size));  \
     (void)Map;                                                                 \
     assert(Map[0].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] &&  \
-           Map[0].NumBreakDowns == 1 && #RBNameDst #Size                       \
-           " Dst is incorrectly initialized");                                 \
+           Map[0].NumBreakDowns == 1 &&                                        \
+           #RBNameDst #Size " Dst is incorrectly initialized");                \
     assert(Map[1].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] &&  \
-           Map[1].NumBreakDowns == 1 && #RBNameSrc #Size                       \
-           " Src is incorrectly initialized");                                 \
+           Map[1].NumBreakDowns == 1 &&                                        \
+           #RBNameSrc #Size " Src is incorrectly initialized");                \
                                                                                \
   } while (false)
 
@@ -255,6 +256,9 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case AArch64::QQRegClassID:
   case AArch64::QQQRegClassID:
   case AArch64::QQQQRegClassID:
+  case AArch64::ZPR_3bRegClassID:
+  case AArch64::ZPR_4bRegClassID:
+  case AArch64::ZPRRegClassID:
     return getRegBank(AArch64::FPRRegBankID);
   case AArch64::GPR32commonRegClassID:
   case AArch64::GPR32RegClassID:
@@ -299,8 +303,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
   case TargetOpcode::G_OR: {
     // 32 and 64-bit or can be mapped on either FPR or
     // GPR for the same cost.
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -320,8 +324,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_BITCAST: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -341,15 +345,13 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     const InstructionMapping &GPRToFPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::FPRRegBankID, AArch64::GPRRegBankID, Size),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRToGPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::GPRRegBankID, AArch64::FPRRegBankID, Size),
         /*NumOperands*/ 2);
 
@@ -360,8 +362,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_LOAD: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -372,15 +374,17 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     InstructionMappings AltMappings;
     const InstructionMapping &GPRMapping = getInstructionMapping(
         /*ID*/ 1, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstGPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRMapping = getInstructionMapping(
         /*ID*/ 2, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstFPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
 
     AltMappings.push_back(&GPRMapping);
@@ -458,7 +462,7 @@ AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
          "This code is for instructions with 3 or less operands");
 
   LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-  unsigned Size = Ty.getSizeInBits();
+  TypeSize Size = Ty.getSizeInBits();
   bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);
 
   PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
@@ -711,9 +715,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       // If both RB are null that means both registers are generic.
       // We shouldn't be here.
       assert(DstRB && SrcRB && "Both RegBank were nullptr");
-      unsigned Size = getSizeInBits(DstReg, MRI, TRI);
+      TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
       return getInstructionMapping(
-          DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::Fixed(Size)),
+          DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
           getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
           // We only care about the mapping of the destination.
           /*NumOperands*/ 1);
@@ -724,7 +728,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_BITCAST: {
     LLT DstTy = MRI.getT...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 21, 2023

@llvm/pr-subscribers-llvm-globalisel

Author: David Green (davemgreen)

Changes

This appears to be the minimum needed to get SVE fadd working. It needs more testing, just putting it up to show it works OK so far.


Patch is 25.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72976.diff

13 Files Affected:

  • (modified) llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/RegisterBankInfo.cpp (+3-3)
  • (modified) llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def (+7-5)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+8-3)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterBanks.td (+1-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp (+8-6)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+20-7)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+6-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp (+33-29)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h (+3-3)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+2-4)
  • (added) llvm/test/CodeGen/AArch64/sve-add.ll (+12)
  • (modified) llvm/utils/TableGen/InfoByHwMode.cpp (+4-5)
diff --git a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
index baea773cf528e92..f04e4cdb764f2a3 100644
--- a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
@@ -277,7 +277,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
     }
 
     const LLT Ty = MRI.getType(VReg);
-    if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
+    if (Ty.isValid() &&
+        TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
       reportGISelFailure(
           MF, TPC, MORE, "gisel-select",
           "VReg's low-level type and register class have different sizes", *MI);
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index 6a96bb40f56aed9..5548430d1b0ae88 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -565,9 +565,9 @@ bool RegisterBankInfo::ValueMapping::verify(const RegisterBankInfo &RBI,
     OrigValueBitWidth =
         std::max(OrigValueBitWidth, PartMap.getHighBitIdx() + 1);
   }
-  assert(MeaningfulBitWidth.isScalable() ||
-         OrigValueBitWidth >= MeaningfulBitWidth &&
-             "Meaningful bits not covered by the mapping");
+  assert((MeaningfulBitWidth.isScalable() ||
+          OrigValueBitWidth >= MeaningfulBitWidth) &&
+         "Meaningful bits not covered by the mapping");
   APInt ValueMask(OrigValueBitWidth, 0);
   for (const RegisterBankInfo::PartialMapping &PartMap : *this) {
     // Check that the union of the partial mappings covers the whole value,
diff --git a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
index b87421e5ee46ae5..0b3557e67240520 100644
--- a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
+++ b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
@@ -136,8 +136,8 @@ bool AArch64GenRegisterBankInfo::checkValueMapImpl(unsigned Idx,
                                                    unsigned Size,
                                                    unsigned Offset) {
   unsigned PartialMapBaseIdx = Idx - PartialMappingIdx::PMI_Min;
-  const ValueMapping &Map =
-      AArch64GenRegisterBankInfo::getValueMapping((PartialMappingIdx)FirstInBank, Size)[Offset];
+  const ValueMapping &Map = AArch64GenRegisterBankInfo::getValueMapping(
+      (PartialMappingIdx)FirstInBank, TypeSize::Fixed(Size))[Offset];
   return Map.BreakDown == &PartMappings[PartialMapBaseIdx] &&
          Map.NumBreakDowns == 1;
 }
@@ -167,7 +167,7 @@ bool AArch64GenRegisterBankInfo::checkPartialMappingIdx(
 }
 
 unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
-                                                             unsigned Size) {
+                                                             TypeSize Size) {
   if (RBIdx == PMI_FirstGPR) {
     if (Size <= 32)
       return 0;
@@ -178,6 +178,8 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
     return -1;
   }
   if (RBIdx == PMI_FirstFPR) {
+    if (Size.isScalable())
+      return 3;
     if (Size <= 16)
       return 0;
     if (Size <= 32)
@@ -197,7 +199,7 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getValueMapping(PartialMappingIdx RBIdx,
-                                            unsigned Size) {
+                                            TypeSize Size) {
   assert(RBIdx != PartialMappingIdx::PMI_None && "No mapping needed for that");
   unsigned BaseIdxOffset = getRegBankBaseIdxOffset(RBIdx, Size);
   if (BaseIdxOffset == -1u)
@@ -221,7 +223,7 @@ const AArch64GenRegisterBankInfo::PartialMappingIdx
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getCopyMapping(unsigned DstBankID,
-                                           unsigned SrcBankID, unsigned Size) {
+                                           unsigned SrcBankID, TypeSize Size) {
   assert(DstBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   assert(SrcBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   PartialMappingIdx DstRBIdx = BankIDToCopyMapIdx[DstBankID];
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d42ae4ff93a4442..2dc7ffbd7be4335 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -144,6 +144,11 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
                                  cl::desc("Maximum of xors"));
 
+cl::opt<bool> DisableSVEGISel(
+    "aarch64-disable-sve-gisel", cl::Hidden,
+    cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+    cl::init(true));
+
 /// Value type used for condition codes.
 static const MVT MVT_CC = MVT::i32;
 
@@ -25277,15 +25282,15 @@ bool AArch64TargetLowering::shouldLocalize(
 }
 
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  if (Inst.getType()->isScalableTy())
+  if (DisableSVEGISel && Inst.getType()->isScalableTy())
     return true;
 
   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (Inst.getOperand(i)->getType()->isScalableTy())
+    if (DisableSVEGISel && Inst.getOperand(i)->getType()->isScalableTy())
       return true;
 
   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
-    if (AI->getAllocatedType()->isScalableTy())
+    if (DisableSVEGISel && AI->getAllocatedType()->isScalableTy())
       return true;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba74..9e2ed356299e2bc 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -13,7 +13,7 @@
 def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
 
 /// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
 
 /// Conditional register: NZCV.
 def CCRegBank : RegisterBank<"CC", [CCR]>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 84057ea8d2214ac..f8f321c5881b68e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -51,6 +51,8 @@
 
 using namespace llvm;
 
+extern cl::opt<bool> DisableSVEGISel;
+
 AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
   : CallLowering(&TLI) {}
 
@@ -387,8 +389,8 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
       // i1 is a special case because SDAG i1 true is naturally zero extended
       // when widened using ANYEXT. We need to do it explicitly here.
       auto &Flags = CurArgInfo.Flags[0];
-      if (MRI.getType(CurVReg).getSizeInBits() == 1 && !Flags.isSExt() &&
-          !Flags.isZExt()) {
+      if (MRI.getType(CurVReg).getSizeInBits() == TypeSize::Fixed(1) &&
+          !Flags.isSExt() && !Flags.isZExt()) {
         CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
       } else if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) ==
                  1) {
@@ -523,10 +525,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,
 
 bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
   auto &F = MF.getFunction();
-  if (F.getReturnType()->isScalableTy() ||
-      llvm::any_of(F.args(), [](const Argument &A) {
-        return A.getType()->isScalableTy();
-      }))
+  if (DisableSVEGISel && (F.getReturnType()->isScalableTy() ||
+                          llvm::any_of(F.args(), [](const Argument &A) {
+                            return A.getType()->isScalableTy();
+                          })))
     return true;
   const auto &ST = MF.getSubtarget<AArch64Subtarget>();
   if (!ST.hasNEON() || !ST.hasFPARMv8()) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index bdaae4dd724d536..9ad1e30c802dad9 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -595,11 +595,12 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
 /// Given a register bank, and size in bits, return the smallest register class
 /// that can represent that combination.
 static const TargetRegisterClass *
-getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
+getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
                       bool GetAllRegSet = false) {
   unsigned RegBankID = RB.getID();
 
   if (RegBankID == AArch64::GPRRegBankID) {
+    assert(!SizeInBits.isScalable() && "Unexpected scalable register size");
     if (SizeInBits <= 32)
       return GetAllRegSet ? &AArch64::GPR32allRegClass
                           : &AArch64::GPR32RegClass;
@@ -611,6 +612,12 @@ getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
   }
 
   if (RegBankID == AArch64::FPRRegBankID) {
+    if (SizeInBits.isScalable()) {
+      assert(SizeInBits == TypeSize::Scalable(128) &&
+             "Unexpected scalable register size");
+      return &AArch64::ZPRRegClass;
+    }
+
     switch (SizeInBits) {
     default:
       return nullptr;
@@ -937,8 +944,8 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   Register SrcReg = I.getOperand(1).getReg();
   const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
   const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
-  unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
-  unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
+  TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
+  TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
 
   // Special casing for cross-bank copies of s1s. We can technically represent
   // a 1-bit value with any size of register. The minimum size for a GPR is 32
@@ -948,8 +955,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   // then we can pull it into the helpers that get the appropriate class for a
   // register bank. Or make a new helper that carries along some constraint
   // information.
-  if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
-    SrcSize = DstSize = 32;
+  if (SrcRegBank != DstRegBank &&
+      (DstSize == TypeSize::Fixed(1) && SrcSize == TypeSize::Fixed(1)))
+    SrcSize = DstSize = TypeSize::Fixed(32);
 
   return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
           getMinClassForRegBank(DstRegBank, DstSize, true)};
@@ -1014,10 +1022,15 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
       return false;
     }
 
-    unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
-    unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
+    TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
+    TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
     unsigned SubReg;
 
+    if (SrcSize.isScalable()) {
+      assert(DstSize.isScalable() && "Unhandled scalable copy");
+      return true;
+    }
+
     // If the source bank doesn't support a subregister copy small enough,
     // then we first need to copy to the destination bank.
     if (getMinSizeForRegBank(SrcRegBank) > DstSize) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 000fd648595222b..e55cf5400565215 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -59,6 +59,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT v4s32 = LLT::fixed_vector(4, 32);
   const LLT v2s64 = LLT::fixed_vector(2, 64);
   const LLT v2p0 = LLT::fixed_vector(2, p0);
+  const LLT nxv16s8 = LLT::scalable_vector(16, 8);
+  const LLT nxv8s16 = LLT::scalable_vector(8, 16);
+  const LLT nxv4s32 = LLT::scalable_vector(4, 32);
+  const LLT nxv2s64 = LLT::scalable_vector(2, 64);
 
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
@@ -238,7 +242,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                                G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
                                G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
                                G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
+      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64, nxv16s8, nxv8s16,
+                 nxv4s32, nxv2s64})
       .legalIf([=](const LegalityQuery &Query) {
         const auto &Ty = Query.Types[0];
         return (Ty == v8s16 || Ty == v4s16) && HasFP16;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 4ca5b3674461d89..1466570cf317a7e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -161,17 +161,18 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
     unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min;               \
     (void)PartialMapDstIdx;                                                    \
     (void)PartialMapSrcIdx;                                                    \
-    const ValueMapping *Map = getCopyMapping(                                  \
-        AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size);  \
+    const ValueMapping *Map =                                                  \
+        getCopyMapping(AArch64::RBNameDst##RegBankID,                          \
+                       AArch64::RBNameSrc##RegBankID, TypeSize::Fixed(Size));  \
     (void)Map;                                                                 \
     assert(Map[0].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] &&  \
-           Map[0].NumBreakDowns == 1 && #RBNameDst #Size                       \
-           " Dst is incorrectly initialized");                                 \
+           Map[0].NumBreakDowns == 1 &&                                        \
+           #RBNameDst #Size " Dst is incorrectly initialized");                \
     assert(Map[1].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] &&  \
-           Map[1].NumBreakDowns == 1 && #RBNameSrc #Size                       \
-           " Src is incorrectly initialized");                                 \
+           Map[1].NumBreakDowns == 1 &&                                        \
+           #RBNameSrc #Size " Src is incorrectly initialized");                \
                                                                                \
   } while (false)
 
@@ -255,6 +256,9 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case AArch64::QQRegClassID:
   case AArch64::QQQRegClassID:
   case AArch64::QQQQRegClassID:
+  case AArch64::ZPR_3bRegClassID:
+  case AArch64::ZPR_4bRegClassID:
+  case AArch64::ZPRRegClassID:
     return getRegBank(AArch64::FPRRegBankID);
   case AArch64::GPR32commonRegClassID:
   case AArch64::GPR32RegClassID:
@@ -299,8 +303,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
   case TargetOpcode::G_OR: {
     // 32 and 64-bit or can be mapped on either FPR or
     // GPR for the same cost.
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -320,8 +324,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_BITCAST: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -341,15 +345,13 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     const InstructionMapping &GPRToFPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::FPRRegBankID, AArch64::GPRRegBankID, Size),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRToGPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::GPRRegBankID, AArch64::FPRRegBankID, Size),
         /*NumOperands*/ 2);
 
@@ -360,8 +362,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_LOAD: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -372,15 +374,17 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     InstructionMappings AltMappings;
     const InstructionMapping &GPRMapping = getInstructionMapping(
         /*ID*/ 1, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstGPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRMapping = getInstructionMapping(
         /*ID*/ 2, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstFPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
 
     AltMappings.push_back(&GPRMapping);
@@ -458,7 +462,7 @@ AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
          "This code is for instructions with 3 or less operands");
 
   LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-  unsigned Size = Ty.getSizeInBits();
+  TypeSize Size = Ty.getSizeInBits();
   bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);
 
   PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
@@ -711,9 +715,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       // If both RB are null that means both registers are generic.
       // We shouldn't be here.
       assert(DstRB && SrcRB && "Both RegBank were nullptr");
-      unsigned Size = getSizeInBits(DstReg, MRI, TRI);
+      TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
       return getInstructionMapping(
-          DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::Fixed(Size)),
+          DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
           getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
           // We only care about the mapping of the destination.
           /*NumOperands*/ 1);
@@ -724,7 +728,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_BITCAST: {
     LLT DstTy = MRI.getT...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 21, 2023

@llvm/pr-subscribers-backend-amdgpu

Author: David Green (davemgreen)

Changes

This appears to be the minimum needed to get SVE fadd working. It needs more testing, just putting it up to show it works OK so far.


Patch is 25.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72976.diff

13 Files Affected:

  • (modified) llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/RegisterBankInfo.cpp (+3-3)
  • (modified) llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def (+7-5)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+8-3)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterBanks.td (+1-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp (+8-6)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+20-7)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+6-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp (+33-29)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.h (+3-3)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+2-4)
  • (added) llvm/test/CodeGen/AArch64/sve-add.ll (+12)
  • (modified) llvm/utils/TableGen/InfoByHwMode.cpp (+4-5)
diff --git a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
index baea773cf528e92..f04e4cdb764f2a3 100644
--- a/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
@@ -277,7 +277,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
     }
 
     const LLT Ty = MRI.getType(VReg);
-    if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
+    if (Ty.isValid() &&
+        TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
       reportGISelFailure(
           MF, TPC, MORE, "gisel-select",
           "VReg's low-level type and register class have different sizes", *MI);
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index 6a96bb40f56aed9..5548430d1b0ae88 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -565,9 +565,9 @@ bool RegisterBankInfo::ValueMapping::verify(const RegisterBankInfo &RBI,
     OrigValueBitWidth =
         std::max(OrigValueBitWidth, PartMap.getHighBitIdx() + 1);
   }
-  assert(MeaningfulBitWidth.isScalable() ||
-         OrigValueBitWidth >= MeaningfulBitWidth &&
-             "Meaningful bits not covered by the mapping");
+  assert((MeaningfulBitWidth.isScalable() ||
+          OrigValueBitWidth >= MeaningfulBitWidth) &&
+         "Meaningful bits not covered by the mapping");
   APInt ValueMask(OrigValueBitWidth, 0);
   for (const RegisterBankInfo::PartialMapping &PartMap : *this) {
     // Check that the union of the partial mappings covers the whole value,
diff --git a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
index b87421e5ee46ae5..0b3557e67240520 100644
--- a/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
+++ b/llvm/lib/Target/AArch64/AArch64GenRegisterBankInfo.def
@@ -136,8 +136,8 @@ bool AArch64GenRegisterBankInfo::checkValueMapImpl(unsigned Idx,
                                                    unsigned Size,
                                                    unsigned Offset) {
   unsigned PartialMapBaseIdx = Idx - PartialMappingIdx::PMI_Min;
-  const ValueMapping &Map =
-      AArch64GenRegisterBankInfo::getValueMapping((PartialMappingIdx)FirstInBank, Size)[Offset];
+  const ValueMapping &Map = AArch64GenRegisterBankInfo::getValueMapping(
+      (PartialMappingIdx)FirstInBank, TypeSize::Fixed(Size))[Offset];
   return Map.BreakDown == &PartMappings[PartialMapBaseIdx] &&
          Map.NumBreakDowns == 1;
 }
@@ -167,7 +167,7 @@ bool AArch64GenRegisterBankInfo::checkPartialMappingIdx(
 }
 
 unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
-                                                             unsigned Size) {
+                                                             TypeSize Size) {
   if (RBIdx == PMI_FirstGPR) {
     if (Size <= 32)
       return 0;
@@ -178,6 +178,8 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
     return -1;
   }
   if (RBIdx == PMI_FirstFPR) {
+    if (Size.isScalable())
+      return 3;
     if (Size <= 16)
       return 0;
     if (Size <= 32)
@@ -197,7 +199,7 @@ unsigned AArch64GenRegisterBankInfo::getRegBankBaseIdxOffset(unsigned RBIdx,
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getValueMapping(PartialMappingIdx RBIdx,
-                                            unsigned Size) {
+                                            TypeSize Size) {
   assert(RBIdx != PartialMappingIdx::PMI_None && "No mapping needed for that");
   unsigned BaseIdxOffset = getRegBankBaseIdxOffset(RBIdx, Size);
   if (BaseIdxOffset == -1u)
@@ -221,7 +223,7 @@ const AArch64GenRegisterBankInfo::PartialMappingIdx
 
 const RegisterBankInfo::ValueMapping *
 AArch64GenRegisterBankInfo::getCopyMapping(unsigned DstBankID,
-                                           unsigned SrcBankID, unsigned Size) {
+                                           unsigned SrcBankID, TypeSize Size) {
   assert(DstBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   assert(SrcBankID < AArch64::NumRegisterBanks && "Invalid bank ID");
   PartialMappingIdx DstRBIdx = BankIDToCopyMapIdx[DstBankID];
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d42ae4ff93a4442..2dc7ffbd7be4335 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -144,6 +144,11 @@ static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
                                  cl::desc("Maximum of xors"));
 
+cl::opt<bool> DisableSVEGISel(
+    "aarch64-disable-sve-gisel", cl::Hidden,
+    cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
+    cl::init(true));
+
 /// Value type used for condition codes.
 static const MVT MVT_CC = MVT::i32;
 
@@ -25277,15 +25282,15 @@ bool AArch64TargetLowering::shouldLocalize(
 }
 
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  if (Inst.getType()->isScalableTy())
+  if (DisableSVEGISel && Inst.getType()->isScalableTy())
     return true;
 
   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (Inst.getOperand(i)->getType()->isScalableTy())
+    if (DisableSVEGISel && Inst.getOperand(i)->getType()->isScalableTy())
       return true;
 
   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
-    if (AI->getAllocatedType()->isScalableTy())
+    if (DisableSVEGISel && AI->getAllocatedType()->isScalableTy())
       return true;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
index 615ce7d51d9ba74..9e2ed356299e2bc 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterBanks.td
@@ -13,7 +13,7 @@
 def GPRRegBank : RegisterBank<"GPR", [XSeqPairsClass]>;
 
 /// Floating Point/Vector Registers: B, H, S, D, Q.
-def FPRRegBank : RegisterBank<"FPR", [QQQQ]>;
+def FPRRegBank : RegisterBank<"FPR", [QQQQ, ZPR]>;
 
 /// Conditional register: NZCV.
 def CCRegBank : RegisterBank<"CC", [CCR]>;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index 84057ea8d2214ac..f8f321c5881b68e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -51,6 +51,8 @@
 
 using namespace llvm;
 
+extern cl::opt<bool> DisableSVEGISel;
+
 AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
   : CallLowering(&TLI) {}
 
@@ -387,8 +389,8 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
       // i1 is a special case because SDAG i1 true is naturally zero extended
       // when widened using ANYEXT. We need to do it explicitly here.
       auto &Flags = CurArgInfo.Flags[0];
-      if (MRI.getType(CurVReg).getSizeInBits() == 1 && !Flags.isSExt() &&
-          !Flags.isZExt()) {
+      if (MRI.getType(CurVReg).getSizeInBits() == TypeSize::Fixed(1) &&
+          !Flags.isSExt() && !Flags.isZExt()) {
         CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
       } else if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) ==
                  1) {
@@ -523,10 +525,10 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,
 
 bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
   auto &F = MF.getFunction();
-  if (F.getReturnType()->isScalableTy() ||
-      llvm::any_of(F.args(), [](const Argument &A) {
-        return A.getType()->isScalableTy();
-      }))
+  if (DisableSVEGISel && (F.getReturnType()->isScalableTy() ||
+                          llvm::any_of(F.args(), [](const Argument &A) {
+                            return A.getType()->isScalableTy();
+                          })))
     return true;
   const auto &ST = MF.getSubtarget<AArch64Subtarget>();
   if (!ST.hasNEON() || !ST.hasFPARMv8()) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index bdaae4dd724d536..9ad1e30c802dad9 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -595,11 +595,12 @@ getRegClassForTypeOnBank(LLT Ty, const RegisterBank &RB,
 /// Given a register bank, and size in bits, return the smallest register class
 /// that can represent that combination.
 static const TargetRegisterClass *
-getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
+getMinClassForRegBank(const RegisterBank &RB, TypeSize SizeInBits,
                       bool GetAllRegSet = false) {
   unsigned RegBankID = RB.getID();
 
   if (RegBankID == AArch64::GPRRegBankID) {
+    assert(!SizeInBits.isScalable() && "Unexpected scalable register size");
     if (SizeInBits <= 32)
       return GetAllRegSet ? &AArch64::GPR32allRegClass
                           : &AArch64::GPR32RegClass;
@@ -611,6 +612,12 @@ getMinClassForRegBank(const RegisterBank &RB, unsigned SizeInBits,
   }
 
   if (RegBankID == AArch64::FPRRegBankID) {
+    if (SizeInBits.isScalable()) {
+      assert(SizeInBits == TypeSize::Scalable(128) &&
+             "Unexpected scalable register size");
+      return &AArch64::ZPRRegClass;
+    }
+
     switch (SizeInBits) {
     default:
       return nullptr;
@@ -937,8 +944,8 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   Register SrcReg = I.getOperand(1).getReg();
   const RegisterBank &DstRegBank = *RBI.getRegBank(DstReg, MRI, TRI);
   const RegisterBank &SrcRegBank = *RBI.getRegBank(SrcReg, MRI, TRI);
-  unsigned DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
-  unsigned SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
+  TypeSize DstSize = RBI.getSizeInBits(DstReg, MRI, TRI);
+  TypeSize SrcSize = RBI.getSizeInBits(SrcReg, MRI, TRI);
 
   // Special casing for cross-bank copies of s1s. We can technically represent
   // a 1-bit value with any size of register. The minimum size for a GPR is 32
@@ -948,8 +955,9 @@ getRegClassesForCopy(MachineInstr &I, const TargetInstrInfo &TII,
   // then we can pull it into the helpers that get the appropriate class for a
   // register bank. Or make a new helper that carries along some constraint
   // information.
-  if (SrcRegBank != DstRegBank && (DstSize == 1 && SrcSize == 1))
-    SrcSize = DstSize = 32;
+  if (SrcRegBank != DstRegBank &&
+      (DstSize == TypeSize::Fixed(1) && SrcSize == TypeSize::Fixed(1)))
+    SrcSize = DstSize = TypeSize::Fixed(32);
 
   return {getMinClassForRegBank(SrcRegBank, SrcSize, true),
           getMinClassForRegBank(DstRegBank, DstSize, true)};
@@ -1014,10 +1022,15 @@ static bool selectCopy(MachineInstr &I, const TargetInstrInfo &TII,
       return false;
     }
 
-    unsigned SrcSize = TRI.getRegSizeInBits(*SrcRC);
-    unsigned DstSize = TRI.getRegSizeInBits(*DstRC);
+    TypeSize SrcSize = TRI.getRegSizeInBits(*SrcRC);
+    TypeSize DstSize = TRI.getRegSizeInBits(*DstRC);
     unsigned SubReg;
 
+    if (SrcSize.isScalable()) {
+      assert(DstSize.isScalable() && "Unhandled scalable copy");
+      return true;
+    }
+
     // If the source bank doesn't support a subregister copy small enough,
     // then we first need to copy to the destination bank.
     if (getMinSizeForRegBank(SrcRegBank) > DstSize) {
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 000fd648595222b..e55cf5400565215 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -59,6 +59,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   const LLT v4s32 = LLT::fixed_vector(4, 32);
   const LLT v2s64 = LLT::fixed_vector(2, 64);
   const LLT v2p0 = LLT::fixed_vector(2, p0);
+  const LLT nxv16s8 = LLT::scalable_vector(16, 8);
+  const LLT nxv8s16 = LLT::scalable_vector(8, 16);
+  const LLT nxv4s32 = LLT::scalable_vector(4, 32);
+  const LLT nxv2s64 = LLT::scalable_vector(2, 64);
 
   std::initializer_list<LLT> PackedVectorAllTypeList = {/* Begin 128bit types */
                                                         v16s8, v8s16, v4s32,
@@ -238,7 +242,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                                G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
                                G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
                                G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
+      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64, nxv16s8, nxv8s16,
+                 nxv4s32, nxv2s64})
       .legalIf([=](const LegalityQuery &Query) {
         const auto &Ty = Query.Types[0];
         return (Ty == v8s16 || Ty == v4s16) && HasFP16;
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
index 4ca5b3674461d89..1466570cf317a7e 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
@@ -161,17 +161,18 @@ AArch64RegisterBankInfo::AArch64RegisterBankInfo(
     unsigned PartialMapSrcIdx = PMI_##RBNameSrc##Size - PMI_Min;               \
     (void)PartialMapDstIdx;                                                    \
     (void)PartialMapSrcIdx;                                                    \
-    const ValueMapping *Map = getCopyMapping(                                  \
-        AArch64::RBNameDst##RegBankID, AArch64::RBNameSrc##RegBankID, Size);  \
+    const ValueMapping *Map =                                                  \
+        getCopyMapping(AArch64::RBNameDst##RegBankID,                          \
+                       AArch64::RBNameSrc##RegBankID, TypeSize::Fixed(Size));  \
     (void)Map;                                                                 \
     assert(Map[0].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapDstIdx] &&  \
-           Map[0].NumBreakDowns == 1 && #RBNameDst #Size                       \
-           " Dst is incorrectly initialized");                                 \
+           Map[0].NumBreakDowns == 1 &&                                        \
+           #RBNameDst #Size " Dst is incorrectly initialized");                \
     assert(Map[1].BreakDown ==                                                 \
                &AArch64GenRegisterBankInfo::PartMappings[PartialMapSrcIdx] &&  \
-           Map[1].NumBreakDowns == 1 && #RBNameSrc #Size                       \
-           " Src is incorrectly initialized");                                 \
+           Map[1].NumBreakDowns == 1 &&                                        \
+           #RBNameSrc #Size " Src is incorrectly initialized");                \
                                                                                \
   } while (false)
 
@@ -255,6 +256,9 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
   case AArch64::QQRegClassID:
   case AArch64::QQQRegClassID:
   case AArch64::QQQQRegClassID:
+  case AArch64::ZPR_3bRegClassID:
+  case AArch64::ZPR_4bRegClassID:
+  case AArch64::ZPRRegClassID:
     return getRegBank(AArch64::FPRRegBankID);
   case AArch64::GPR32commonRegClassID:
   case AArch64::GPR32RegClassID:
@@ -299,8 +303,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
   case TargetOpcode::G_OR: {
     // 32 and 64-bit or can be mapped on either FPR or
     // GPR for the same cost.
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -320,8 +324,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_BITCAST: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 32 && Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(32) && Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -341,15 +345,13 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     const InstructionMapping &GPRToFPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::FPRRegBankID, AArch64::GPRRegBankID, Size),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRToGPRMapping = getInstructionMapping(
         /*ID*/ 3,
         /*Cost*/
-        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank,
-                 TypeSize::Fixed(Size)),
+        copyCost(AArch64::GPRRegBank, AArch64::FPRRegBank, Size),
         getCopyMapping(AArch64::GPRRegBankID, AArch64::FPRRegBankID, Size),
         /*NumOperands*/ 2);
 
@@ -360,8 +362,8 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     return AltMappings;
   }
   case TargetOpcode::G_LOAD: {
-    unsigned Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
-    if (Size != 64)
+    TypeSize Size = getSizeInBits(MI.getOperand(0).getReg(), MRI, TRI);
+    if (Size != TypeSize::Fixed(64))
       break;
 
     // If the instruction has any implicit-defs or uses,
@@ -372,15 +374,17 @@ AArch64RegisterBankInfo::getInstrAlternativeMappings(
     InstructionMappings AltMappings;
     const InstructionMapping &GPRMapping = getInstructionMapping(
         /*ID*/ 1, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstGPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstGPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
     const InstructionMapping &FPRMapping = getInstructionMapping(
         /*ID*/ 2, /*Cost*/ 1,
-        getOperandsMapping({getValueMapping(PMI_FirstFPR, Size),
-                            // Addresses are GPR 64-bit.
-                            getValueMapping(PMI_FirstGPR, 64)}),
+        getOperandsMapping(
+            {getValueMapping(PMI_FirstFPR, Size),
+             // Addresses are GPR 64-bit.
+             getValueMapping(PMI_FirstGPR, TypeSize::Fixed(64))}),
         /*NumOperands*/ 2);
 
     AltMappings.push_back(&GPRMapping);
@@ -458,7 +462,7 @@ AArch64RegisterBankInfo::getSameKindOfOperandsMapping(
          "This code is for instructions with 3 or less operands");
 
   LLT Ty = MRI.getType(MI.getOperand(0).getReg());
-  unsigned Size = Ty.getSizeInBits();
+  TypeSize Size = Ty.getSizeInBits();
   bool IsFPR = Ty.isVector() || isPreISelGenericFloatingPointOpcode(Opc);
 
   PartialMappingIdx RBIdx = IsFPR ? PMI_FirstFPR : PMI_FirstGPR;
@@ -711,9 +715,9 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
       // If both RB are null that means both registers are generic.
       // We shouldn't be here.
       assert(DstRB && SrcRB && "Both RegBank were nullptr");
-      unsigned Size = getSizeInBits(DstReg, MRI, TRI);
+      TypeSize Size = getSizeInBits(DstReg, MRI, TRI);
       return getInstructionMapping(
-          DefaultMappingID, copyCost(*DstRB, *SrcRB, TypeSize::Fixed(Size)),
+          DefaultMappingID, copyCost(*DstRB, *SrcRB, Size),
           getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
           // We only care about the mapping of the destination.
           /*NumOperands*/ 1);
@@ -724,7 +728,7 @@ AArch64RegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
   case TargetOpcode::G_BITCAST: {
     LLT DstTy = MRI.getT...
[truncated]

@@ -51,6 +51,8 @@

using namespace llvm;

extern cl::opt<bool> DisableSVEGISel;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unnecessary. I don't think this should be a long lived option

@tschuett
Copy link
Member

tschuett commented Jan 5, 2024

Can there be tests for all 4 legalized types?

Copy link

github-actions bot commented Jan 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@davemgreen
Copy link
Collaborator Author

I was rebasing this yesterday but didnt get as far as pushing it apparently.

I think it might be good to get it in so there is some testing for scalable vectors, even if it's relatively bare bones at the moment. As for fully supporting all the operations, that would be a lot of work and it might be best to keep focussing on the base architecture for the time being. The option lets us keep this as something that can be worked on in the background.

This appears to be the minimum needed to get SVE fadd working.
@tschuett
Copy link
Member

tschuett commented Jan 5, 2024

const bool HasSVE = ST.  ;

@@ -238,7 +241,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move G_FADD into a separate builder and slowly move/legalize/merge the two builders back into one? G_INTRINSIC_TRUNC is legal for SVE?!?

@tschuett
Copy link
Member

tschuett commented Jan 5, 2024

Would this eventually work in SelectionDAGCompat.td?

def : GINodeEquiv<G_FADD_PRED, fadd>;

@@ -238,7 +241,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR,
G_FRINT, G_FNEARBYINT, G_INTRINSIC_TRUNC,
G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64, nxv8s16, nxv4s32,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HasSVE?

@tschuett
Copy link
Member

tschuett commented Jan 5, 2024

Thanks. I volunteer to legalize.

@arsenm
Copy link
Contributor

arsenm commented Feb 6, 2024

ping?

Him188 added a commit to Him188/llvm-project that referenced this pull request May 14, 2024
…TORE

This patch adds basic support for scalable vector types in load & store instructions for AArch64 with GISel.
Only scalable vector types with a 128-bit base size are supported, e.g. <vscale x 4 x i32>, <vscale x 16 x i8>.

This patch adapted some ideas from a similar abandoned patch llvm#72976.
Him188 added a commit that referenced this pull request May 30, 2024
…TORE (#92130)

This patch adds basic support for scalable vector types in load & store
instructions for AArch64 with GISel.

Only scalable vector types with a 128-bit base size are supported, e.g.
`<vscale x 4 x i32>`, `<vscale x 16 x i8>`.

This patch adapted some ideas from a similar abandoned patch
[#72976.
HendrikHuebner pushed a commit to HendrikHuebner/llvm-project that referenced this pull request Jun 2, 2024
…TORE (llvm#92130)

This patch adds basic support for scalable vector types in load & store
instructions for AArch64 with GISel.

Only scalable vector types with a 128-bit base size are supported, e.g.
`<vscale x 4 x i32>`, `<vscale x 16 x i8>`.

This patch adapted some ideas from a similar abandoned patch
[llvm#72976.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants