From 13632a62fed2d772519312aa8c29be7f5295fda0 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 18 Oct 2023 08:28:17 -0700 Subject: [PATCH] [mlir] Add debug messages for failures of isValidIntOrFloat I have run into assertion failures quite often when calling this method via `DenseElementsAttr::get`, and I think this would help, at the very least, by printing out the bit width size mismatches, rather than a plain assertion failure. I included all the other cases in the method for completeness --- mlir/lib/IR/BuiltinAttributes.cpp | 51 +++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 64949a1710729..89b1ed67f5d06 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -20,9 +20,12 @@ #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/Endian.h" #include +#define DEBUG_TYPE "builtinattributes" + using namespace mlir; using namespace mlir::detail; @@ -1098,24 +1101,44 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type, static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, bool isSigned) { // Make sure that the data element size is the same as the type element width. - if (getDenseElementBitWidth(type) != - static_cast(dataEltSize * CHAR_BIT)) + auto denseEltBitWidth = getDenseElementBitWidth(type); + auto dataSize = static_cast(dataEltSize * CHAR_BIT); + if (denseEltBitWidth != dataSize) { + LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width " + << denseEltBitWidth << " to match data size " + << dataSize << " for type " << type << "\n"); return false; + } // Check that the element type is either float or integer or index. - if (!isInt) - return llvm::isa(type); + if (!isInt) { + bool valid = llvm::isa(type); + if (!valid) + LLVM_DEBUG(llvm::dbgs() + << "expected float type when isInt is false, but found " + << type << "\n"); + return valid; + } if (type.isIndex()) return true; auto intType = llvm::dyn_cast(type); - if (!intType) + if (!intType) { + LLVM_DEBUG(llvm::dbgs() + << "expected integer type when isInt is true, but found " << type + << "\n"); return false; + } // Make sure signedness semantics is consistent. if (intType.isSignless()) return true; - return intType.isSigned() ? isSigned : !isSigned; + + bool valid = intType.isSigned() == isSigned; + if (!valid) + LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned + << " to match type " << type << "\n"); + return valid; } /// Defaults down the subclass implementation. @@ -1247,12 +1270,14 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { DenseElementsAttr DenseElementsAttr::mapValues(Type newElementType, function_ref mapping) const { - return llvm::cast(*this).mapValues(newElementType, mapping); + return llvm::cast(*this).mapValues(newElementType, + mapping); } DenseElementsAttr DenseElementsAttr::mapValues( Type newElementType, function_ref mapping) const { - return llvm::cast(*this).mapValues(newElementType, mapping); + return llvm::cast(*this).mapValues(newElementType, + mapping); } ShapedType DenseElementsAttr::getType() const { @@ -1331,8 +1356,9 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, bool isInt, bool isSigned) { assert(::isValidIntOrFloat( - llvm::cast(type.getElementType()).getElementType(), - dataEltSize / 2, isInt, isSigned)); + llvm::cast(type.getElementType()).getElementType(), + dataEltSize / 2, isInt, isSigned) && + "Try re-running with -debug-only=builtinattributes"); int64_t numElements = data.size() / dataEltSize; (void)numElements; @@ -1347,8 +1373,9 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, int64_t dataEltSize, bool isInt, bool isSigned) { - assert( - ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)); + assert(::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, + isSigned) && + "Try re-running with -debug-only=builtinattributes"); int64_t numElements = data.size() / dataEltSize; assert(numElements == 1 || numElements == type.getNumElements());