Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Add getelementptr nusw and nuw flags #90824

Merged
merged 7 commits into from
May 27, 2024
Merged

[IR] Add getelementptr nusw and nuw flags #90824

merged 7 commits into from
May 27, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented May 2, 2024

This implements the nusw and nuw flags for getelementptr as proposed at https://discourse.llvm.org/t/rfc-add-nusw-and-nuw-flags-for-getelementptr/78672.

The three possible flags are encapsulated in the new GEPNoWrapFlags class. Currently this class has a ctor from bool, interpreted as the InBounds flag. This ctor should be removed in the future, as code gets migrated to handle all flags.

There are a few places annotated with TODO(gep_nowrap), where I've had to touch code but opted to not infer or precisely preserve the new flags, so as to keep this as NFC as possible and make sure any changes of that kind get test coverage when they are made.

@llvmbot
Copy link
Collaborator

llvmbot commented May 2, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-backend-amdgpu

Author: Nikita Popov (nikic)

Changes

This implements the nusw and nuw flags for getelementptr as proposed at https://discourse.llvm.org/t/rfc-add-nusw-and-nuw-flags-for-getelementptr/78672.

There are a bunch of places annotated with TODO(gep_nowrap), where I've had to touch code but opted to not infer or precisely preserve the new flags, so as to keep this as NFC as possible and make sure any changes of that kind get test coverage when they are made.


Patch is 47.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90824.diff

29 Files Affected:

  • (modified) clang/lib/CodeGen/CGVTT.cpp (+3-1)
  • (modified) clang/lib/CodeGen/ItaniumCXXABI.cpp (+3-1)
  • (modified) llvm/docs/LangRef.rst (+43-18)
  • (modified) llvm/docs/ReleaseNotes.rst (+1)
  • (modified) llvm/include/llvm/AsmParser/LLToken.h (+1)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+10-1)
  • (modified) llvm/include/llvm/IR/Constants.h (+6-5)
  • (modified) llvm/include/llvm/IR/Instructions.h (+14)
  • (modified) llvm/include/llvm/IR/Operator.h (+25-1)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+9-3)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+32-7)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+39-16)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+11-5)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+4)
  • (modified) llvm/lib/IR/ConstantFold.cpp (+4-1)
  • (modified) llvm/lib/IR/Constants.cpp (+11-2)
  • (modified) llvm/lib/IR/Instruction.cpp (+18-4)
  • (modified) llvm/lib/IR/Instructions.cpp (+16)
  • (modified) llvm/lib/IR/Operator.cpp (+2-1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (+2-1)
  • (modified) llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp (+13)
  • (modified) llvm/lib/Transforms/Utils/FunctionComparator.cpp (+6)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+5)
  • (modified) llvm/test/Assembler/flags.ll (+79)
  • (modified) llvm/test/Transforms/InstCombine/freeze.ll (+22)
  • (modified) llvm/test/Transforms/SimplifyCFG/HoistCode.ll (+60)
  • (modified) llvm/test/tools/llvm-reduce/reduce-flags.ll (+13-5)
  • (modified) llvm/tools/llvm-reduce/deltas/ReduceInstructionFlags.cpp (+4)
diff --git a/clang/lib/CodeGen/CGVTT.cpp b/clang/lib/CodeGen/CGVTT.cpp
index d2376b14dd5826..8c72f3dccfd6e3 100644
--- a/clang/lib/CodeGen/CGVTT.cpp
+++ b/clang/lib/CodeGen/CGVTT.cpp
@@ -87,8 +87,10 @@ CodeGenVTables::EmitVTTDefinition(llvm::GlobalVariable *VTT,
      unsigned Offset = ComponentSize * AddressPoint.AddressPointIndex;
      llvm::ConstantRange InRange(llvm::APInt(32, -Offset, true),
                                  llvm::APInt(32, VTableSize - Offset, true));
+     // TODO(gep_nowrap): Set nuw as well.
      llvm::Constant *Init = llvm::ConstantExpr::getGetElementPtr(
-         VTable->getValueType(), VTable, Idxs, /*InBounds=*/true, InRange);
+         VTable->getValueType(), VTable, Idxs, /*InBounds=*/true, /*NUSW=*/true,
+         /*NUW=*/false, InRange);
 
      VTTComponents.push_back(Init);
   }
diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 18acf7784f714b..0138915ad35996 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -1901,8 +1901,10 @@ ItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
   unsigned Offset = ComponentSize * AddressPoint.AddressPointIndex;
   llvm::ConstantRange InRange(llvm::APInt(32, -Offset, true),
                               llvm::APInt(32, VTableSize - Offset, true));
+  // TODO(gep_nowrap): Set nuw as well.
   return llvm::ConstantExpr::getGetElementPtr(
-      VTable->getValueType(), VTable, Indices, /*InBounds=*/true, InRange);
+      VTable->getValueType(), VTable, Indices, /*InBounds=*/true, /*NUSW=*/true,
+      /*NUW=*/false, InRange);
 }
 
 // Check whether all the non-inline virtual methods for the class have the
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 6291a4e57919a5..a4340f060d6f07 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11180,6 +11180,8 @@ Syntax:
 
       <result> = getelementptr <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr inbounds <ty>, ptr <ptrval>{, <ty> <idx>}*
+      <result> = getelementptr nusw <ty>, ptr <ptrval>{, <ty> <idx>}*
+      <result> = getelementptr nuw <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr inrange(S,E) <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr <ty>, <N x ptr> <ptrval>, <vector index type> <idx>
 
@@ -11295,27 +11297,46 @@ memory though, even if it happens to point into allocated storage. See the
 :ref:`Pointer Aliasing Rules <pointeraliasing>` section for more
 information.
 
