Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM][IR] Add native vector support to ConstantInt & ConstantFP. #74502

Merged
merged 3 commits into from Feb 22, 2024

Conversation

paulwalker-arm
Copy link
Collaborator

@paulwalker-arm paulwalker-arm commented Dec 5, 2023

NOTE: For brevity the following talks about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  • ConstantVector for fixed-length vectors
  • ConstantExprs for scalable vectors

However, ConstantExprs are being deprecated and ConstantVector is
not space efficient for larger vector types. By extending ConstantInt
we can represent vector splats by only storing the underlying scalar
value.

More specifically:

  • ConstantInt gains an ElementCount variant of get().
  • LLVMContext is extended to map <EC,APInt>->ConstantInt.
  • BitcodeReader/Writer support is extended to allow vector types.

Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flags are added to allow wider testing and thus help with the
migration:

--use-constant-int-for-fixed-length-splat
--use-constant-fp-for-fixed-length-splat
--use-constant-int-for-scalable-splat
--use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: For similar reasons as above, code generation doesn't work
out-the-box.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen llvm:ir llvm:analysis llvm:transforms clang:openmp OpenMP related changes to Clang labels Dec 5, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Paul Walker (paulwalker-arm)

Changes

[LLVM][IR] Add native vector support to ConstantInt & ConstantFP.

NOTE: For brevity the following takes about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  • ConstantVector for fixed-length vectors
  • ConstantExprs for scalable vectors

ConstantExprs are being deprecated and ConstantVector is not space
efficient for larger vector types. This patch introduces an
alternative by allowing ConstantInt to natively support vector
splats via the IR syntax:

<N x ty> splat(ty <imm>)

More specifically:

  • IR parsing is extended to support the new syntax.
  • ConstantInt gains the interface getSplat().
  • LLVMContext is extended to map <EC,APInt>->ConstantInt.
  • BitCodeReader/Writer is extended to support vector types.

Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa<ConstantInt> assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flag are added to allow wider testing and thus help with the
migration:

--use-constant-int-for-fixed-length-splat
--use-constant-fp-for-fixed-length-splat
--use-constant-int-for-scalable-splat
--use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: Code generation doesn't work out-the-box but the issues look
limited to calls to ConstantInt::getBitWidth() that will need to be
ported.


Patch is 41.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74502.diff

