diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 3b148f9021666..826d7547716e4 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -134,7 +134,7 @@ class AffineYieldOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(AffineYieldOp op, PatternRewriter &rewriter) const override { - if (isa(op->getParentOp())) { + if (isa(op->getParentOp())) { // Terminator is rewritten as part of the "affine.parallel" lowering // pattern. return failure(); @@ -230,8 +230,12 @@ class AffineParallelLowering : public OpRewritePattern { static_cast(cast(reduction).getInt())); assert(reductionOp && "Reduction operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; - identityVals.push_back( - arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); + Value identityVal = + arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc); + if (!identityVal) + return rewriter.notifyMatchFailure( + op, "unsupported reduction kind for identity value"); + identityVals.push_back(identityVal); } parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 6f368604df65a..a25c0ed23456a 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2799,9 +2799,10 @@ std::optional mlir::arith::getNeutralElement(Operation *op) { Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue) { - auto attr = - getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); - return arith::ConstantOp::create(builder, loc, attr); + if (auto attr = +getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue)) + return arith::ConstantOp::create(builder, loc, attr); + return {}; } /// Return the value obtained by applying the reduction operation kind diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-invalid.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-invalid.mlir new file mode 100644 index 0000000000000..a357b27a2430e --- /dev/null +++ b/mlir/test/Conversion/AffineToStandard/lower-affine-invalid.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt --lower-affine %s 2>&1 | FileCheck %s + +// Test that affine.parallel with an unsupported reduction kind ("assign") +// does not crash but emits a proper error message. Previously, +// getIdentityValue would be called with a null TypedAttr and crash inside +// arith::ConstantOp::build with "Failed to infer result type(s)". + +// CHECK: Reduction operation type not supported +// CHECK-NOT: Failed to infer result type + +func.func @affine_parallel_assign_reduction_no_crash(%n: index) -> i32 { + %0 = affine.parallel (%i) = (0) to (%n) reduce ("assign") -> i32 { + %c0 = arith.constant 0 : i32 + affine.yield %c0 : i32 + } + return %0 : i32 +}