-If the ``inbounds`` keyword is present, the result value of a
-``getelementptr`` with any non-zero indices is a
-:ref:`poison value <poisonvalues>` if one of the following rules is violated:
-
-*  The base pointer has an *in bounds* address of an allocated object, which
+The ``getelementptr`` instruction may have a number of attributes that impose
+additional rules. If any of the rules are violated, the result value is a
+:ref:`poison value <poisonvalues>`. In cases where the base is a vector of
+pointers, the attributes apply to each computation element-wise.
+
+For ``nusw`` (no unsigned signed wrap):
+
+ * If the type of an index is larger than the pointer index type, the
+   truncation to the pointer index type preserves the signed value
+   (``trunc nsw``).
+ * The multiplication of an index by the type size does not wrap the pointer
+   index type in a signed sense (``mul nsw``).
+ * The successive addition of each offset (without adding the base address)
+   does not wrap the pointer index type in a signed sense (``add nsw``).
+ * The successive addition of the current address, truncated to the index type
+   and interpreted as an unsigned number, and each offset, interpreted as
+   a signed number, does not wrap the index type.
+
+For ``nuw`` (no unsigned wrap):
+
+ * If the type of an index is larger than the pointer index type, the
+   truncation to the pointer index type preserves the unsigned value
+   (``trunc nuw``).
+ * The multiplication of an index by the type size does not wrap the pointer
+   index type in an unsigned sense (``mul nuw``).
+ * The successive addition of each offset (without adding the base address)
+   does not wrap the pointer index type in an unsigned sense (``add nuw``).
+ * The successive addition of the current address, truncated to the index type
+   and interpreted as an unsigned number, and each offset, also interpreted as
+   an unsigned number, does not wrap the index type (``add nuw``).
+
+For ``inbounds`` all rules of the ``nusw`` attribute apply. Additionally,
+if the ``getelementptr`` has any non-zero indices, the following rules apply:
+
+ * The base pointer has an *in bounds* address of an allocated object, which
    means that it points into an allocated object, or to its end. Note that the
    object does not have to be live anymore; being in-bounds of a deallocated
    object is sufficient.
-*  If the type of an index is larger than the pointer index type, the
-   truncation to the pointer index type preserves the signed value.
-*  The multiplication of an index by the type size does not wrap the pointer
-   index type in a signed sense (``nsw``).
-*  The successive addition of each offset (without adding the base address) does
-   not wrap the pointer index type in a signed sense (``nsw``).
-*  The successive addition of the current address, interpreted as an unsigned
-   number, and each offset, interpreted as a signed number, does not wrap the
-   unsigned address space and remains *in bounds* of the allocated object.
-   As a corollary, if the added offset is non-negative, the addition does not
-   wrap in an unsigned sense (``nuw``).
-*  In cases where the base is a vector of pointers, the ``inbounds`` keyword
-   applies to each of the computations element-wise.
+ * During the successive addition of offsets to the address, the resulting
+   pointer must remain *in bounds* of the allocated object at each step.
 
 Note that ``getelementptr`` with all-zero indices is always considered to be
 ``inbounds``, even if the base pointer does not point to an allocated object.
@@ -11326,6 +11347,10 @@ These rules are based on the assumption that no allocated object may cross
 the unsigned address space boundary, and no allocated object may be larger
 than half the pointer index type space.
 
+If ``inbounds`` is present on a ``getelementptr`` instruction, the ``nusw``
+attribute will be automatically set as well. For this reason, the ``nusw``
+will also not be printed in textual IR if ``inbounds`` is already present.
+
 If the ``inrange(Start, End)`` attribute is present, loading from or
 storing to any pointer derived from the ``getelementptr`` has undefined
 behavior if the load or store would access memory outside the half-open range
diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst
index d8cc667723f554..412b85456cdbeb 100644
--- a/llvm/docs/ReleaseNotes.rst
+++ b/llvm/docs/ReleaseNotes.rst
@@ -51,6 +51,7 @@ Changes to the LLVM IR
 ----------------------
 
 * Added Memory Model Relaxation Annotations (MMRAs).