22 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+4-3)
  • (modified) llvm/include/llvm/AsmParser/LLParser.h (+3-1)
  • (modified) llvm/include/llvm/AsmParser/LLToken.h (+1)
  • (modified) llvm/include/llvm/IR/Constants.h (+23-5)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+1-1)
  • (modified) llvm/lib/AsmParser/LLLexer.cpp (+1)
  • (modified) llvm/lib/AsmParser/LLParser.cpp (+50-6)
  • (modified) llvm/lib/Bitcode/Reader/BitcodeReader.cpp (+41-29)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+3-3)
  • (modified) llvm/lib/IR/AsmWriter.cpp (+32-5)
  • (modified) llvm/lib/IR/ConstantFold.cpp (+1-1)
  • (modified) llvm/lib/IR/Constants.cpp (+92-1)
  • (modified) llvm/lib/IR/LLVMContextImpl.cpp (+2)
  • (modified) llvm/lib/IR/LLVMContextImpl.h (+4)
  • (modified) llvm/lib/IR/Verifier.cpp (+2-2)
  • (modified) llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp (+3-3)
  • (modified) llvm/lib/Transforms/IPO/OpenMPOpt.cpp (+8-7)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp (+2-2)
  • (modified) llvm/lib/Transforms/Scalar/ConstantHoisting.cpp (+3-3)
  • (modified) llvm/lib/Transforms/Scalar/LoopFlatten.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+4-4)
  • (added) llvm/test/Bitcode/constant-splat.ll (+53)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 65d9862621061..8dc828abf8aec 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3218,7 +3218,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     Value *AlignmentValue = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(AlignmentValue);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(PtrValue, Ptr,
@@ -17010,7 +17010,7 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
     Value *Op1 = EmitScalarExpr(E->getArg(1));
     ConstantInt *AlignmentCI = cast<ConstantInt>(Op0);
     if (AlignmentCI->getValue().ugt(llvm::Value::MaximumAlignment))
-      AlignmentCI = ConstantInt::get(AlignmentCI->getType(),
+      AlignmentCI = ConstantInt::get(AlignmentCI->getIntegerType(),
                                      llvm::Value::MaximumAlignment);
 
     emitAlignmentAssumption(Op1, E->getArg(1),
@@ -17248,7 +17248,8 @@ Value *CodeGenFunction::EmitPPCBuiltinExpr(unsigned BuiltinID,
         Op0, llvm::FixedVectorType::get(ConvertType(E->getType()), 2));
 
     if (getTarget().isLittleEndian())
-      Index = ConstantInt::get(Index->getType(), 1 - Index->getZExtValue());
+      Index =
+          ConstantInt::get(Index->getIntegerType(), 1 - Index->getZExtValue());
 
     return Builder.CreateExtractElement(Unpacked, Index);
   }
diff --git a/llvm/include/llvm/AsmParser/LLParser.h b/llvm/include/llvm/AsmParser/LLParser.h
index 810f3668d05d4..38f6f08b8f3a1 100644
--- a/llvm/include/llvm/AsmParser/LLParser.h
+++ b/llvm/include/llvm/AsmParser/LLParser.h
@@ -59,7 +59,9 @@ namespace llvm {
       t_Constant,                      // Value in ConstantVal.
       t_InlineAsm,                     // Value in FTy/StrVal/StrVal2/UIntVal.
       t_ConstantStruct,                // Value in ConstantStructElts.
-      t_PackedConstantStruct           // Value in ConstantStructElts.
+      t_PackedConstantStruct,          // Value in ConstantStructElts.
+      t_APSIntSplat,                   // Value in APSIntVal.
+      t_APFloatSplat                   // Value in APFloatVal.
     } Kind = t_LocalID;
 
     LLLexer::LocTy Loc;
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 0683291faae72..dd55afee21033 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -335,6 +335,7 @@ enum Kind {
   kw_extractelement,
   kw_insertelement,
   kw_shufflevector,
+  kw_splat,
   kw_extractvalue,
   kw_insertvalue,
   kw_blockaddress,
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 2f7fc5652c2cd..b76cb1beecf3c 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -81,6 +81,7 @@ class ConstantInt final : public ConstantData {
   APInt Val;
 
   ConstantInt(IntegerType *Ty, const APInt &V);
+  ConstantInt(VectorType *Ty, const APInt &V);
 
   void destroyConstantImpl();
 
@@ -98,6 +99,13 @@ class ConstantInt final : public ConstantData {
   /// value. Otherwise return a ConstantInt for the given value.
   static Constant *get(Type *Ty, uint64_t V, bool IsSigned = false);
 
+  /// WARNING: Incomplete support, do not use. These methods exist for early
+  /// prototyping, for most use cases ConstantInt::get() should be used.
+  /// Return a ConstantInt with a splat of the given value.
+  static ConstantInt *getSplat(LLVMContext &Context, ElementCount EC,
+                               const APInt &V);
+  static ConstantInt *getSplat(const VectorType *Ty, const APInt &V);
+
   /// Return a ConstantInt with the specified integer value for the specified
   /// type. If the type is wider than 64 bits, the value will be zero-extended
   /// to fit the type, unless IsSigned is true, in which case the value will
@@ -136,7 +144,11 @@ class ConstantInt final : public ConstantData {
   inline const APInt &getValue() const { return Val; }
 
   /// getBitWidth - Return the bitwidth of this constant.
-  unsigned getBitWidth() const { return Val.getBitWidth(); }
+  unsigned getBitWidth() const {
+    assert(getType()->isIntegerTy() &&
+           "Returning the bitwidth of a vector constant is not support!");
+    return Val.getBitWidth();
+  }
 
   /// Return the constant as a 64-bit unsigned integer value after it
   /// has been zero extended as appropriate for the type of this constant. Note
@@ -170,10 +182,9 @@ class ConstantInt final : public ConstantData {
   /// Determine if this constant's value is same as an unsigned char.
   bool equalsInt(uint64_t V) const { return Val == V; }
 
-  /// getType - Specialize the getType() method to always return an IntegerType,
-  /// which reduces the amount of casting needed in parts of the compiler.
-  ///
-  inline IntegerType *getType() const {
+  /// Variant of the getType() method to always return an IntegerType, which
+  /// reduces the amount of casting needed in parts of the compiler.
+  inline IntegerType *getIntegerType() const {
     return cast<IntegerType>(Value::getType());
   }
 
@@ -279,6 +290,13 @@ class ConstantFP final : public ConstantData {
   /// value. Otherwise return a ConstantFP for the given value.
   static Constant *get(Type *Ty, const APFloat &V);
 
+  /// WARNING: Incomplete support, do not use. These methods exist for early
+  /// prototyping, for most use cases ConstantFP::get() should be used.
+  /// Return a ConstantFP with a splat of the given value.
+  static ConstantFP *getSplat(LLVMContext &Context, ElementCount EC,
+                              const APFloat &V);
+  static ConstantFP *getSplat(const VectorType *Ty, const APFloat &V);
+
   static Constant *get(Type *Ty, StringRef Str);
   static ConstantFP *get(LLVMContext &Context, const APFloat &V);
   static Constant *getNaN(Type *Ty, bool Negative = false,
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index cef9f6ec179ba..c24bb1bb2cf9f 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -6081,7 +6081,7 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
   Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());
 
   auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
-  if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
+  if (!OffsetConstInt || OffsetConstInt->getIntegerType()->getBitWidth() > 64)
     return nullptr;
 
   APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc(
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 09a205c445dbe..eb47284feb218 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -697,6 +697,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(uinc_wrap);
   KEYWORD(udec_wrap);
 
+  KEYWORD(splat);
   KEYWORD(vscale);
   KEYWORD(x);
   KEYWORD(blockaddress);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index d236b6cfa9000..94e1a51aa2e75 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -3952,6 +3952,31 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) {
     return false;
   }
 
+  case lltok::kw_splat: {
+    Lex.Lex();
+    if (parseToken(lltok::lparen, "expected '(' after vector splat"))
+      return true;
+    Constant *C;
+    if (parseGlobalTypeAndValue(C))
+      return true;
+    if (parseToken(lltok::rparen, "expected ')' at end of vector splat"))
+      return true;
+
+    if (auto *CI = dyn_cast<ConstantInt>(C)) {
+      ID.APSIntVal = CI->getValue();
+      ID.Kind = ValID::t_APSIntSplat;
+      return false;
+    }
+
+    if (auto *CFP = dyn_cast<ConstantFP>(C)) {
+      ID.APFloatVal = CFP->getValue();
+      ID.Kind = ValID::t_APFloatSplat;
+      return false;
+    }
+
+    return tokError("invalid splat operand");
+  }
+
   case lltok::kw_getelementptr:
   case lltok::kw_shufflevector:
   case lltok::kw_insertelement:
@@ -5716,9 +5741,23 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
     ID.APSIntVal = ID.APSIntVal.extOrTrunc(Ty->getPrimitiveSizeInBits());
     V = ConstantInt::get(Context, ID.APSIntVal);
     return false;
+  case ValID::t_APSIntSplat:
+    if (!Ty->isVectorTy() || !Ty->getScalarType()->isIntegerTy())
+      return error(ID.Loc, "expected an integer vector result");
+    if (ID.APSIntVal.getBitWidth() !=
+        cast<IntegerType>(Ty->getScalarType())->getBitWidth())
+      return error(ID.Loc, "operand type must match result element type");
+    V = ConstantInt::getSplat(cast<VectorType>(Ty), ID.APSIntVal);
+    return false;
   case ValID::t_APFloat:
-    if (!Ty->isFloatingPointTy() ||
-        !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
+  case ValID::t_APFloatSplat: {
+    if ((ID.Kind == ValID::t_APFloat && !Ty->isFloatingPointTy()) ||
+        (ID.Kind == ValID::t_APFloatSplat && !Ty->isVectorTy()))
+      return error(ID.Loc, "floating point constant invalid for type");
+
+    Type *ScalarTy = Ty->getScalarType();
+    if (!ScalarTy->isFloatingPointTy() ||
+        !ConstantFP::isValueValidForType(ScalarTy, ID.APFloatVal))
       return error(ID.Loc, "floating point constant invalid for type");
 
     // The lexer has no type info, so builds all half, bfloat, float, and double
@@ -5727,13 +5766,13 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
       // Check for signaling before potentially converting and losing that info.
       bool IsSNAN = ID.APFloatVal.isSignaling();
       bool Ignored;
-      if (Ty->isHalfTy())
+      if (ScalarTy->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isBFloatTy())
+      else if (ScalarTy->isBFloatTy())
         ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
                               &Ignored);
-      else if (Ty->isFloatTy())
+      else if (ScalarTy->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);
       if (IsSNAN) {
@@ -5745,13 +5784,18 @@ bool LLParser::convertValIDToValue(Type *Ty, ValID &ID, Value *&V,
                                          ID.APFloatVal.isNegative(), &Payload);
       }
     }
-    V = ConstantFP::get(Context, ID.APFloatVal);
+
+    if (auto *VTy = dyn_cast<VectorType>(Ty))
+      V = ConstantFP::getSplat(VTy, ID.APFloatVal);
+    else
+      V = ConstantFP::get(Context, ID.APFloatVal);
 
     if (V->getType() != Ty)
       return error(ID.Loc, "floating point constant does not have type '" +
                                getTypeString(Ty) + "'");
 
     return false;
+  }
   case ValID::t_Null:
     if (!Ty->isPointerTy())
       return error(ID.Loc, "null must be a pointer type");
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index e4c3770946b3a..b661d36fb6854 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -3022,50 +3022,62 @@ Error BitcodeReader::parseConstants() {
       V = Constant::getNullValue(CurTy);
       break;
     case bitc::CST_CODE_INTEGER:   // INTEGER: [intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid integer const record");
-      V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
+
+      if (auto *VTy = dyn_cast<VectorType>(CurTy)) {
+        auto *ScalarTy = cast<IntegerType>(VTy->getScalarType());
+        unsigned BitWidth = ScalarTy->getBitWidth();
+        APInt VInt(BitWidth, decodeSignRotatedValue(Record[0]));
+        V = ConstantInt::getSplat(VTy, VInt);
+      } else
+        V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
       break;
     case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
-      if (!CurTy->isIntegerTy() || Record.empty())
+      if (!CurTy->isIntOrIntVectorTy() || Record.empty())
         return error("Invalid wide integer const record");
 
-      APInt VInt =
-          readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
-      V = ConstantInt::get(Context, VInt);
-
+      auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
+      APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
+      if (auto *VTy = dyn_cast<VectorType>(CurTy))
+        V = ConstantInt::getSplat(VTy, VInt);
+      else
+        V = ConstantInt::get(Context, VInt);
       break;
     }
     case bitc::CST_CODE_FLOAT: {    // FLOAT: [fpval]
       if (Record.empty())
         return error("Invalid float const record");
-      if (CurTy->isHalfTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
-                                             APInt(16, (uint16_t)Record[0])));
-      else if (CurTy->isBFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
-                                             APInt(16, (uint32_t)Record[0])));
-      else if (CurTy->isFloatTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
-                                             APInt(32, (uint32_t)Record[0])));
-      else if (CurTy->isDoubleTy())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
-                                             APInt(64, Record[0])));
-      else if (CurTy->isX86_FP80Ty()) {
+
+      APFloat Val(APFloat::Bogus());
+      auto *ScalarTy = CurTy->getScalarType();
+      if (ScalarTy->isHalfTy())
+        Val = APFloat(APFloat::IEEEhalf(), APInt(16, (uint16_t)Record[0]));
+      else if (ScalarTy->isBFloatTy())
+        Val = APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0]));
+      else if (ScalarTy->isFloatTy())
+        Val = APFloat(APFloat::IEEEsingle(), APInt(32, (uint32_t)Record[0]));
+      else if (ScalarTy->isDoubleTy())
+        Val = APFloat(APFloat::IEEEdouble(), APInt(64, Record[0]));
+      else if (ScalarTy->isX86_FP80Ty()) {
         // Bits are not stored the same way as a normal i80 APInt, compensate.
         uint64_t Rearrange[2];
         Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
         Rearrange[1] = Record[0] >> 48;
-        V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
-                                             APInt(80, Rearrange)));
-      } else if (CurTy->isFP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
-                                             APInt(128, Record)));
-      else if (CurTy->isPPC_FP128Ty())
-        V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
-                                             APInt(128, Record)));
-      else
+        Val = APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange));
+      } else if (ScalarTy->isFP128Ty())
+        Val = APFloat(APFloat::IEEEquad(), APInt(128, Record));
+      else if (ScalarTy->isPPC_FP128Ty())
+        Val = APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record));
+      else {
         V = UndefValue::get(CurTy);
+        break;
+      }
+
+      if (auto *VTy = dyn_cast<VectorType>(CurTy))
+        V = ConstantFP::getSplat(VTy, Val);
+      else
+        V = ConstantFP::get(Context, Val);
       break;
     }
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 8239775d04865..0f5b9ff9ebd72 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -2577,18 +2577,18 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
     } else if (isa<UndefValue>(C)) {
       Code = bitc::CST_CODE_UNDEF;
     } else if (const ConstantInt *IV = dyn_cast<ConstantInt>(C)) {
-      if (IV->getBitWidth() <= 64) {
+      if (IV->getValue().getBitWidth() <= 64) {
         uint64_t V = IV->getSExtValue();
         emitSignedInt64(Record, V);
         Code = bitc::CST_CODE_INTEGER;
         AbbrevToUse = CONSTANTS_INTEGER_ABBREV;
-      } else {                             // Wide integers, > 64 bits in size.
+      } else { // Wide integers, > 64 bits in size.
         emitWideAPInt(Record, IV->getValue());
         Code = bitc::CST_CODE_WIDE_INTEGER;
       }
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
-      Type *Ty = CFP->getType();
+      Type *Ty = CFP->getType()->getScalarType();
       if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
           Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index fabc79adbd33d..e37da7c460f26 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1394,16 +1394,32 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
 static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   AsmWriterContext &WriterCtx) {
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
-    if (CI->getType()->isIntegerTy(1)) {
-      Out << (CI->getZExtValue() ? "true" : "false");
-      return;
+    if (CI->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out);
+      Out << " ";
     }
-    Out << CI->getValue();
+
+    if (CI->getType()->getScalarType()->isIntegerTy(1))
+      Out << (CI->getZExtValue() ? "true" : "false");
+    else
+      Out << CI->getValue();
+
+    if (CI->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
   if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
     const APFloat &APF = CFP->getValueAPF();
+
+    if (CFP->getType()->isVectorTy()) {
+      Out << "splat (";
+      WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out);
+      Out << " ";
+    }
+
     if (&APF.getSemantics() == &APFloat::IEEEsingle() ||
         &APF.getSemantics() == &APFloat::IEEEdouble()) {
       // We would like to output the FP constant value in exponential notation,
@@ -1429,6 +1445,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
         // Reparse stringized version!
         if (APFloat(APFloat::IEEEdouble(), StrVal).convertToDouble() == Val) {
           Out << StrVal;
+
+          if (CFP->getType()->isVectorTy())
+            Out << ")";
+
           return;
         }
       }
@@ -1454,6 +1474,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
         }
       }
       Out << format_hex(apf.bitcastToAPInt().getZExtValue(), 0, /*Upper=*/true);
