diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index a19cce4b919a8..8f3232f01544f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp : DefaultValuedAttr:$tile_sizes, OptionalAttr:$mapping); let results = (outs Variadic:$fill_op, - TransformHandleTypeInterface:$split_linalg_op, - TransformHandleTypeInterface:$combining_linalg_op, + TransformHandleTypeInterface:$split_op, + TransformHandleTypeInterface:$combining_op, TransformHandleTypeInterface:$forall_op); let builders = [ @@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp : let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, - ::mlir::linalg::LinalgOp target, + Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f0c1f4485b054..f3db8f7ccfaa1 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build( } DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( - transform::TransformRewriter &rewriter, LinalgOp target, + transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); + + auto partialReductionOp = dyn_cast(target); + if (!partialReductionOp) { + return emitSilenceableFailure( + target->getLoc(), + "Operation should implement PartialReductionOpInterface"); + } SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); SmallVector tileSizes = @@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( extractFromIntegerArrayAttr(getReductionDims()); if (reductionDims.empty()) { for (auto [idx, iteratorType] : - llvm::enumerate(target.getIteratorTypesArray())) { + llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) { if (iteratorType == utils::IteratorType::reduction) reductionDims.push_back(idx); } } options.setReductionDims(reductionDims); - FailureOr result = scf::tileUsingSCF( - rewriter, cast(target.getOperation()), options); + FailureOr result = + scf::tileUsingSCF(rewriter, partialReductionOp, options); if (failed(result)) { auto diag = emitSilenceableError() << "could not tile reduction";