+* Added ``nusw`` and ``nuw`` flags to ``getelementptr`` instruction.
 * Renamed ``llvm.experimental.vector.reverse`` intrinsic to ``llvm.vector.reverse``.
 * Renamed ``llvm.experimental.vector.splice`` intrinsic to ``llvm.vector.splice``.
 * Renamed ``llvm.experimental.vector.interleave2`` intrinsic to ``llvm.vector.interleave2``.
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0cbcdcd9ffac77..df61ec6ed30e0b 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -109,6 +109,7 @@ enum Kind {
   kw_fast,
   kw_nuw,
   kw_nsw,
+  kw_nusw,
   kw_exact,
   kw_disjoint,
   kw_inbounds,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a9..d3b9e96520f88a 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -385,7 +385,7 @@ enum ConstantsCodes {
   CST_CODE_CSTRING = 9,          // CSTRING:       [values]
   CST_CODE_CE_BINOP = 10,        // CE_BINOP:      [opcode, opval, opval]
   CST_CODE_CE_CAST = 11,         // CE_CAST:       [opcode, opty, opval]
-  CST_CODE_CE_GEP = 12,          // CE_GEP:        [n x operands]
+  CST_CODE_CE_GEP_OLD = 12,      // CE_GEP:        [n x operands]
   CST_CODE_CE_SELECT = 13,       // CE_SELECT:     [opval, opval, opval]
   CST_CODE_CE_EXTRACTELT = 14,   // CE_EXTRACTELT: [opty, opval, opval]
   CST_CODE_CE_INSERTELT = 15,    // CE_INSERTELT:  [opval, opval, opval]
@@ -412,6 +412,7 @@ enum ConstantsCodes {
                                       //                 asmdialect|unwind,
                                       //                 asmstr,conststr]
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
+  CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
@@ -524,6 +525,14 @@ enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };
 /// PossiblyDisjointInst's SubclassOptionalData contents.
 enum PossiblyDisjointInstOptionalFlags { PDI_DISJOINT = 0 };
 
+/// GetElementPtrOptionalFlags - Flags for serializing
+/// GEPOperator's SubclassOptionalData contents.
+enum GetElementPtrOptionalFlags {
+  GEP_INBOUNDS = 0,
+  GEP_NUSW = 1,
+  GEP_NUW = 2,
+};
+
 /// Encoded AtomicOrdering values.
 enum AtomicOrderingCodes {
   ORDERING_NOTATOMIC = 0,
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 9ec81903f09c96..28ee766a6843e5 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1198,26 +1198,27 @@ class ConstantExpr : public Constant {
   /// \param OnlyIfReducedTy see \a getWithOperands() docs.
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, ArrayRef<Constant *> IdxList,
-                   bool InBounds = false,
+                   bool InBounds = false, bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr) {
     return getGetElementPtr(
         Ty, C, ArrayRef((Value *const *)IdxList.data(), IdxList.size()),
-        InBounds, InRange, OnlyIfReducedTy);
+        InBounds, NUSW, NUW, InRange, OnlyIfReducedTy);
   }
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, Constant *Idx, bool InBounds = false,
+                   bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr) {
     // This form of the function only exists to avoid ambiguous overload
     // warnings about whether to convert Idx to ArrayRef<Constant *> or
     // ArrayRef<Value *>.
-    return getGetElementPtr(Ty, C, cast<Value>(Idx), InBounds, InRange,
-                            OnlyIfReducedTy);
+    return getGetElementPtr(Ty, C, cast<Value>(Idx), InBounds, NUSW, NUW,
+                            InRange, OnlyIfReducedTy);
   }
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, ArrayRef<Value *> IdxList,
-                   bool InBounds = false,
+                   bool InBounds = false, bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr);
 
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index d7ec3c16bec21c..8c0db7b7bfdb2e 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1171,9 +1171,23 @@ class GetElementPtrInst : public Instruction {
   /// See LangRef.html for the meaning of inbounds on a getelementptr.
   void setIsInBounds(bool b = true);
 
+  /// Set or clear the nusw flag on this GEP instruction.
+  /// See LangRef.html for the meaning of nusw on a getelementptr.
+  void setHasNoUnsignedSignedWrap(bool B = true);
+
+  /// Set or clear the nuw flag on this GEP instruction.
+  /// See LangRef.html for the meaning of nuw on a getelementptr.
+  void setHasNoUnsignedWrap(bool B = true);
+
   /// Determine whether the GEP has the inbounds flag.
   bool isInBounds() const;
 
+  /// Determine whether the GEP has the nusw flag.
+  bool hasNoUnsignedSignedWrap() const;
+
+  /// Determine whether the GEP has the nuw flag.
+  bool hasNoUnsignedWrap() const;
+
   /// Accumulate the constant address offset of this GEP if possible.
   ///
   /// This routine accepts an APInt into which it will accumulate the constant
diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index b2307948bbbc4f..637542397cd5d8 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -405,11 +405,27 @@ class GEPOperator
 
   enum {
     IsInBounds = (1 << 0),
+    HasNoUnsignedSignedWrap = (1 << 1),
+    HasNoUnsignedWrap = (1 << 2),
   };
 
   void setIsInBounds(bool B) {
+    // Also set nusw when inbounds is set.
+    SubclassOptionalData = (SubclassOptionalData & ~IsInBounds) |
+                           (B * (IsInBounds | HasNoUnsignedSignedWrap));
+  }
+
+  void setHasNoUnsignedSignedWrap(bool B) {
+    // Also unset inbounds when nusw is unset.
+    if (B)
+      SubclassOptionalData |= HasNoUnsignedSignedWrap;
+    else
+      SubclassOptionalData &= ~(IsInBounds | HasNoUnsignedSignedWrap);
+  }
+
+  void setHasNoUnsignedWrap(bool B) {
     SubclassOptionalData =
-      (SubclassOptionalData & ~IsInBounds) | (B * IsInBounds);
+        (SubclassOptionalData & ~HasNoUnsignedWrap) | (B * HasNoUnsignedWrap);
   }
 
 public:
@@ -421,6 +437,14 @@ class GEPOperator
     return SubclassOptionalData & IsInBounds;
   }
 
+  bool hasNoUnsignedSignedWrap() const {
+    return SubclassOptionalData & HasNoUnsignedSignedWrap;
+  }
+
+  bool hasNoUnsignedWrap() const {
+    return SubclassOptionalData & HasNoUnsignedWrap;
+  }
+
   /// Returns the offset of the index with an inrange attachment, or
   /// std::nullopt if none.
   std::optional<ConstantRange> getInRange() const;
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48af..1cbcb6868eeef9 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -856,8 +856,10 @@ Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops,
   if (!Any)
     return nullptr;
 
+  // TODO(gep_nowrap): Preserve NUSW/NUW here.
   Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], NewIdxs,
-                                               InBounds, InRange);
+                                               InBounds, /*NUSW=*/InBounds,
+                                               /*NUW=*/false, InRange);
   return ConstantFoldConstant(C, DL, TLI);
 }
 
@@ -980,7 +982,9 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP,
     NewIdxs.push_back(ConstantInt::get(
         Type::getIntNTy(Ptr->getContext(), Index.getBitWidth()), Index));
 
+  // TODO(gep_nowrap): Preserve NUSW/NUW.
   return ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs, InBounds,
+                                        /*NUSW=*/InBounds, /*NUW=*/false,
                                         InRange);
 }
 
@@ -1028,8 +1032,10 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode,
     if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI))
       return C;
 