+
+      if (CFP->getType()->isVectorTy())
+        Out << ")";
+
       return;
     }
 
@@ -1468,7 +1492,6 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   /*Upper=*/true);
       Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
                                   /*Upper=*/true);
-      return;
     } else if (&APF.getSemantics() == &APFloat::IEEEquad()) {
       Out << 'L';
       Out << format_hex_no_prefix(API.getLoBits(64).getZExtValue(), 16,
@@ -1491,6 +1514,10 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
                                   /*Upper=*/true);
     } else
       llvm_unreachable("Unsupported floating point type");
+
+    if (CFP->getType()->isVectorTy())
+      Out << ")";
+
     return;
   }
 
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index d499d74f7ba01..c478040234078 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -868,7 +868,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
           }
 
           if (GVAlign > 1) {
-            unsigned DstWidth = CI2->getType()->getBitWidth();
+            unsigned DstWidth = CI2->getIntegerType()->getBitWidth();
             unsigned SrcWidth = std::min(DstWidth, Log2(GVAlign));
             APInt BitsNotSet(APInt::...
[truncated]

Copy link

github-actions bot commented Dec 5, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@paulwalker-arm
Copy link
Collaborator Author

The PR contains a couple of commits that unless there's disagreement I'm tempted to land directly but have held off just in case there's any buyer remorse about extending ConstantInt/ConstantFP to cover vector types.

For similar reasons I've not updated the LangRef as I don't really want people using the support until at least code generation works.

/// Return a ConstantInt with a splat of the given value.
static ConstantInt *getSplat(LLVMContext &Context, ElementCount EC,
const APInt &V);
static ConstantInt *getSplat(const VectorType *Ty, const APInt &V);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these APIs should exist. ConstantInt::get() already supports creation of splats, they just aren't represented as ConstantInts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're in a transition period and thus I need an absolute way to create a vector ConstantInt (e.g. when parsing ll files and bitcode). Today ConstantInt::get() returns other Constant types to represent splats and that must be maintained for correctness because there are many code paths for which a vector ConstantInt will break.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isn't the right way to phase in the change. I think the splat syntax should just return whatever ConstantVector::getSplat() produces, and what that produces can be controlled by the opt flags you have introduced.

That means that the splat syntax becomes usable right away as a short-hand for producing the representations we currently use, and will switch to producing plain ConstantInt/ConstantFP once the flag is flipped (or in tests that explicitly flip it).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks. It seems I've incorrectly assumed from an IR parsing and printing point of view there is a requirement for IR_out == IR_in. Your suggestion certainly means I can break the work up some more so I get that sorted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've moved the textual IR side of things to #74620 following the suggestion to have splat(x) be synonymous with ConstantInt/FP::get().

unsigned getBitWidth() const { return Val.getBitWidth(); }
unsigned getBitWidth() const {
assert(getType()->isIntegerTy() &&
"Returning the bitwidth of a vector constant is not support!");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I think this API should work with vectors as well (with the current implementation, i.e. returning the bitwidth of the scalar value). You can just adjust the comment to clarify.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultimately I think this should be more explicit, for example getScalarBitWidth(). For this patch though the need was tiny so I made this change purely to trigger asserts at this level when failure cases are hit once I start expanding the testing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO this is not necessary. If the name were getSizeInBits() I would agree, but the term "bit width" implies that we're talking about scalar. We don't use the term "bit width" to refer to full size of a vector.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

@@ -343,7 +343,7 @@ static bool verifyTripCount(Value *RHS, Loop *L,
// If the RHS of the compare is equal to the backedge taken count we need
// to add one to get the trip count.
if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1);
ConstantInt *One = ConstantInt::get(ConstantRHS->getIntegerType(), 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you changing usages like these? This code should work fine with getType().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much the same reason as with getBitWidth() but the need for immediate change is greater so I renamed this method so changes to existing code paths are minimal whilst still providing a route to trigger asserts once testing is expanded.

I did consider just removing the method and adding the necessary casts but figured somebody went to the trouble of adding the override in the first place so I maintained this but under a modified name. Do you think the shorthand is not worth it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really follow here. Even if you rename the overload on ConstantInt, there will still be the method inherited from Value::getType(), and using that method here should work (and be forward-compatible with vector ConstantInt), because ConstantInt::get doesn't actually require that the argument is an IntegerType.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be static ConstantInt *get(IntegerType *Ty, uint64_t V, bool IsSigned = false);?

I don't think I made this change up. I wanted a mechanical change so just removed the overload and the compiler told me all the places that relied on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh, I think the issue is that it's stored in a ConstantInt * variable. Changing it to Constant * should make it pick the other overload, I think.

tschuett added a commit to tschuett/llvm-project that referenced this pull request Dec 8, 2023
Cleanups, preparation for more combines, add known bits for constant
conditions, combine selects where the false and true register are
constants, and improve support for vector conditions.

AMDGPU supports vector conditions. X86 has a todo for vector
conditions. AArch64 SVE supports SEL for vector conditions. How to
implement vector conditions with NEON (with bsl), see arm64-vselect.ll
? Vector select asserts in the instruction selector.

buildNot does not support scalable vectors. We cannot create scalable
constant vectors of -1 and there is no G_Not. AArch64 SVE has a NOT
and a DUP for broadcasting. Something akin to G_CONSTANT_SPLAT,
G_CONSTANT_VECTOR, G_SPLAT_VECTOR, G_BRODCAST, or G_HOMOGENOUS_VECTOR
that takes an immediate and creates a (fixed or scalable) vector where
all elements are the immediate might solve the buildNot challenge,
facilitates new combines, pattern matching, and new selecting
optimizations.

P.S. We need to support integer and float.

llvm#74502

```c
<vscale x 4 x i32> splat (i32 -1)
```
tschuett added a commit to tschuett/llvm-project that referenced this pull request Dec 14, 2023
tschuett added a commit that referenced this pull request Dec 14, 2023
tschuett added a commit to tschuett/llvm-project that referenced this pull request Dec 20, 2023
Cleanups, preparation for more combines, add known bits for constant
conditions, combine selects where the false and true register are
constants, and improve support for vector conditions.

AMDGPU supports vector conditions. X86 has a todo for vector
conditions. AArch64 SVE supports SEL for vector conditions. How to
implement vector conditions with NEON (with bsl), see arm64-vselect.ll
? Vector select asserts in the instruction selector.

buildNot does not support scalable vectors. We cannot create scalable
constant vectors of -1 and there is no G_Not. AArch64 SVE has a NOT
and a DUP for broadcasting. Something akin to G_CONSTANT_SPLAT,
G_CONSTANT_VECTOR, G_SPLAT_VECTOR, G_BRODCAST, or G_HOMOGENOUS_VECTOR
that takes an immediate and creates a (fixed or scalable) vector where
all elements are the immediate might solve the buildNot challenge,
facilitates new combines, pattern matching, and new selecting
optimizations.

P.S. We need to support integer and float.

llvm#74502

```c
<vscale x 4 x i32> splat (i32 -1)
```
@paulwalker-arm
Copy link
Collaborator Author

Hi @nikic, I had to put this on the back burner but I'm back now. I believe the rebase last year was the completion point for this stage of the work so the patch should be good to review. Given the time though I'll rebase and retest tomorrow but wanted to check if your assessment matches mine or whether you feel more is needed?

Separately to this I will create some NFC patches to port tests to the newer splat syntax so I can more easily find the failure paths when enabling the new functionality.

@paulwalker-arm paulwalker-arm force-pushed the constant-vector-splats branch 2 times, most recently from 70b4fb8 to 6c6baf8 Compare February 9, 2024 15:49
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks conceptually fine to me.

llvm/lib/IR/AsmWriter.cpp Outdated Show resolved Hide resolved
llvm/lib/IR/AsmWriter.cpp Outdated Show resolved Hide resolved
llvm/lib/IR/Constants.cpp Show resolved Hide resolved
NOTE: For brevity the following talks about ConstantInt but
everything extends to cover ConstantFP as well.

Whilst ConstantInt::get() supports the creation of vectors whereby
each lane has the same value, it achieves this via other constants:

  * ConstantVector for fixed-length vectors
  * ConstantExprs for scalable vectors

However, ConstantExprs are being deprecated and ConstantVector is
not space efficient for larger vector types. By extending ConstantInt
we can represent vector splats by only storing the underlying scalar
value.

More specifically:

 * ConstantInt gains an ElementCount variant of get().
 * LLVMContext is extended to map <EC,APInt>->ConstantInt.
 * BitcodeReader/Writer support is extended to allow vector types.

Whilst this patch adds the base support, more work is required
before it's production ready. For example, there's likely to be
many places where isa<ConstantInt> assumes a scalar type. Accordingly
the default behaviour of ConstantInt::get() remains unchanged but a
set of flags are added to allow wider testing and thus help with the
migration:

  --use-constant-int-for-fixed-length-splat
  --use-constant-fp-for-fixed-length-splat
  --use-constant-int-for-scalable-splat
  --use-constant-fp-for-scalable-splat

NOTE: No change is required to the bitcode format because types and
values are handled separately.

NOTE: For similar reasons as above, code generation doesn't work
out-the-box.
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -1384,6 +1440,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {

Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
if (!EC.isScalable()) {
// Maintain special handling of zero.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering whether this is something you want to keep long term or just initially?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to take offers but see it as temporary.

@nikic
Copy link
Contributor

nikic commented Feb 19, 2024

Something important I forgot: Can you please test what happens with <2 x i8> <i8 42, i8 42> in use-constant-int-for-fixed-length-splat mode? This needs to create the same representation as <2 x i8> splat (i8 42), to maintain a canonical form of constants.

(This applies to a lesser degree to the scalable case -- in that case it's "just" a missed constant folding opportunity.)

@paulwalker-arm paulwalker-arm merged commit cbb24e1 into llvm:main Feb 22, 2024
4 checks passed
@paulwalker-arm paulwalker-arm deleted the constant-vector-splats branch March 19, 2024 18:20
qihangkong pushed a commit to rvgpu/rvgpu-llvm that referenced this pull request Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category llvm:analysis llvm:ir llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants