diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 4eab357f1b33b..10e1223825193 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -707,11 +707,15 @@ class TargetTransformInfo { /// The type may be VoidTy, in which case only return true if the addressing /// mode is legal for a load/store of any legal type. /// If target returns true in LSRWithInstrQueries(), I may be valid. + /// \param ScalableOffset represents a quantity of bytes multiplied by vscale, + /// an invariant value known only at runtime. Most targets should not accept + /// a scalable offset. + /// /// TODO: Handle pre/postinc as well. bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, - unsigned AddrSpace = 0, - Instruction *I = nullptr) const; + unsigned AddrSpace = 0, Instruction *I = nullptr, + int64_t ScalableOffset = 0) const; /// Return true if LSR cost of C1 is lower than C2. bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1, @@ -1839,7 +1843,8 @@ class TargetTransformInfo::Concept { virtual bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace, - Instruction *I) = 0; + Instruction *I, + int64_t ScalableOffset) = 0; virtual bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1, const TargetTransformInfo::LSRCost &C2) = 0; virtual bool isNumRegsMajorCostOfLSR() = 0; @@ -2300,9 +2305,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { } bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace, - Instruction *I) override { + Instruction *I, int64_t ScalableOffset) override { return Impl.isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale, - AddrSpace, I); + AddrSpace, I, ScalableOffset); } bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1, const TargetTransformInfo::LSRCost &C2) override { diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 7f661bb4a1df2..07eeceeeaa22a 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -220,7 +220,8 @@ class TargetTransformInfoImplBase { bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace, - Instruction *I = nullptr) const { + Instruction *I = nullptr, + int64_t ScalableOffset = 0) const { // Guess that only reg and reg+reg addressing is allowed. This heuristic is // taken from the implementation of LSR. return !BaseGV && BaseOffset == 0 && (Scale == 0 || Scale == 1); diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 61f6564e8cd79..721900038ddd5 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -333,13 +333,15 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { } bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, - bool HasBaseReg, int64_t Scale, - unsigned AddrSpace, Instruction *I = nullptr) { + bool HasBaseReg, int64_t Scale, unsigned AddrSpace, + Instruction *I = nullptr, + int64_t ScalableOffset = 0) { TargetLoweringBase::AddrMode AM; AM.BaseGV = BaseGV; AM.BaseOffs = BaseOffset; AM.HasBaseReg = HasBaseReg; AM.Scale = Scale; + AM.ScalableOffset = ScalableOffset; return getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace, I); } diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 2f164a460db84..4753d8e8a5125 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -2722,17 +2722,19 @@ class TargetLoweringBase { } /// This represents an addressing mode of: - /// BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + /// BaseGV + BaseOffs + BaseReg + Scale*ScaleReg + ScalableOffset*vscale /// If BaseGV is null, there is no BaseGV. /// If BaseOffs is zero, there is no base offset. /// If HasBaseReg is false, there is no base register. /// If Scale is zero, there is no ScaleReg. Scale of 1 indicates a reg with /// no scale. + /// If ScalableOffset is zero, there is no scalable offset. struct AddrMode { GlobalValue *BaseGV = nullptr; int64_t BaseOffs = 0; bool HasBaseReg = false; int64_t Scale = 0; + int64_t ScalableOffset = 0; AddrMode() = default; }; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 15311be4dba27..4b113e6d3798c 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -403,9 +403,10 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace, - Instruction *I) const { + Instruction *I, + int64_t ScalableOffset) const { return TTIImpl->isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, - Scale, AddrSpace, I); + Scale, AddrSpace, I, ScalableOffset); } bool TargetTransformInfo::isLSRCostLess(const LSRCost &C1, diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index 8ac55ee6a5d0c..9990556f89ed8 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -2011,6 +2011,10 @@ bool TargetLoweringBase::isLegalAddressingMode(const DataLayout &DL, // The default implementation of this implements a conservative RISCy, r+r and // r+i addr mode. + // Scalable offsets not supported + if (AM.ScalableOffset) + return false; + // Allows a sign-extended 16-bit immediate field. if (AM.BaseOffs <= -(1LL << 16) || AM.BaseOffs >= (1LL << 16)-1) return false; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 5b7a36d2eba76..2bfaed146e9b6 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16671,15 +16671,29 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL, if (Ty->isScalableTy()) { if (isa(Ty)) { + // See if we have a foldable vscale-based offset, for vector types which + // are either legal or smaller than the minimum; more work will be + // required if we need to consider addressing for types which need + // legalization by splitting. + uint64_t VecNumBytes = DL.getTypeSizeInBits(Ty).getKnownMinValue() / 8; + if (AM.HasBaseReg && !AM.BaseOffs && AM.ScalableOffset && !AM.Scale && + (AM.ScalableOffset % VecNumBytes == 0) && VecNumBytes <= 16 && + isPowerOf2_64(VecNumBytes)) + return isInt<4>(AM.ScalableOffset / (int64_t)VecNumBytes); + uint64_t VecElemNumBytes = DL.getTypeSizeInBits(cast(Ty)->getElementType()) / 8; - return AM.HasBaseReg && !AM.BaseOffs && + return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && (AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes); } - return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale; + return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && !AM.Scale; } + // No scalable offsets allowed for non-scalable types. + if (AM.ScalableOffset) + return false; + // check reg + imm case: // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12 uint64_t NumBytes = 0; diff --git a/llvm/unittests/Target/AArch64/AddressingModes.cpp b/llvm/unittests/Target/AArch64/AddressingModes.cpp index 284ea7ae9233e..0af18d886791a 100644 --- a/llvm/unittests/Target/AArch64/AddressingModes.cpp +++ b/llvm/unittests/Target/AArch64/AddressingModes.cpp @@ -13,11 +13,13 @@ using namespace llvm; namespace { struct AddrMode : public TargetLowering::AddrMode { - constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S) { + constexpr AddrMode(GlobalValue *GV, int64_t Offs, bool HasBase, int64_t S, + int64_t SOffs = 0) { BaseGV = GV; BaseOffs = Offs; HasBaseReg = HasBase; Scale = S; + ScalableOffset = SOffs; } }; struct TestCase { @@ -153,6 +155,45 @@ const std::initializer_list Tests = { {{nullptr, 4096 + 1, true, 0}, 8, false}, }; + +struct SVETestCase { + AddrMode AM; + unsigned TypeBits; + unsigned NumElts; + bool Result; +}; + +const std::initializer_list SVETests = { + // {BaseGV, BaseOffs, HasBaseReg, Scale, SOffs}, EltBits, Count, Result + // Test immediate range -- [-8,7] vector's worth. + // , increment by one vector + {{nullptr, 0, true, 0, 16}, 8, 16, true}, + // , increment by eight vectors + {{nullptr, 0, true, 0, 128}, 32, 4, false}, + // , increment by seven vectors + {{nullptr, 0, true, 0, 112}, 16, 8, true}, + // , decrement by eight vectors + {{nullptr, 0, true, 0, -128}, 64, 2, true}, + // , decrement by nine vectors + {{nullptr, 0, true, 0, -144}, 8, 16, false}, + + // Half the size of a vector register, but allowable with extending + // loads and truncating stores + // , increment by three vectors + {{nullptr, 0, true, 0, 24}, 8, 8, true}, + + // Test invalid types or offsets + // , increment by one vector (base size > 16B) + {{nullptr, 0, true, 0, 20}, 32, 5, false}, + // , increment by half a vector + {{nullptr, 0, true, 0, 8}, 16, 8, false}, + // , increment by 3 vectors (non-power-of-two) + {{nullptr, 0, true, 0, 9}, 8, 3, false}, + + // Scalable and fixed offsets + // , increment by 32 then decrement by vscale x 16 + {{nullptr, 32, true, 0, -16}, 8, 16, false}, +}; } // namespace TEST(AddressingModes, AddressingModes) { @@ -179,4 +220,11 @@ TEST(AddressingModes, AddressingModes) { Type *Typ = Type::getIntNTy(Ctx, Test.TypeBits); ASSERT_EQ(TLI->isLegalAddressingMode(DL, Test.AM, Typ, 0), Test.Result); } + + for (const auto &SVETest : SVETests) { + Type *Ty = VectorType::get(Type::getIntNTy(Ctx, SVETest.TypeBits), + ElementCount::getScalable(SVETest.NumElts)); + ASSERT_EQ(TLI->isLegalAddressingMode(DL, SVETest.AM, Ty, 0), + SVETest.Result); + } }