-    return ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], Ops.slice(1),
-                                          GEP->isInBounds(), GEP->getInRange());
+    return ConstantExpr::getGetElementPtr(
+        SrcElemTy, Ops[0], Ops.slice(1), GEP->isInBounds(),
+        GEP->hasNoUnsignedSignedWrap(), GEP->hasNoUnsignedWrap(),
+        GEP->getInRange());
   }
 
   if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE)) {
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 8ded07ffd8bd25..20a1bd29577124 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -566,6 +566,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(fast);
   KEYWORD(nuw);
   KEYWORD(nsw);
+  KEYWORD(nusw);
   KEYWORD(exact);
   KEYWORD(disjoint);
   KEYWORD(inbounds);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2902bd9fe17c48..fa4d87ca8d5ffe 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4216,7 +4216,7 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   case lltok::kw_extractelement: {
     unsigned Opc = Lex.getUIntVal();
     SmallVector<Constant*, 16> Elts;
-    bool InBounds = false;
+    bool InBounds = false, HasNUSW = false, HasNUW = false;
     bool HasInRange = false;
     APSInt InRangeStart;
     APSInt InRangeEnd;
@@ -4224,7 +4224,17 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     Lex.Lex();
 
     if (Opc == Instruction::GetElementPtr) {
-      InBounds = EatIfPresent(lltok::kw_inbounds);
+      while (true) {
+        if (EatIfPresent(lltok::kw_inbounds))
+          InBounds = true;
+        else if (EatIfPresent(lltok::kw_nusw))
+          HasNUSW = true;
+        else if (EatIfPresent(lltok::kw_nuw))
+          HasNUW = true;
+        else
+          break;
+      }
+
       if (EatIfPresent(lltok::kw_inrange)) {
         if (parseToken(lltok::lparen, "expected '('"))
           return true;
@@ -4303,8 +4313,8 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       if (!GetElementPtrInst::getIndexedType(Ty, Indices))
         return error(ID.Loc, "invalid getelementptr indices");
 
-      ID.ConstantVal = ConstantExpr::getGetElementPtr(Ty, Elts[0], Indices,
-                                                      InBounds, InRange);
+      ID.ConstantVal = ConstantExpr::getGetElementPtr(
+          Ty, Elts[0], Indices, InBounds, HasNUSW, HasNUW, InRange);
     } else if (Opc == Instruction::ShuffleVector) {
       if (Elts.size() != 3)
         return error(ID.Loc, "expected three operands to shufflevector");
@@ -8340,7 +8350,17 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
   Value *Val = nullptr;
   LocTy Loc, EltLoc;
 
-  bool InBounds = EatIfPresent(lltok::kw_inbounds);
+  bool InBounds = false, NUSW = false, NUW = false;
+  while (true) {
+    if (EatIfPresent(lltok::kw_inbounds))
+      InBounds = true;
+    else if (EatIfPresent(lltok::kw_nusw))
+      NUSW = true;
+    else if (EatIfPresent(lltok::kw_nuw))
+      NUW = true;
+    else
+      break;
+  }
 
   Type *Ty = nullptr;
   if (parseType(Ty) ||
@@ -8393,9 +8413,14 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
 
   if (!GetElementPtrInst::getIndexedType(Ty, Indices))
     return error(Loc, "invalid getelementptr indices");
-  Inst = GetElementPtrInst::Create(Ty, Ptr, Indices);
+  GetElementPtrInst *GEP = GetElementPtrInst::Create(Ty, Ptr, Indices);
+  Inst = GEP;
   if (InBounds)
-    cast<GetElementPtrInst>(Inst)->setIsInBounds(true);
+    GEP->setIsInBounds(true);
+  if (NUSW)
+    GEP->setHasNoUnsignedSignedWrap(true);
+  if (NUW)
+    GEP->setHasNoUnsignedWrap(true);
   return AteExtraComma ? InstExtraComma : InstNormal;
 }
 
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index a0779f955cf28d..278d2c6adae6a7 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1613,9 +1613,11 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           C = ConstantExpr::getCompare(BC->Flags, ConstOps[0], ConstOps[1]);
           break;
         case Instruction::GetElementPtr:
-          C = ConstantExpr::getGetElementPtr(BC->SrcElemTy, ConstOps[0],
-                                             ArrayRef(ConstOps).drop_front(),
-                                             BC->Flags, BC->getInRange());
+          C = ConstantExpr::getGetElementPtr(
+              BC->SrcElemTy, ConstOps[0], ArrayRef(ConstOps).drop_front(),
+              (BC->Flags & (1 << bitc::GEP_INBOUNDS)) != 0,
+              (BC->Flags & (1 << bitc::GEP_NUSW)) != 0,
+              (BC->Flags & (1 << bitc::GEP_NUW)) != 0, BC->getInRange());
           break;
         case Instruction::ExtractElement:
           C = ...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented May 2, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

Changes

This implements the nusw and nuw flags for getelementptr as proposed at https://discourse.llvm.org/t/rfc-add-nusw-and-nuw-flags-for-getelementptr/78672.

There are a bunch of places annotated with TODO(gep_nowrap), where I've had to touch code but opted to not infer or precisely preserve the new flags, so as to keep this as NFC as possible and make sure any changes of that kind get test coverage when they are made.


Patch is 47.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90824.diff

29 Files Affected:

  • (modified) clang/lib/CodeGen/CGVTT.cpp (+3-1)
  • (modified) clang/lib/CodeGen/ItaniumCXXABI.cpp (+3-1)
  • (modified) llvm/docs/LangRef.rst (+43-18)
  • (modified) llvm/docs/ReleaseNotes.rst (+1)
  • (modified) llvm/include/llvm/AsmParser/LLToken.h (+1)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+10-1)
  • (modified) llvm/include/llvm/IR/Constants.h (+6-5)
  • (modified) llvm/include/llvm/IR/Instructions.h (+14)
  • (modified) llvm/include/llvm/IR/Operator.h (+25-1)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+9-3)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+32-7)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+39-16)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+11-5)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+4)
  • (modified) llvm/lib/IR/ConstantFold.cpp (+4-1)
  • (modified) llvm/lib/IR/Constants.cpp (+11-2)
  • (modified) llvm/lib/IR/Instruction.cpp (+18-4)
  • (modified) llvm/lib/IR/Instructions.cpp (+16)
  • (modified) llvm/lib/IR/Operator.cpp (+2-1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp (+2-1)
  • (modified) llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp (+13)
  • (modified) llvm/lib/Transforms/Utils/FunctionComparator.cpp (+6)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+5)
  • (modified) llvm/test/Assembler/flags.ll (+79)
  • (modified) llvm/test/Transforms/InstCombine/freeze.ll (+22)
  • (modified) llvm/test/Transforms/SimplifyCFG/HoistCode.ll (+60)
  • (modified) llvm/test/tools/llvm-reduce/reduce-flags.ll (+13-5)
  • (modified) llvm/tools/llvm-reduce/deltas/ReduceInstructionFlags.cpp (+4)
