diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index 61d9357e19bb4..66dd7c8f36e6b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -117,6 +117,16 @@ class BroadcastOpLowering : public OpRewritePattern { VectorType resType = VectorType::get(dstType.getShape().drop_front(), eltType, dstType.getScalableDims().drop_front()); + + // For "stretch not at start" with a scalable outer dimension we would need + // to emit an scf.for loop, which is not yet supported. Check before + // creating any IR so that returning failure() does not violate the pattern + // API contract. + if (m != 0 && dstType.getScalableDims()[0]) { + // TODO: For scalable vectors we should emit an scf.for loop. + return failure(); + } + Value result = ub::PoisonOp::create(rewriter, loc, dstType); if (m == 0) { // Stetch at start. @@ -126,10 +136,6 @@ class BroadcastOpLowering : public OpRewritePattern { result = vector::InsertOp::create(rewriter, loc, bcst, result, d); } else { // Stetch not at start. - if (dstType.getScalableDims()[0]) { - // TODO: For scalable vectors we should emit an scf.for loop. - return failure(); - } for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d); Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);