diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h index fe85908e476ff..8e0a2e2efa7e2 100644 --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -543,6 +543,14 @@ using has_get_checked_method = decltype(T::getChecked(std::declval()...)); /// the base `StorageUserBase::getChecked` template (e.g. ArrayAttr), that /// template instantiation requires a complete storage type which may not be /// available in the bytecode reading TU. +/// +/// NOTE: callers must ensure T has an explicit no-context +/// `getChecked(emitError, params...)` overload (i.e. the overload set does not +/// consist solely of inherited StorageUserBase templates). Calling this helper +/// for types that only inherit the base template triggers a compiler bug in +/// clang 21/22 (assertion `isInitialized()` in ImplicitConversionSequence) +/// during SFINAE evaluation. The generated bytecode read functions use +/// `get` for such types instead of this helper. template auto getChecked(function_ref emitError, MLIRContext *context, Ts &&...params) { diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td index 8a1f3d5e5b2e0..d0b13a20a9040 100644 --- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -117,7 +117,7 @@ def FileLineColLoc : DialectAttribute<(attr } let cType = "FusedLoc", - cBuilder = "cast(getChecked([&]() { return reader.emitError(); }, context, $_args))" in { + cBuilder = "cast(get(context, $_args))" in { def FusedLoc : DialectAttribute<(attr Array:$locations )> { @@ -144,7 +144,7 @@ def DenseResourceElementsAttr : DialectAttribute<(attr ResourceHandle<"DenseResourceElementsHandle">:$rawHandle )> { // Note: order of serialization does not match order of builder. - let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, type, *rawHandle)"; + let cBuilder = "get<$_resultType>(context, type, *rawHandle)"; } let cType = "RankedTensorType" in { @@ -153,6 +153,9 @@ def RankedTensorType : DialectType<(type Type:$elementType )> { let printerPredicate = "!$_val.getEncoding()"; + // Use getChecked to return null (and emit a diagnostic) instead of asserting + // when the element type is invalid. + let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)"; } def RankedTensorTypeWithEncoding : DialectType<(type @@ -237,7 +240,11 @@ def Float128Type : DialectType<(type)>; def ComplexType : DialectType<(type Type:$elementType -)>; +)> { + // Use getChecked to return null (and emit a diagnostic) instead of asserting + // when the element type is not a valid complex element type. + let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)"; +} def MemRefLayout: WithType<"MemRefLayoutAttrInterface", Attribute>; @@ -248,6 +255,9 @@ def MemRefType : DialectType<(type MemRefLayout:$layout )> { let printerPredicate = "!$_val.getMemorySpace()"; + // Use getChecked to return null (and emit a diagnostic) instead of asserting + // when element type or layout is invalid. + let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)"; } def MemRefTypeWithMemSpace : DialectType<(type @@ -288,7 +298,11 @@ def UnrankedMemRefTypeWithMemSpace : DialectType<(type def UnrankedTensorType : DialectType<(type Type:$elementType -)>; +)> { + // Use getChecked to return null (and emit a diagnostic) instead of asserting + // when the element type is invalid. + let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)"; +} let cType = "VectorType" in { def VectorType : DialectType<(type diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td index 184c81e6a5f7d..2e31c97f11f91 100644 --- a/mlir/include/mlir/IR/BytecodeBase.td +++ b/mlir/include/mlir/IR/BytecodeBase.td @@ -147,11 +147,23 @@ class DialectAttrOrType { class DialectAttribute : DialectAttrOrType, AttributeKind { let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))"; - let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)"; + // Use get<> by default. Types with an explicit no-context getChecked + // (verifyInvariants that should return null instead of asserting) should + // override cBuilder to call T::getChecked or getChecked explicitly. + // Avoid using the generic getChecked helper for types that only inherit + // StorageUserBase::getChecked, because the SFINAE probe in that helper + // triggers a clang 21/22 compiler crash for such types. + let cBuilder = "get<$_resultType>(context, $_args)"; } class DialectType : DialectAttrOrType, TypeKind { let cParser = "succeeded($_reader.readType<$_resultType>($_var))"; - let cBuilder = "getChecked<$_resultType>([&]() { return reader.emitError(); }, context, $_args)"; + // Use get<> by default. Types with an explicit no-context getChecked + // (verifyInvariants that should return null instead of asserting) should + // override cBuilder to call T::getChecked or getChecked explicitly. + // Avoid using the generic getChecked helper for types that only inherit + // StorageUserBase::getChecked, because the SFINAE probe in that helper + // triggers a clang 21/22 compiler crash for such types. + let cBuilder = "get<$_resultType>(context, $_args)"; } class DialectAttributes { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 2c9e9c040d460..8769f0655759f 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1320,8 +1320,8 @@ Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, loc, parent->getRegion(0).front(), elementCount, builder, integerType); Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); - Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType, - builder.getI32IntegerAttr(offset)); + Value offsetOp = spirv::ConstantOp::create( + builder, loc, integerType, builder.getIntegerAttr(integerType, offset)); auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp); auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp})); diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index f0a210a2ededb..3edb9af671150 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -511,6 +511,17 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, // Do (2) BitVector successorNonLive = markLives(operandValues, nonLiveSet, la).flip(); + // A block argument should not be considered dead if the liveness analysis + // determines it is live. This can happen when this branch is in a + // statically unreachable (dead) block: the forwarded operand appears dead + // because it is in the dead block, but the successor block argument may + // still be live because it is also forwarded from other live predecessor + // branches. + for (auto [index, blockArg] : + llvm::enumerate(successorBlock->getArguments())) { + if (successorNonLive[index] && hasLive({blockArg}, nonLiveSet, la)) + successorNonLive.reset(index); + } collectNonLiveValues(nonLiveSet, successorBlock->getArguments(), successorNonLive); diff --git a/mlir/test/Bytecode/invalid/invalid-type-remapping.mlir b/mlir/test/Bytecode/invalid/invalid-type-remapping.mlir index 44d0a4eb8bb4a..cfe6f79bc29d5 100644 --- a/mlir/test/Bytecode/invalid/invalid-type-remapping.mlir +++ b/mlir/test/Bytecode/invalid/invalid-type-remapping.mlir @@ -43,6 +43,18 @@ module { // ----- +// CHECK: invalid element type for complex +// CHECK: failed to read bytecode +// ComplexType whose element type is replaced by one that is neither int nor +// float — previously crashed with an assertion when using get. +module { + func.func @complex_unsupported_elem_type(%arg0: complex) { + return + } +} + +// ----- + // CHECK: DenseTypedElementsAttr element type must implement DenseElementTypeInterface, but got: '!test.i32' // CHECK: failed to read bytecode // DenseTypedElementsAttr whose element type is replaced by one that does not diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 64088ce15cd48..d2f5c4ab95574 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -868,3 +868,29 @@ module @func_with_non_call_users { } spirv.EntryPoint "GLCompute" @callee } + +// ----- + +// Regression test: verify that the pass does not crash when a branch op +// forwards a value through intermediate block arguments to a join block. +// Previously, processBranchOp incorrectly used forwarded operand liveness +// to determine successor block argument liveness, causing incorrect marking +// of live values as dead. (https://github.com/llvm/llvm-project/issues/182263) +// +// CHECK-LABEL: func.func @branch_forwarded_block_arg_liveness +// CHECK-CANONICALIZE-LABEL: func.func @branch_forwarded_block_arg_liveness +func.func @branch_forwarded_block_arg_liveness(%x: i64) -> i64 { + %c1 = arith.constant 1 : i64 + %c2 = arith.constant 2 : i64 + %cmp = arith.cmpi slt, %c1, %c2 : i64 + cf.cond_br %cmp, ^bb1(%c1 : i64), ^bb2(%c1 : i64) +^bb1(%a: i64): + cf.br ^bb3(%a : i64) +^bb2(%b: i64): + cf.br ^bb3(%b : i64) +^bb3(%arg0: i64): + // CHECK: arith.addi + // CHECK-CANONICALIZE: arith.addi + %final = arith.addi %c1, %arg0 : i64 + func.return %final : i64 +} diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp index dd178b5e5d232..d8004a663aaee 100644 --- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp +++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp @@ -206,10 +206,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType, auto funScope = ios.scope("{\n", "}"); if (args.empty()) { - ios << formatv( - "return getChecked<{0}>([&]() {{ return reader.emitError(); }, " - "context);\n", - returnType); + ios << formatv("return get<{0}>(context);\n", returnType); return; }