diff --git a/clang/lib/CodeGen/CGVTT.cpp b/clang/lib/CodeGen/CGVTT.cpp
index d2376b14dd5826..8c72f3dccfd6e3 100644
--- a/clang/lib/CodeGen/CGVTT.cpp
+++ b/clang/lib/CodeGen/CGVTT.cpp
@@ -87,8 +87,10 @@ CodeGenVTables::EmitVTTDefinition(llvm::GlobalVariable *VTT,
      unsigned Offset = ComponentSize * AddressPoint.AddressPointIndex;
      llvm::ConstantRange InRange(llvm::APInt(32, -Offset, true),
                                  llvm::APInt(32, VTableSize - Offset, true));
+     // TODO(gep_nowrap): Set nuw as well.
      llvm::Constant *Init = llvm::ConstantExpr::getGetElementPtr(
-         VTable->getValueType(), VTable, Idxs, /*InBounds=*/true, InRange);
+         VTable->getValueType(), VTable, Idxs, /*InBounds=*/true, /*NUSW=*/true,
+         /*NUW=*/false, InRange);
 
      VTTComponents.push_back(Init);
   }
diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 18acf7784f714b..0138915ad35996 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -1901,8 +1901,10 @@ ItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
   unsigned Offset = ComponentSize * AddressPoint.AddressPointIndex;
   llvm::ConstantRange InRange(llvm::APInt(32, -Offset, true),
                               llvm::APInt(32, VTableSize - Offset, true));
+  // TODO(gep_nowrap): Set nuw as well.
   return llvm::ConstantExpr::getGetElementPtr(
-      VTable->getValueType(), VTable, Indices, /*InBounds=*/true, InRange);
+      VTable->getValueType(), VTable, Indices, /*InBounds=*/true, /*NUSW=*/true,
+      /*NUW=*/false, InRange);
 }
 
 // Check whether all the non-inline virtual methods for the class have the
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 6291a4e57919a5..a4340f060d6f07 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11180,6 +11180,8 @@ Syntax:
 
       <result> = getelementptr <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr inbounds <ty>, ptr <ptrval>{, <ty> <idx>}*
+      <result> = getelementptr nusw <ty>, ptr <ptrval>{, <ty> <idx>}*
+      <result> = getelementptr nuw <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr inrange(S,E) <ty>, ptr <ptrval>{, <ty> <idx>}*
       <result> = getelementptr <ty>, <N x ptr> <ptrval>, <vector index type> <idx>
 
@@ -11295,27 +11297,46 @@ memory though, even if it happens to point into allocated storage. See the
 :ref:`Pointer Aliasing Rules <pointeraliasing>` section for more
 information.
 
-If the ``inbounds`` keyword is present, the result value of a
-``getelementptr`` with any non-zero indices is a
-:ref:`poison value <poisonvalues>` if one of the following rules is violated:
-
-*  The base pointer has an *in bounds* address of an allocated object, which
+The ``getelementptr`` instruction may have a number of attributes that impose
+additional rules. If any of the rules are violated, the result value is a
+:ref:`poison value <poisonvalues>`. In cases where the base is a vector of
+pointers, the attributes apply to each computation element-wise.
+
+For ``nusw`` (no unsigned signed wrap):
+
+ * If the type of an index is larger than the pointer index type, the
+   truncation to the pointer index type preserves the signed value
+   (``trunc nsw``).
+ * The multiplication of an index by the type size does not wrap the pointer
+   index type in a signed sense (``mul nsw``).
+ * The successive addition of each offset (without adding the base address)
+   does not wrap the pointer index type in a signed sense (``add nsw``).
+ * The successive addition of the current address, truncated to the index type
+   and interpreted as an unsigned number, and each offset, interpreted as
+   a signed number, does not wrap the index type.
+
+For ``nuw`` (no unsigned wrap):
+
+ * If the type of an index is larger than the pointer index type, the
+   truncation to the pointer index type preserves the unsigned value
+   (``trunc nuw``).
+ * The multiplication of an index by the type size does not wrap the pointer
+   index type in an unsigned sense (``mul nuw``).
+ * The successive addition of each offset (without adding the base address)
+   does not wrap the pointer index type in an unsigned sense (``add nuw``).
+ * The successive addition of the current address, truncated to the index type
+   and interpreted as an unsigned number, and each offset, also interpreted as
+   an unsigned number, does not wrap the index type (``add nuw``).
+
+For ``inbounds`` all rules of the ``nusw`` attribute apply. Additionally,
+if the ``getelementptr`` has any non-zero indices, the following rules apply:
+
+ * The base pointer has an *in bounds* address of an allocated object, which
    means that it points into an allocated object, or to its end. Note that the
    object does not have to be live anymore; being in-bounds of a deallocated
    object is sufficient.
