diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index b67ff776ea4a9..896b31835fb45 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -89,11 +89,30 @@ LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); /// Returns a list of PromotionInfo which hold the promoted buffer and the /// full and partial views indexing into the buffer. // TODO: revisit dynamicBuffers option. -LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, - llvm::SetVector subViews, - bool dynamicBuffers = false, - int64_t alignment = 0, - OperationFolder *folder = nullptr); +struct LinalgPromotionOptions { + /// Indices of subViews to promote. If `None`, try to promote all operands. + Optional> operandsToPromote = None; + LinalgPromotionOptions &setOperandsToPromote(ArrayRef operands) { + operandsToPromote = DenseSet(); + operandsToPromote->insert(operands.begin(), operands.end()); + return *this; + } + /// Allow the use of dynamicaly-sized buffers. + bool dynamicBuffers = false; + LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) { + dynamicBuffers = dynamic; + return *this; + } + /// Alignment of promoted buffer. If `None` do not specify alignment. + Optional alignment = None; + LinalgPromotionOptions &setAlignment(unsigned align) { + alignment = align; + return *this; + } +}; +LinalgOp promoteSubViews(OpBuilder &b, LinalgOp op, + LinalgPromotionOptions options, + OperationFolder *folder = nullptr); /// Emit a suitable vector form for a Linalg op with fully static shape. void vectorizeLinalgOp(OpBuilder &builder, Operation *op); @@ -125,8 +144,8 @@ interchangeGenericLinalgOpPrecondition(Operation *op, ArrayRef interchangeVector); /// Promote std.subviews feeding linalg operations. -LogicalResult promoteSubviewsLinalgOpPrecondition( - Operation *op, Optional> operandIndicesToPromote = None); +LogicalResult promoteSubviewsPrecondition(Operation *op, + LinalgPromotionOptions options); /// Rewrite a linalg.generic into a suitable vector.contraction op. LogicalResult vectorizeLinalgOpPrecondition(Operation *op); @@ -242,13 +261,12 @@ struct LinalgInterchangePattern : public LinalgBaseInterchangePattern { /// /// Linalg promotion patterns. /// -/// Apply the `promoteSubViewOperands` transformation as a pattern. +/// Apply the `promoteSubViews` transformation as a pattern. /// `marker` controls LinalgTransformMarker matching and update when specified. -/// See `promoteSubViewOperands` for more details. +/// See `promoteSubViews` for more details. struct LinalgBasePromotionPattern : public RewritePattern { LinalgBasePromotionPattern(StringRef opName, MLIRContext *context, - ArrayRef operandsToPromote = {}, - unsigned alignment = 0, + LinalgPromotionOptions options, LinalgMarker marker = LinalgMarker(), PatternBenefit benefit = 1); LogicalResult matchAndRewrite(Operation *op, @@ -257,35 +275,17 @@ struct LinalgBasePromotionPattern : public RewritePattern { private: /// LinalgTransformMarker handles special attribute manipulations. LinalgMarker marker; - /// Indices of subViews to promote. - SmallVector operandsToPromote; - /// Alignment of promoted buffer. - unsigned alignment; + /// Promotion options. + LinalgPromotionOptions options; }; template struct LinalgPromotionPattern : public LinalgBasePromotionPattern { - LinalgPromotionPattern(MLIRContext *context, - ArrayRef operandsToPromote = {}, - unsigned alignment = 0, + LinalgPromotionPattern(MLIRContext *context, LinalgPromotionOptions options, LinalgMarker marker = LinalgMarker(), PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(OpTy::getOperationName(), context, - operandsToPromote, alignment, marker, - benefit) {} - LinalgPromotionPattern(MLIRContext *context, - ArrayRef operandsToPromote, - LinalgMarker marker = LinalgMarker(), - PatternBenefit benefit = 1) - : LinalgPromotionPattern(context, operandsToPromote, 0, marker, benefit) { - } - LinalgPromotionPattern(MLIRContext *context, unsigned alignment, - LinalgMarker marker = LinalgMarker(), - PatternBenefit benefit = 1) - : LinalgPromotionPattern(context, {}, alignment, marker, benefit) {} - LinalgPromotionPattern(MLIRContext *context, LinalgMarker marker, - PatternBenefit benefit = 1) - : LinalgPromotionPattern(context, {}, 0, marker, benefit) {} + : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options, + marker, benefit) {} }; /// @@ -342,8 +342,6 @@ struct LinalgLoweringPattern : public RewritePattern { return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); - if (failed(promoteSubviewsLinalgOpPrecondition(op))) - return failure(); if (loweringType == LinalgLoweringType::LibraryCall) { // TODO: Move lowering to library calls here. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 8e93ea355a127..86c5ceaef5790 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -45,7 +45,45 @@ using folded_std_view = FoldedValueBuilder; #define DEBUG_TYPE "linalg-promotion" -/// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin +namespace { + +/// Helper struct that captures the information required to apply the +/// transformation on each op. This bridges the abstraction gap with the +/// user-facing API which exposes positional arguments to control which operands +/// are promoted. +struct LinalgOpInstancePromotionOptions { + LinalgOpInstancePromotionOptions(LinalgOp op, + const LinalgPromotionOptions &options); + /// SubViews to promote. + SetVector subViews; + /// Allow the use of dynamicaly-sized buffers. + bool dynamicBuffers; + /// Alignment of promoted buffer. + Optional alignment; +}; +} // namespace + +LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( + LinalgOp linalgOp, const LinalgPromotionOptions &options) + : subViews(), dynamicBuffers(options.dynamicBuffers), + alignment(options.alignment) { + if (options.operandsToPromote.hasValue()) { + for (unsigned idx : options.operandsToPromote.getValue()) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } else { + unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); + for (unsigned idx = 0; idx < nBuffers; ++idx) { + auto *op = linalgOp.getBuffer(idx).getDefiningOp(); + if (auto sv = dyn_cast_or_null(op)) + subViews.insert(sv); + } + } +} + +/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp /// is a constant then return a new value set to the smallest such constant. /// Otherwise return size. static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc, @@ -53,25 +91,26 @@ static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc, auto affineMinOp = dyn_cast_or_null(size.getDefiningOp()); if (!affineMinOp) return size; - if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) { - return e.dyn_cast(); - })) - return size; int64_t minConst = std::numeric_limits::max(); for (auto e : affineMinOp.getAffineMap().getResults()) if (auto cst = e.dyn_cast()) minConst = std::min(minConst, cst.getValue()); - assert(minConst != std::numeric_limits::max()); - return b.create(loc, minConst); + return (minConst == std::numeric_limits::max()) + ? size + : b.create(loc, minConst); } +/// Alloc a new buffer of `size`. If `dynamicBuffers` is true allocate exactly +/// the size needed, otherwise try to allocate a static bounding box. static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers, - OperationFolder *folder, int64_t alignment = 0) { + OperationFolder *folder, + Optional alignment = None) { auto *ctx = size.getContext(); auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); IntegerAttr alignment_attr; - if (alignment) - alignment_attr = IntegerAttr::get(IntegerType::get(64, ctx), alignment); + if (alignment.hasValue()) + alignment_attr = + IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue()); if (!dynamicBuffers) if (auto cst = dyn_cast_or_null(size.getDefiningOp())) return std_alloc( @@ -100,11 +139,11 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers, // To account for general boundary effects, padding must be performed on the // boundary tiles. For now this is done with an unconditional `fill` op followed // by a partial `copy` op. -static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, - SubViewOp subView, - bool dynamicBuffers, - int64_t alignment, - OperationFolder *folder) { +static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, + SubViewOp subView, + bool dynamicBuffers, + Optional alignment, + OperationFolder *folder) { auto zero = folded_std_constant_index(folder, 0); auto one = folded_std_constant_index(folder, 1); @@ -117,8 +156,10 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, for (auto en : llvm::enumerate(subView.getRanges())) { auto rank = en.index(); auto rangeValue = en.value(); - // Try to extract a tight constant + // Try to extract a tight constant. + LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n"); Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size); + LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); allocSize = folded_std_muli(folder, allocSize, size); fullSizes.push_back(size); partialSizes.push_back(folded_std_dim(folder, subView, rank)); @@ -136,26 +177,26 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, return PromotionInfo{buffer, fullLocalView, partialLocalView}; } -SmallVector -mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, - ArrayRef subViews, bool dynamicBuffers, - int64_t alignment, OperationFolder *folder) { - if (subViews.empty()) +static SmallVector +promoteSubViews(OpBuilder &b, Location loc, + LinalgOpInstancePromotionOptions options, + OperationFolder *folder) { + if (options.subViews.empty()) return {}; ScopedContext scope(b, loc); SmallVector res; - res.reserve(subViews.size()); + res.reserve(options.subViews.size()); DenseMap promotionInfoMap; - for (auto v : subViews) { + for (auto v : options.subViews) { SubViewOp subView = cast(v.getDefiningOp()); - auto promotionInfo = promoteFullTileBuffer(b, loc, subView, dynamicBuffers, - alignment, folder); + auto promotionInfo = promoteSubviewAsNewBuffer( + b, loc, subView, options.dynamicBuffers, options.alignment, folder); promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo)); res.push_back(promotionInfo); } - for (auto v : subViews) { + for (auto v : options.subViews) { SubViewOp subView = cast(v.getDefiningOp()); auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) @@ -172,7 +213,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, linalg_fill(info->second.fullLocalView, fillVal); } - for (auto v : subViews) { + for (auto v : options.subViews) { auto info = promotionInfoMap.find(v); if (info == promotionInfoMap.end()) continue; @@ -182,11 +223,9 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, return res; } -LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, - SetVector subViews, - bool dynamicBuffers, - int64_t alignment, - OperationFolder *folder) { +static void promoteSubViews(OpBuilder &b, LinalgOp op, + LinalgOpInstancePromotionOptions options, + OperationFolder *folder) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (auto convOp = dyn_cast(op.getOperation())) { @@ -196,17 +235,15 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, } // 1. Promote the specified views and use them in the new op. - ScopedContext scope(b, op.getLoc()); - auto promotedBufferAndViews = - promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, - alignment, folder); + auto loc = op.getLoc(); + auto promotedBufferAndViews = promoteSubViews(b, loc, options, folder); SmallVector opViews; opViews.reserve(op.getNumInputsAndOutputs()); SmallVector, 8> writebackViews; - writebackViews.reserve(subViews.size()); + writebackViews.reserve(promotedBufferAndViews.size()); unsigned promotedIdx = 0; for (auto view : op.getInputsAndOutputBuffers()) { - if (subViews.count(view) != 0) { + if (options.subViews.count(view) != 0) { opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); writebackViews.emplace_back(std::make_pair( view, promotedBufferAndViews[promotedIdx].partialLocalView)); @@ -219,67 +256,55 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, // 2. Append all other operands as they appear, this enforces that such // operands are not views. This is to support cases such as FillOp taking // extra scalars etc. - auto operands = getAssumedNonViewOperands(op); - opViews.append(operands.begin(), operands.end()); - LinalgOp res = op.clone(b, op.getLoc(), opViews); + // Keep a reference to output buffers; + DenseSet originalOutputs(op.getOutputBuffers().begin(), + op.getOutputBuffers().end()); + op.getOperation()->setOperands(0, opViews.size(), opViews); + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(op); + ScopedContext scope(b, loc); // 3. Emit write-back for the promoted output views: copy the partial view. - for (auto viewAndPartialLocalView : writebackViews) { - // WARNING: MUST use the old op to determine whether the operand view is an - // output. - bool isOutput = - op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue(); - if (isOutput) + for (auto viewAndPartialLocalView : writebackViews) + if (originalOutputs.count(viewAndPartialLocalView.first)) linalg_copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first); - } - // 4. Dealloc local buffers. + // 4. Dealloc all local buffers. for (const auto &pi : promotedBufferAndViews) std_dealloc(pi.buffer); - - return res; } -static void promoteSubViews(FuncOp f, bool dynamicBuffers) { - SmallVector toErase; - OperationFolder folder(f.getContext()); - f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { - if (!op.hasBufferSemantics()) - return; - - // TODO(ntv) some heuristic here to decide what to promote. Atm only float - // and integer buffers can be promoted. - SetVector subViews; - OpBuilder b(op); - for (auto it : op.getInputsAndOutputBuffers()) - if (auto sv = dyn_cast_or_null(it.getDefiningOp())) - if (sv.getType().getElementType().isSignlessIntOrFloat()) - subViews.insert(sv); - if (!subViews.empty()) { - promoteSubViewOperands(b, op, subViews, dynamicBuffers, 0, &folder); - toErase.push_back(op); - } - }); - for (auto op : toErase) - op.erase(); -} - -LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition( - Operation *op, llvm::Optional> operandIndicesToPromote) { +LogicalResult +mlir::linalg::promoteSubviewsPrecondition(Operation *op, + LinalgPromotionOptions options) { LinalgOp linOp = dyn_cast(op); // Transformation applies to buffers only. if (!linOp || !linOp.hasBufferSemantics()) return failure(); + // Check that at least one of the requested operands is indeed a subview. for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) { auto sv = isa_and_nonnull(en.value().getDefiningOp()); - if (sv && (!operandIndicesToPromote.hasValue() || - operandIndicesToPromote->count(en.index()))) - return success(); + if (sv) { + if (!options.operandsToPromote.hasValue() || + options.operandsToPromote->count(en.index())) + return success(); + } } + // TODO: Check all subviews requested are bound by a static constant. + // TODO: Check that the total footprint fits within a given size. return failure(); } +LinalgOp mlir::linalg::promoteSubViews(OpBuilder &b, LinalgOp linalgOp, + LinalgPromotionOptions options, + OperationFolder *folder) { + LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options); + ::promoteSubViews( + b, linalgOp, LinalgOpInstancePromotionOptions(linalgOp, options), folder); + return linalgOp; +} + namespace { struct LinalgPromotionPass : public LinalgPromotionBase { LinalgPromotionPass() = default; @@ -288,11 +313,20 @@ struct LinalgPromotionPass : public LinalgPromotionBase { } void runOnFunction() override { - promoteSubViews(getFunction(), dynamicBuffers); + OperationFolder folder(&getContext()); + getFunction().walk([this, &folder](LinalgOp op) { + auto options = LinalgPromotionOptions().setDynamicBuffers(dynamicBuffers); + if (failed(promoteSubviewsPrecondition(op, options))) + return; + LLVM_DEBUG(llvm::dbgs() << "Promote: " << *(op.getOperation()) << "\n"); + OpBuilder b(op); + promoteSubViews(b, op, options, &folder); + }); } }; } // namespace +// TODO: support more transformation options in the pass. std::unique_ptr> mlir::createLinalgPromotionPass(bool dynamicBuffers) { return std::make_unique(dynamicBuffers); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index e229b10072f0c..175c6c8ef0968 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -160,51 +160,23 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( } mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( - StringRef opName, MLIRContext *context, - ArrayRef operandsToPromote, unsigned alignment, + StringRef opName, MLIRContext *context, LinalgPromotionOptions options, LinalgMarker marker, PatternBenefit benefit) : RewritePattern(opName, {}, benefit, context), marker(marker), - operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()), - alignment(alignment) {} + options(options) {} LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) + if (failed(marker.checkAndNotify(rewriter, op))) return failure(); - if (failed(marker.checkAndNotify(rewriter, linalgOp))) + if (failed(promoteSubviewsPrecondition(op, options))) return failure(); - if (operandsToPromote.empty()) { - if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None))) - return failure(); - } else { - DenseSet set; - set.insert(operandsToPromote.begin(), operandsToPromote.end()); - if (failed(promoteSubviewsLinalgOpPrecondition(op, set))) - return failure(); - } - - llvm::SetVector subViews; - if (!operandsToPromote.empty()) { - for (unsigned idx : operandsToPromote) { - auto *op = linalgOp.getBuffer(idx).getDefiningOp(); - if (auto sv = dyn_cast_or_null(op)) - subViews.insert(sv); - } - } else { - unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers(); - for (unsigned idx = 0; idx < nBuffers; ++idx) { - auto *op = linalgOp.getBuffer(idx).getDefiningOp(); - if (auto sv = dyn_cast_or_null(op)) - subViews.insert(sv); - } - } - - auto promotedOp = - promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false, - /*alignment=*/alignment); - marker.replaceLinalgMarker(rewriter, promotedOp.getOperation()); - rewriter.eraseOp(op); + rewriter.updateRootInPlace(op, [&]() { + auto promotedOp = promoteSubViews(rewriter, op, options); + (void)promotedOp; + assert(promotedOp && "Unexpected pattern failure"); + marker.replaceLinalgMarker(rewriter, op); + }); return success(); } diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index f3861c38fa601..eb27a7ae00341 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -120,15 +120,13 @@ static void applyPatterns(FuncOp funcOp) { // Linalg subview operands promotion. //===--------------------------------------------------------------------===// patterns.insert>( - ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_")); + ctx, LinalgPromotionOptions(), + LinalgMarker({"_promote_views_"}, "_views_promoted_")); patterns.insert>( - ctx, - /*operandsToPromote=*/ArrayRef{0}, + ctx, LinalgPromotionOptions().setOperandsToPromote({0}), LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_")); patterns.insert>( - ctx, - /*operandsToPromote=*/ArrayRef{0}, - /*alignment=*/32, + ctx, LinalgPromotionOptions().setOperandsToPromote({0}).setAlignment(32), LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_")); applyPatternsAndFoldGreedily(funcOp, patterns);