diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index 1af552362a26a..f3e239bedb7f7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -111,9 +111,6 @@ struct ScanToArithOps : public OpRewritePattern { if (!isValidKind(isInt, scanOp.getKind())) return failure(); - VectorType resType = destType; - Value result = arith::ConstantOp::create(rewriter, loc, resType, - rewriter.getZeroAttr(resType)); int64_t reductionDim = scanOp.getReductionDim(); bool inclusive = scanOp.getInclusive(); int64_t destRank = destType.getRank(); @@ -123,10 +120,16 @@ struct ScanToArithOps : public OpRewritePattern { SmallVector reductionShape(destShape); SmallVector reductionScalableDims(destType.getScalableDims()); + // Check before creating any IR so that returning failure() does not + // violate the pattern API contract. if (reductionScalableDims[reductionDim]) return rewriter.notifyMatchFailure( scanOp, "Trying to reduce scalable dimension - not yet supported!"); + VectorType resType = destType; + Value result = arith::ConstantOp::create(rewriter, loc, resType, + rewriter.getZeroAttr(resType)); + // The reduction dimension, after reducing, becomes 1. It's a fixed-width // dimension - no need to touch the scalability flag. reductionShape[reductionDim] = 1;