-*  If the type of an index is larger than the pointer index type, the
-   truncation to the pointer index type preserves the signed value.
-*  The multiplication of an index by the type size does not wrap the pointer
-   index type in a signed sense (``nsw``).
-*  The successive addition of each offset (without adding the base address) does
-   not wrap the pointer index type in a signed sense (``nsw``).
-*  The successive addition of the current address, interpreted as an unsigned
-   number, and each offset, interpreted as a signed number, does not wrap the
-   unsigned address space and remains *in bounds* of the allocated object.
-   As a corollary, if the added offset is non-negative, the addition does not
-   wrap in an unsigned sense (``nuw``).
-*  In cases where the base is a vector of pointers, the ``inbounds`` keyword
-   applies to each of the computations element-wise.
+ * During the successive addition of offsets to the address, the resulting
+   pointer must remain *in bounds* of the allocated object at each step.
 
 Note that ``getelementptr`` with all-zero indices is always considered to be
 ``inbounds``, even if the base pointer does not point to an allocated object.
@@ -11326,6 +11347,10 @@ These rules are based on the assumption that no allocated object may cross
 the unsigned address space boundary, and no allocated object may be larger
 than half the pointer index type space.
 
+If ``inbounds`` is present on a ``getelementptr`` instruction, the ``nusw``
+attribute will be automatically set as well. For this reason, the ``nusw``
+will also not be printed in textual IR if ``inbounds`` is already present.
+
 If the ``inrange(Start, End)`` attribute is present, loading from or
 storing to any pointer derived from the ``getelementptr`` has undefined
 behavior if the load or store would access memory outside the half-open range
diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst
index d8cc667723f554..412b85456cdbeb 100644
--- a/llvm/docs/ReleaseNotes.rst
+++ b/llvm/docs/ReleaseNotes.rst
@@ -51,6 +51,7 @@ Changes to the LLVM IR
 ----------------------
 
 * Added Memory Model Relaxation Annotations (MMRAs).
+* Added ``nusw`` and ``nuw`` flags to ``getelementptr`` instruction.
 * Renamed ``llvm.experimental.vector.reverse`` intrinsic to ``llvm.vector.reverse``.
 * Renamed ``llvm.experimental.vector.splice`` intrinsic to ``llvm.vector.splice``.
 * Renamed ``llvm.experimental.vector.interleave2`` intrinsic to ``llvm.vector.interleave2``.
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0cbcdcd9ffac77..df61ec6ed30e0b 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -109,6 +109,7 @@ enum Kind {
   kw_fast,
   kw_nuw,
   kw_nsw,
+  kw_nusw,
   kw_exact,
   kw_disjoint,
   kw_inbounds,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a9..d3b9e96520f88a 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -385,7 +385,7 @@ enum ConstantsCodes {
   CST_CODE_CSTRING = 9,          // CSTRING:       [values]
   CST_CODE_CE_BINOP = 10,        // CE_BINOP:      [opcode, opval, opval]
   CST_CODE_CE_CAST = 11,         // CE_CAST:       [opcode, opty, opval]
-  CST_CODE_CE_GEP = 12,          // CE_GEP:        [n x operands]
+  CST_CODE_CE_GEP_OLD = 12,      // CE_GEP:        [n x operands]
   CST_CODE_CE_SELECT = 13,       // CE_SELECT:     [opval, opval, opval]
   CST_CODE_CE_EXTRACTELT = 14,   // CE_EXTRACTELT: [opty, opval, opval]
   CST_CODE_CE_INSERTELT = 15,    // CE_INSERTELT:  [opval, opval, opval]
@@ -412,6 +412,7 @@ enum ConstantsCodes {
                                       //                 asmdialect|unwind,
                                       //                 asmstr,conststr]
   CST_CODE_CE_GEP_WITH_INRANGE = 31,  // [opty, flags, range, n x operands]
+  CST_CODE_CE_GEP = 32,               // [opty, flags, n x operands]
 };
 
 /// CastOpcodes - These are values used in the bitcode files to encode which
@@ -524,6 +525,14 @@ enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };
 /// PossiblyDisjointInst's SubclassOptionalData contents.
 enum PossiblyDisjointInstOptionalFlags { PDI_DISJOINT = 0 };
 
+/// GetElementPtrOptionalFlags - Flags for serializing
+/// GEPOperator's SubclassOptionalData contents.
+enum GetElementPtrOptionalFlags {
+  GEP_INBOUNDS = 0,
+  GEP_NUSW = 1,
+  GEP_NUW = 2,
+};
+
 /// Encoded AtomicOrdering values.
 enum AtomicOrderingCodes {
   ORDERING_NOTATOMIC = 0,
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 9ec81903f09c96..28ee766a6843e5 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -1198,26 +1198,27 @@ class ConstantExpr : public Constant {
   /// \param OnlyIfReducedTy see \a getWithOperands() docs.
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, ArrayRef<Constant *> IdxList,
-                   bool InBounds = false,
+                   bool InBounds = false, bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr) {
     return getGetElementPtr(
         Ty, C, ArrayRef((Value *const *)IdxList.data(), IdxList.size()),
-        InBounds, InRange, OnlyIfReducedTy);
+        InBounds, NUSW, NUW, InRange, OnlyIfReducedTy);
   }
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, Constant *Idx, bool InBounds = false,
+                   bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr) {
     // This form of the function only exists to avoid ambiguous overload
     // warnings about whether to convert Idx to ArrayRef<Constant *> or
     // ArrayRef<Value *>.
-    return getGetElementPtr(Ty, C, cast<Value>(Idx), InBounds, InRange,
-                            OnlyIfReducedTy);
+    return getGetElementPtr(Ty, C, cast<Value>(Idx), InBounds, NUSW, NUW,
+                            InRange, OnlyIfReducedTy);
   }
   static Constant *
   getGetElementPtr(Type *Ty, Constant *C, ArrayRef<Value *> IdxList,
-                   bool InBounds = false,
+                   bool InBounds = false, bool NUSW = false, bool NUW = false,
                    std::optional<ConstantRange> InRange = std::nullopt,
                    Type *OnlyIfReducedTy = nullptr);
 
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index d7ec3c16bec21c..8c0db7b7bfdb2e 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1171,9 +1171,23 @@ class GetElementPtrInst : public Instruction {
   /// See LangRef.html for the meaning of inbounds on a getelementptr.
   void setIsInBounds(bool b = true);
 
+  /// Set or clear the nusw flag on this GEP instruction.
+  /// See LangRef.html for the meaning of nusw on a getelementptr.
+  void setHasNoUnsignedSignedWrap(bool B = true);
+
+  /// Set or clear the nuw flag on this GEP instruction.
+  /// See LangRef.html for the meaning of nuw on a getelementptr.
+  void setHasNoUnsignedWrap(bool B = true);
+
   /// Determine whether the GEP has the inbounds flag.
   bool isInBounds() const;
 
+  /// Determine whether the GEP has the nusw flag.
+  bool hasNoUnsignedSignedWrap() const;
+
+  /// Determine whether the GEP has the nuw flag.
+  bool hasNoUnsignedWrap() const;
+
   /// Accumulate the constant address offset of this GEP if possible.
   ///
   /// This routine accepts an APInt into which it will accumulate the constant
diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h
index b2307948bbbc4f..637542397cd5d8 100644
--- a/llvm/include/llvm/IR/Operator.h
+++ b/llvm/include/llvm/IR/Operator.h
@@ -405,11 +405,27 @@ class GEPOperator
 
   enum {
     IsInBounds = (1 << 0),
+    HasNoUnsignedSignedWrap = (1 << 1),
+    HasNoUnsignedWrap = (1 << 2),
   };
 
   void setIsInBounds(bool B) {
+    // Also set nusw when inbounds is set.
+    SubclassOptionalData = (SubclassOptionalData & ~IsInBounds) |
+                           (B * (IsInBounds | HasNoUnsignedSignedWrap));
+  }
+
+  void setHasNoUnsignedSignedWrap(bool B) {
+    // Also unset inbounds when nusw is unset.
+    if (B)
+      SubclassOptionalData |= HasNoUnsignedSignedWrap;
+    else
+      SubclassOptionalData &= ~(IsInBounds | HasNoUnsignedSignedWrap);
+  }
+
+  void setHasNoUnsignedWrap(bool B) {
     SubclassOptionalData =
-      (SubclassOptionalData & ~IsInBounds) | (B * IsInBounds);
+        (SubclassOptionalData & ~HasNoUnsignedWrap) | (B * HasNoUnsignedWrap);
   }
 
 public:
@@ -421,6 +437,14 @@ class GEPOperator
     return SubclassOptionalData & IsInBounds;
   }
 
