diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index b37a14f0eb7a09..90c6a0374e945a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -389,6 +389,11 @@ OwningRewritePatternList getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); struct LinalgBaseTilingPattern : public RewritePattern { + // Entry point to match any LinalgOp OpInterface. + LinalgBaseTilingPattern(LinalgTilingOptions options, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + // Entry point to match a specific Linalg op. LinalgBaseTilingPattern(StringRef opName, MLIRContext *context, LinalgTilingOptions options, LinalgMarker marker = LinalgMarker(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 97c3dafe57a8ff..804ae6681f8cbd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -111,6 +111,11 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( : RewritePattern(opName, {}, benefit, context), marker(marker), options(options) {} +mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( + LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit) + : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker), + options(options) {} + LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( Operation *op, PatternRewriter &rewriter, SmallVectorImpl &tensorResults) const { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 8e60312bf4fd8c..f44bb6769e6168 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -128,12 +128,12 @@ void GenerateLoopNest::doit( ArrayRef iteratorTypes, function_ref bodyBuilderFn, Optional distributionOptions) { - // Create procInfo so it dominate loops, if appropriate. + // Create procInfo so it dominates loops, if appropriate. OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); Location loc = edsc::ScopedContext::getLocation(); SmallVector procInfo; if (distributionOptions.hasValue()) - procInfo = distributionOptions->procInfo(builder, loc, ArrayRef{}); + procInfo = distributionOptions->procInfo(builder, loc, loopRanges); SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); @@ -143,11 +143,12 @@ void GenerateLoopNest::doit( if (!distributionOptions.hasValue() || loopNest.loops.empty()) return; - // TODO: support distributionMethod, which is currently ignored. + // Only supports cyclic distribution for now. for (auto it : llvm::zip(loopNest.loops, procInfo, distributionOptions->distributionMethod)) - mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId, - std::get<1>(it).nprocs); + if (std::get<2>(it) == DistributionMethod::Cyclic) + mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId, + std::get<1>(it).nprocs); } /// Specialization to build affine "for" nest. diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index 9e3efcf4166499..c2b4c7b9c82142 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -415,8 +415,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context, { LinalgLoopDistributionOptions cyclicNprocsEqNiters; - cyclicNprocsEqNiters.distributionMethod.resize( - 2, DistributionMethod::CyclicNumProcsEqNumIters); + cyclicNprocsEqNiters.distributionMethod.resize(2, + DistributionMethod::Cyclic); cyclicNprocsEqNiters.procInfo = getGpuProcIds; patterns.insert>(