+  bool hasNoUnsignedSignedWrap() const {
+    return SubclassOptionalData & HasNoUnsignedSignedWrap;
+  }
+
+  bool hasNoUnsignedWrap() const {
+    return SubclassOptionalData & HasNoUnsignedWrap;
+  }
+
   /// Returns the offset of the index with an inrange attachment, or
   /// std::nullopt if none.
   std::optional<ConstantRange> getInRange() const;
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 749374a3aa48af..1cbcb6868eeef9 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -856,8 +856,10 @@ Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops,
   if (!Any)
     return nullptr;
 
+  // TODO(gep_nowrap): Preserve NUSW/NUW here.
   Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], NewIdxs,
-                                               InBounds, InRange);
+                                               InBounds, /*NUSW=*/InBounds,
+                                               /*NUW=*/false, InRange);
   return ConstantFoldConstant(C, DL, TLI);
 }
 
@@ -980,7 +982,9 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP,
     NewIdxs.push_back(ConstantInt::get(
         Type::getIntNTy(Ptr->getContext(), Index.getBitWidth()), Index));
 
+  // TODO(gep_nowrap): Preserve NUSW/NUW.
   return ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs, InBounds,
+                                        /*NUSW=*/InBounds, /*NUW=*/false,
                                         InRange);
 }
 
@@ -1028,8 +1032,10 @@ Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode,
     if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI))
       return C;
 
-    return ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], Ops.slice(1),
-                                          GEP->isInBounds(), GEP->getInRange());
+    return ConstantExpr::getGetElementPtr(
+        SrcElemTy, Ops[0], Ops.slice(1), GEP->isInBounds(),
+        GEP->hasNoUnsignedSignedWrap(), GEP->hasNoUnsignedWrap(),
+        GEP->getInRange());
   }
 
   if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE)) {
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 8ded07ffd8bd25..20a1bd29577124 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -566,6 +566,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(fast);
   KEYWORD(nuw);
   KEYWORD(nsw);
+  KEYWORD(nusw);
   KEYWORD(exact);
   KEYWORD(disjoint);
   KEYWORD(inbounds);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2902bd9fe17c48..fa4d87ca8d5ffe 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -4216,7 +4216,7 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
   case lltok::kw_extractelement: {
     unsigned Opc = Lex.getUIntVal();
     SmallVector<Constant*, 16> Elts;
-    bool InBounds = false;
+    bool InBounds = false, HasNUSW = false, HasNUW = false;
     bool HasInRange = false;
     APSInt InRangeStart;
     APSInt InRangeEnd;
@@ -4224,7 +4224,17 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     Lex.Lex();
 
     if (Opc == Instruction::GetElementPtr) {
-      InBounds = EatIfPresent(lltok::kw_inbounds);
+      while (true) {
+        if (EatIfPresent(lltok::kw_inbounds))
+          InBounds = true;
+        else if (EatIfPresent(lltok::kw_nusw))
+          HasNUSW = true;
+        else if (EatIfPresent(lltok::kw_nuw))
+          HasNUW = true;
+        else
+          break;
+      }
+
       if (EatIfPresent(lltok::kw_inrange)) {
         if (parseToken(lltok::lparen, "expected '('"))
           return true;
@@ -4303,8 +4313,8 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
       if (!GetElementPtrInst::getIndexedType(Ty, Indices))
         return error(ID.Loc, "invalid getelementptr indices");
 
-      ID.ConstantVal = ConstantExpr::getGetElementPtr(Ty, Elts[0], Indices,
-                                                      InBounds, InRange);
+      ID.ConstantVal = ConstantExpr::getGetElementPtr(
+          Ty, Elts[0], Indices, InBounds, HasNUSW, HasNUW, InRange);
     } else if (Opc == Instruction::ShuffleVector) {
       if (Elts.size() != 3)
         return error(ID.Loc, "expected three operands to shufflevector");
@@ -8340,7 +8350,17 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
   Value *Val = nullptr;
   LocTy Loc, EltLoc;
 
-  bool InBounds = EatIfPresent(lltok::kw_inbounds);
+  bool InBounds = false, NUSW = false, NUW = false;
+  while (true) {
+    if (EatIfPresent(lltok::kw_inbounds))
+      InBounds = true;
+    else if (EatIfPresent(lltok::kw_nusw))
+      NUSW = true;
+    else if (EatIfPresent(lltok::kw_nuw))
+      NUW = true;
+    else
+      break;
+  }
 
   Type *Ty = nullptr;
   if (parseType(Ty) ||
@@ -8393,9 +8413,14 @@ int LLParser::parseGetElementPtr(Instruction *&Inst, PerFunctionState &PFS) {
 
   if (!GetElementPtrInst::getIndexedType(Ty, Indices))
     return error(Loc, "invalid getelementptr indices");
-  Inst = GetElementPtrInst::Create(Ty, Ptr, Indices);
+  GetElementPtrInst *GEP = GetElementPtrInst::Create(Ty, Ptr, Indices);
+  Inst = GEP;
   if (InBounds)
-    cast<GetElementPtrInst>(Inst)->setIsInBounds(true);
+    GEP->setIsInBounds(true);
+  if (NUSW)
+    GEP->setHasNoUnsignedSignedWrap(true);
+  if (NUW)
+    GEP->setHasNoUnsignedWrap(true);
   return AteExtraComma ? InstExtraComma : InstNormal;
 }
 
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index a0779f955cf28d..278d2c6adae6a7 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1613,9 +1613,11 @@ Expected<Value *> BitcodeReader::materializeValue(unsigned StartValID,
           C = ConstantExpr::getCompare(BC->Flags, ConstOps[0], ConstOps[1]);
           break;
         case Instruction::GetElementPtr:
-          C = ConstantExpr::getGetElementPtr(BC->SrcElemTy, ConstOps[0],
-                                             ArrayRef(ConstOps).drop_front(),
-                                             BC->Flags, BC->getInRange());
+          C = ConstantExpr::getGetElementPtr(
+              BC->SrcElemTy, ConstOps[0], ArrayRef(ConstOps).drop_front(),
+              (BC->Flags & (1 << bitc::GEP_INBOUNDS)) != 0,
+              (BC->Flags & (1 << bitc::GEP_NUSW)) != 0,
+              (BC->Flags & (1 << bitc::GEP_NUW)) != 0, BC->getInRange());
           break;
         case Instruction::ExtractElement:
           C = ...
[truncated]

Copy link

github-actions bot commented May 2, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@tschuett
Copy link
Member

tschuett commented May 2, 2024

Could you please add a TODO here:

uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {

Thanks.

@antoniofrighetto
Copy link
Contributor

Are the TODOs encompassing all the cases? Why we don't want to set the flags in PHITransAddr as well?

llvm/docs/LangRef.rst Outdated Show resolved Hide resolved
cast<GetElementPtrInst>(I)->setIsInBounds();
if (BC->Flags & (1 << bitc::GEP_NUSW))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this could be else if.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using else if would be correct here as setIsInBounds will also set nusw by itself, but keeping all cases separate seems cleaner?

Otherwise we'll have the following code here...

        if (BC->Flags & (1 << bitc::GEP_INBOUNDS))
          cast<GetElementPtrInst>(I)->setIsInBounds();
        else if (BC->Flags & (1 << bitc::GEP_NUSW))
          cast<GetElementPtrInst>(I)->setHasNoUnsignedSignedWrap();
        if (BC->Flags & (1 << bitc::GEP_NUW))
          cast<GetElementPtrInst>(I)->setHasNoUnsignedWrap();

...at which point it will probably need a comment to explain why there is only one else if there.

@nikic
Copy link
Contributor Author

nikic commented May 3, 2024

Are the TODOs encompassing all the cases? Why we don't want to set the flags in PHITransAddr as well?

No, the TODOs are only for places where I had to modify code, but the used implementation is obviously non-optimal to minimize initial impact. There are many more places that need to be modified to full propagate the flags everywhere.

Copy link
Contributor

@aeubanks aeubanks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I think abstracting out GEPNoWrapFlags is good

llvm/include/llvm/IR/GEPNoWrapFlags.h Outdated Show resolved Hide resolved
llvm/include/llvm/IR/GEPNoWrapFlags.h Show resolved Hide resolved
GEPNoWrapFlags() : Flags(0) {}
// For historical reasons, interpret plain boolean as InBounds.
// TODO: Migrate users to pass explicit GEPNoWrapFlags and remove this ctor.
GEPNoWrapFlags(bool IsInBounds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this CTOR? Seems to be entirely encapsulated with inBounds which is more expressive. Seems a bit non-intuitive to only have 1 bool constructor for a flags class that represents 3...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a backwards-compatiblity ctor because we're replacing existing bool IsInBounds parameters with GEPNoWrapFlags NW instead. As the TODO indicates, this ctor will be removed in the future. It avoids having to touch large amounts of code in this PR in sub-optimal ways.

@goldsteinn
Copy link
Contributor

LGTM. Wait on some additional approvals to push please.

@nikic
Copy link
Contributor Author

nikic commented May 22, 2024

There's already another approval from @aeubanks, so I plan to merge this next Monday if there's no more feedback.

@nikic nikic merged commit 8cdecd4 into llvm:main May 27, 2024
8 checks passed
@nikic nikic deleted the gep-nowrap branch May 27, 2024 14:05
tschuett added a commit to tschuett/llvm-project that referenced this pull request May 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants