diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h index a35964c5eab4bc..389e5cc6d1fb90 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -27,6 +27,8 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" + namespace mlir { namespace linalg { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td index 845873ff83dfea..1e1546407a56df 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -213,6 +213,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { return {range.begin(), range.begin() + $_op.getNumInputs()}; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the range over the input operands that are of buffer type. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getInputBuffers", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::to_vector<4>(llvm::make_filter_range( + getInputs(), [](Value in){ return in.getType().isa(); })); + }] + >, InterfaceMethod< /*desc=*/[{ Return the subset of input operands that are of ranked tensor type. @@ -337,6 +350,18 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { return this->getOperation()->getOperand(i); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the number of output buffers + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumOutputBuffers", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getNumOutputs() - this->getOperation()->getNumResults(); + }] + >, InterfaceMethod< /*desc=*/[{ Return the number of inputs and outputs, irrespective of their buffer or @@ -404,6 +429,49 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { return getInitTensors()[i]; }] >, + InterfaceMethod< + /*desc=*/[{ + Return true if the shaped operand index `i` is the index of an init + tensor. + }], + /*retTy=*/"bool", + /*methodName=*/"isIndexOfAnInitTensor", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumShapedOperands() && "overflowing shaped operand index"); + return i >= $_op.getNumInputs() + getNumOutputBuffers(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the relative init tensor index of the shaped operand index. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getInitTensorIndexFromShapedIndex", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(isIndexOfAnInitTensor(i) && "expected an init tensor index"); + return i - $_op.getNumInputs() - getNumOutputBuffers(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the index of the given init tensor value, or `None` if the value + is not part of the init tensors. + }], + /*retTy=*/"llvm::Optional", + /*methodName=*/"getIndexOfInitTensor", + /*args=*/(ins "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto it = llvm::find(getInitTensors(), value); + if (it != getInitTensors().end()) + return it - getInitTensors().begin(); + return llvm::None; + }] + >, InterfaceMethod< /*desc=*/[{ Return the number of inputs, output buffers and init tensors operands. @@ -416,6 +484,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the `i`-th shaped operand value, which can be an arbitrary input + tensor/buffer, init tensor or output buffer. + }], + /*retTy=*/"Value", + /*methodName=*/"getShapedOperand", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i < $_op.getNumShapedOperands()); + return this->getOperation()->getOperand(i); + }] + >, InterfaceMethod< /*desc=*/[{ Return the range over inputs, output buffers and init tensors. @@ -473,19 +555,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { >, InterfaceMethod< /*desc=*/[{ - Return the position of buffer in inputs + outputs list + Return the position of the shaped operand in the operand list. }], /*retTy=*/"Optional", - /*methodName=*/"getIndexOfInputAndOutputBuffer", + /*methodName=*/"getIndexOfShapedOperand", /*args=*/(ins "Value":$value), /*methodBody=*/"", /*defaultImplementation=*/[{ Optional inputIndex = getIndexOfInput(value); if (inputIndex.hasValue()) return inputIndex.getValue(); Optional outputIndex = getIndexOfOutputBuffer(value); - if (outputIndex.hasValue()) { + if (outputIndex.hasValue()) return $_op.getNumInputs() + outputIndex.getValue(); - } + Optional initTensorIndex = getIndexOfInitTensor(value); + if (initTensorIndex.hasValue()) + return $_op.getNumInputs() + $_op.getNumOutputBuffers() + initTensorIndex.getValue(); return llvm::None; }] >, @@ -628,8 +712,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location and operands. This - is used to abstract away the optional underlying region creation. This - does not change the balance between input, output_buffer and + is used to abstract away the optional underlying region creation. This + does not change the balance between input, output_buffer and init_tensors operands. }], /*retTy=*/"Operation *", diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 61367dd795486b..9b343b3b04ab2c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -32,6 +32,9 @@ class PatternRewriter; namespace linalg { class LinalgDependenceGraph; +/// A struct containing the Linalg producer before and after fusion. +/// When operating on tensors, `fusedProducer` may feed into a `tensor_cast` op +/// before the consumer Linalg op, until enough canonicalizations have applied. struct FusionInfo { LinalgOp originalProducer; LinalgOp fusedProducer; @@ -81,13 +84,25 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. +/// Implements the fusion part of the "tileAndFuse on buffers" +/// transformation and thus requires the `consumerdIdx`^th operand of `consumer` +/// to be a `subview` op (generally obtained by applying the tiling +/// transformation). /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. -Optional fuseProducerOf(OpBuilder &b, LinalgOp consumer, - unsigned consumerIdx, - const LinalgDependenceGraph &graph, - OperationFolder *folder = nullptr); +Optional fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + const LinalgDependenceGraph &graph, + OperationFolder *folder = nullptr); +/// Tensor counterpart of `fuseProducerOfBuffer`. +/// This implements the fusion part of the "tileAndFuse on tensors" +/// transformation and thus requires the `consumerdIdx`^th operand of `consumer` +/// to be the result of a `subtensor` op (generally obtained by applying the +/// tiling transformation). +Optional fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + OperationFolder *folder); /// Fuse linalg operation on tensors, with the producer of the operand at /// position `consumerIdx` of the consumer. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index bffd9bd1bd0c37..58a8b3eddc3be9 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -147,13 +147,9 @@ LinalgDependenceGraph::getDependencesInto( } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - assert(src.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(dst.hasBufferSemantics() && - "expected linalg op with buffer semantics"); for (auto srcView : src.getOutputBuffers()) { // W // RAW graph - for (auto dstView : dst.getInputs()) { // R + for (auto dstView : dst.getInputBuffers()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAW addDependenceElem(DependenceType::RAW, LinalgOpView{src.getOperation(), srcView}, @@ -169,9 +165,9 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } } - for (auto srcView : src.getInputs()) { // R + for (auto srcView : src.getInputBuffers()) { // R // RAR graph - for (auto dstView : dst.getInputs()) { // R + for (auto dstView : dst.getInputBuffers()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAR addDependenceElem(DependenceType::RAR, LinalgOpView{src.getOperation(), srcView}, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 585b8810fdc25e..8542c2afb0863c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -41,97 +41,131 @@ using folded_std_constant_index = FoldedValueBuilder; using llvm::dbgs; -/// Implements a simple high-level fusion pass of linalg library operations. +/// Implements a simple high-level fusion pass on linalg structured operations. /// /// In each block, linalg ops are processed in reverse textual order. /// Given a linalg op `O`, fusion occurs by: -/// 1. inspecting the linalg ops that write into the views read by `O`. This -/// uses the SSA value of the views and a simple subview/slice analysis to -/// determine producer-consumer dependences; -/// 2. greedily fuse the linalg ops that produce subview +/// 1. inspecting the linalg ops that write into the views read by `O`. There +/// are 2 cases: +/// a) buffer case: use the SSA value of the views and a simple alias +/// analysis on subview ops to determine producer-consumer dependences; +/// b) tensor case: use SSA use-def chains on subtensor ops; +/// 2. greedily fuse the linalg ops that produce the subview/subtensor. /// 3. inspect the fused ops and determine whether they have other remaining /// LinalgOp uses. If not, then erase the original producing linalg op. /// /// More advanced use cases, analyses as well as profitability heuristics are /// left for future work. +// Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed +// by `permutationMap`. +static void inferShapeComponents(AffineMap permutationMap, + ArrayRef loopRanges, + SmallVectorImpl &offsets, + SmallVectorImpl &sizes, + SmallVectorImpl &strides) { + assert(permutationMap.isProjectedPermutation() && + "expected some subset of a permutation map"); + SmallVector shapeRanges(permutationMap.getNumResults()); + unsigned idx = 0; + for (AffineExpr e : permutationMap.getResults()) { + // loopToOperandRangesMaps are permutations-only, just swap indices. + unsigned loopPos = e.cast().getPosition(); + shapeRanges[idx++] = loopRanges[loopPos]; + } + // Construct a new subshape for the tile. + unsigned rank = shapeRanges.size(); + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (auto r : shapeRanges) { + offsets.push_back(r.offset); + sizes.push_back(r.size); + strides.push_back(r.stride); + } +} + // Return a cloned version of `op` that operates on `loopRanges`, assumed to be // a subset of the original loop ranges of `op`. // This is achieved by applying the `loopToOperandRangesMaps` permutation maps // to the `loopRanges` in order to obtain view ranges. static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef loopRanges) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - auto maps = op.indexing_maps(); - SmallVector clonedViews; - clonedViews.reserve(op.getNumInputsAndOutputs()); - // Iterate over the inputs and outputs in order. + SmallVector clonedShapes; + clonedShapes.reserve(op.getNumShapedOperands()); + + // Iterate over the shape operands in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputBuffers()); - for (auto en : llvm::enumerate(ios)) { - unsigned idx = en.index(); - auto map = maps[idx].cast().getValue(); - LLVM_DEBUG(dbgs() << "map: " << map << "\n"); - Value view = en.value(); - SmallVector viewRanges(map.getNumResults()); - for (auto en2 : llvm::enumerate(map.getResults())) { - unsigned d = en2.index(); - // loopToOperandRangesMaps are permutations-only. - unsigned loopPos = en2.value().cast().getPosition(); - viewRanges[d] = loopRanges[loopPos]; - LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() - << "\t" - << "loopPos: " << loopPos << "\t" << viewRanges[d]); - } - // Construct a new subview for the tile. - unsigned rank = viewRanges.size(); + for (auto en : llvm::enumerate(op.getShapedOperands())) { + unsigned shapedOperandIdx = en.index(); + AffineMap map = op.getIndexingMap(shapedOperandIdx); + LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx + << " with indexingMap: " << map << "\n"); SmallVector offsets, sizes, strides; - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (auto r : viewRanges) { - offsets.push_back(r.offset); - sizes.push_back(r.size); - strides.push_back(r.stride); - } - clonedViews.push_back( - b.create(loc, view, offsets, sizes, strides)); + inferShapeComponents(map, loopRanges, offsets, sizes, strides); + Value shape = en.value(); + Value sub = shape.getType().isa() + ? b.create(loc, shape, offsets, sizes, strides) + .getResult() + : b.create(loc, shape, offsets, sizes, strides) + .getResult(); + clonedShapes.push_back(sub); } + // Append the other operands. auto operands = op.getAssumedNonShapedOperands(); - clonedViews.append(operands.begin(), operands.end()); + clonedShapes.append(operands.begin(), operands.end()); + + // Iterate over the results in order. + // Extract the subtensor type from the linearized range. + // Since we do not enforce any canonicalizations on the fly, this is always + // fully dynamic at construction time. + SmallVector resultTypes; + resultTypes.reserve(op.getOperation()->getNumResults()); + for (RankedTensorType t : op.getOutputTensorTypes()) { + unsigned rank = t.getRank(); + SmallVector staticOffsetsVector( + rank, ShapedType::kDynamicStrideOrOffset); + SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector( + rank, ShapedType::kDynamicStrideOrOffset); + resultTypes.push_back(SubTensorOp::inferResultType( + t.cast(), staticOffsetsVector, staticSizesVector, + staticStridesVector)); + } - Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews); - // When the producer is an IndexedGenercOp, we have to transform its block + Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes); + // When the producer is an IndexedGenericOp, we have to transform its block // IV arguments according to the tiling of the consumer, i.e. offset them by // the values computed in `loopRanges`. if (auto indexedGenericOp = dyn_cast(clonedOp)) { auto &block = indexedGenericOp.region().front(); - OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&block); for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { Value oldIndex = block.getArgument(i); + // TODO: replace by an affine_apply. AddIOp newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, loopRanges[i].offset); oldIndex.replaceAllUsesExcept(newIndex, SmallPtrSet{newIndex}); } } + return clonedOp; } -struct ViewDimension { - Value view; +struct ShapeDimension { + Value shape; unsigned dimension; }; -// Given an `op`, returns the first (`view`, `dimension`) pair that identifies +// Given an `op`, returns the first (`shape`, `dimension`) pair that identifies // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps // guarantees at least one such dimension is found. If multiple candidates exist // they must agree by construction (i.e. have the same size) and we just return // the first one. -static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); +static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, + unsigned loopDepth) { auto maps = op.indexing_maps(); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. @@ -139,43 +173,47 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx].cast().getValue(); - LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); - LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); - Value view = en.value(); - SmallVector viewRanges(map.getNumResults(), nullptr); + LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); + LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n"); + Value shape = en.value(); + SmallVector shapeRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { if (loopDepth == en2.value().cast().getPosition()) { - LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth + LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: " + << loopDepth << "\n"); + LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape << "\n"); - LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); - return ViewDimension{view, static_cast(en2.index())}; + return ShapeDimension{shape, static_cast(en2.index())}; } } } - llvm_unreachable("Expect to be able to extract a view defining loop range"); + llvm_unreachable("Expect to be able to extract a shape defining loop range"); } +/// Fuses the producer of `producerIdx` into the loop immediately enclosing +/// `consumer`. This is achieved by "recomputing" the `producer` at the time it +/// is needed just before the `consumer. +/// +/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are +/// 2 cases: +/// 1. Buffer case: `producerIdx` is the index of the buffer in +/// `producer.getOutputBuffers()`. +/// 2. Tensor case: `producerIdx` is the index of the tensor in +/// `producer.getResults()`. static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, LinalgOp consumer, unsigned consumerIdx, OperationFolder *folder = nullptr) { - assert(producer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - - auto subView = dyn_cast_or_null( - consumer.getBuffer(consumerIdx).getDefiningOp()); - auto slice = dyn_cast_or_null( - consumer.getBuffer(consumerIdx).getDefiningOp()); - assert(subView || slice); - (void)subView; - (void)slice; + Operation *shapeProducingOp = + consumer.getShapedOperand(consumerIdx).getDefiningOp(); + assert((isa(shapeProducingOp) || + isa(shapeProducingOp)) && + "SubviewOp or SubTensorOp expected"); // loopToOperandRangesMaps are permutations-only by construction: // we can always identify a data dimension with a (at least one) loop // dimension. - AffineMap producerMap = - producer.indexing_maps()[producerIdx].cast().getValue(); + // TODO: extend this with range inference. + AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx << ", producer map: " << producerMap << "\n"); @@ -190,20 +228,24 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, for (auto en : llvm::enumerate(producerMap.getResults())) { unsigned posInProducerLoop = en.value().cast().getPosition(); loopRanges[posInProducerLoop] = - subView.getOrCreateRanges(b, loc)[en.index()]; + isa(shapeProducingOp) + ? cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()] + : cast(shapeProducingOp) + .getOrCreateRanges(b, loc)[en.index()]; } // Iterate over all dimensions. For the dimensions not identified by the - // producer map for `producerIdx`, we need to explicitly compute the view that - // defines the loop ranges using the `producer`. + // producer map for `producerIdx`, we need to explicitly compute the shape + // that defines the loop ranges using the `producer`. for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { if (loopRanges[i].offset) LLVM_DEBUG(llvm::dbgs() << "existing LoopRange: " << loopRanges[i] << "\n"); else { - auto viewDim = getViewDefiningLoopRange(producer, i); + auto shapeDim = getShapeDefiningLoopRange(producer, i); loopRanges[i] = Range{folded_std_constant_index(folder, 0), - std_dim(viewDim.view, viewDim.dimension), + std_dim(shapeDim.shape, shapeDim.dimension), folded_std_constant_index(folder, 1)}; LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); } @@ -269,7 +311,7 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, "expected linalg op with buffer semantics"); if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) return false; - // Check for any fusion-preventing dependence to any view read/written that + // Check for any fusion-preventing dependence to any shape read/written that // would violate dependences. if (!graph.findCoveringDependences(producer, consumer).empty()) { LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" @@ -308,7 +350,7 @@ static bool isSameSubView(Value a, Value b) { return false; if (sva.static_strides() != svb.static_strides()) return false; - /// Skip the "viewSource" operand. + /// Skip the "source" operand. for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) if (sva.getOperand(idx) != svb.getOperand(idx)) return false; @@ -354,7 +396,7 @@ findFusableProducer(LinalgOp consumer, unsigned consumerIdx, return {}; } -Optional mlir::linalg::fuseProducerOf( +Optional mlir::linalg::fuseProducerOfBuffer( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, const LinalgDependenceGraph &graph, OperationFolder *folder) { Optional fusableDependence = @@ -381,7 +423,7 @@ Optional mlir::linalg::fuseProducerOf( ScopedContext scope(b, consumer.getLoc()); LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); Optional producerIdxOpt = - producerOp.getIndexOfInputAndOutputBuffer(producerView); + producerOp.getIndexOfOutputBuffer(producerView); assert(producerIdxOpt.hasValue() && "incorrect operand index"); unsigned producerIdx = producerIdxOpt.getValue(); @@ -390,10 +432,75 @@ Optional mlir::linalg::fuseProducerOf( return FusionInfo{producerOp, fusedProducer}; } +/// Walk back use-def chain through scf::For yields. +/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp +static void getProducerOfTensor(Value tensor, LinalgOp &producer, + unsigned &outputIndex) { + if (!tensor.getType().isa()) + return; + + while (true) { + if (auto linalgOp = tensor.getDefiningOp()) { + producer = linalgOp; + outputIndex = tensor.cast().getResultNumber(); + return; + } + if (auto subTensorOp = tensor.getDefiningOp()) { + tensor = subTensorOp.source(); + continue; + } + if (auto blockArg = tensor.dyn_cast()) { + if (auto forOp = blockArg.getDefiningOp()) { + tensor = forOp.getResult(blockArg.getArgNumber()); + continue; + } + } + return; + } +} + +Optional +mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, + unsigned consumerIdx, + OperationFolder *folder) { + Value inputTensor = consumer.getInput(consumerIdx); + LinalgOp producerOp; + unsigned producerIdx; + getProducerOfTensor(inputTensor, producerOp, producerIdx); + + // Must be a subtensor to guarantee there are loops we can fuse into. + auto subTensor = inputTensor.getDefiningOp(); + if (!subTensor || !producerOp) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)"); + return {}; + } + + // Insert fused `producer` just before `consumer`. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(consumer.getOperation()); + ScopedContext scope(b, consumer.getLoc()); + LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); + LinalgOp fusedProducer = + fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); + + // Replace use. + // Canonicalizations are not guaranteed to have happened before constructing + // `fusedProducer`. In the tensor case this can result in temporary type + // mismatches. Insert a `tensor_cast` op to propagate the transformation + // invariant that types are compatible. + Value def = fusedProducer.getOperation()->getResult(producerIdx); + OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx); + Type consumerType = use.get().getType(); + if (consumerType != def.getType()) + def = b.create(fusedProducer.getLoc(), consumerType, def); + use.set(def); + return FusionInfo{producerOp, fusedProducer}; +} + /// Returns the positions of the loop in `op` that can be tiled based on the /// operations that are to be fused with it. For example, in a /// -/// linalg. matmul ins(%a, %b : ...) outs(%c : ...) +/// linalg.matmul ins(%a, %b : ...) outs(%c : ...) /// /// if the producer of %a needs to be fused with this op, only the `i` loop of /// the matmul can be tiled while fusing. If producer of %a, and %b are to be @@ -475,7 +582,7 @@ static DenseSet collectTileAndFuseLoops( SmallVector, 1> commonTilableLoops; for (auto dependence : fusableDependences) { unsigned consumerIdx = - op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue(); + op.getIndexOfShapedOperand(dependence.indexingView).getValue(); AffineMap consumerAccess = op.getIndexingMap(consumerIdx); // Previously asserted that the consumerAccess map is a projected // permutation, so all results are known to be AffineDimExprs. To remove @@ -522,8 +629,8 @@ findAllFusableDependences(LinalgOp op, LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); Value producerView = fusableDependence->dependentOpView.view; unsigned producerIdx = - producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue(); - AffineMap producerMap = producerOp.getIndexingMap(producerIdx); + producerOp.getIndexOfOutputBuffer(producerView).getValue(); + AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx); if (!producerMap.isProjectedPermutation()) { op.emitError("unhandled non permutation indexing map for fused view in " "producer for operand at index ") @@ -531,8 +638,7 @@ findAllFusableDependences(LinalgOp op, return llvm::None; } Value consumerView = fusableDependence->indexingView; - unsigned consumerIdx = - op.getIndexOfInputAndOutputBuffer(consumerView).getValue(); + unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue(); if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) { op.emitError( "unhandled case where indexing map for fused view in the consumer is " @@ -644,13 +750,11 @@ tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, // Fuse the operands. for (auto producer : enumerate(fusableDependences)) { LinalgOp producerOp = cast(producer.value().dependentOpView.op); - unsigned producerIdx = producerOp - .getIndexOfInputAndOutputBuffer( - producer.value().dependentOpView.view) - .getValue(); - unsigned consumerIdx = - op.getIndexOfInputAndOutputBuffer(producer.value().indexingView) + unsigned producerIdx = + producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view) .getValue(); + unsigned consumerIdx = + op.getIndexOfShapedOperand(producer.value().indexingView).getValue(); LinalgOp fusedOp = fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx); ret.fusedProducers.push_back(fusedOp); @@ -703,34 +807,52 @@ static void fuseLinalgOpsGreedily(FuncOp f) { // Save original Linalg ops, we only want to make a pass over those. SmallVector linalgOps; f.walk([&](LinalgOp op) { - if (op.hasBufferSemantics()) + // TODO: support multi-results. + if (op.getOperation()->getNumResults() <= 1) linalgOps.push_back(op); }); - // TODO: LinalgDependenceGraph should be able to update itself. - // The current naive and expensive reconstruction of the graph should be - // removed. + // Tile and Fuse for tensors inputs (TODO: all tensor operands). for (auto *op : llvm::reverse(linalgOps)) { - for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); - id < e; ++id) { - linalg::Aliases aliases; - linalg::LinalgDependenceGraph graph(aliases, linalgOps); - if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { - auto *originalOp = info->originalProducer.getOperation(); - eraseSet.insert(originalOp); - auto *originalOpInLinalgOpsVector = - std::find(linalgOps.begin(), linalgOps.end(), originalOp); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + LinalgOp linalgOp = cast(op); + for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { + if (en.value().getType().isa()) { + // TODO: LinalgDependenceGraph should be able to update itself. + // The current naive and expensive reconstruction of the graph should be + // removed. + linalg::Aliases aliases; + linalg::LinalgDependenceGraph graph(aliases, linalgOps); + if (auto info = + fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) { + auto *originalOp = info->originalProducer.getOperation(); + eraseSet.insert(originalOp); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + } + } else { + assert(en.value().getType().isa()); + // Tile and Fuse tensor input (TODO: init_tensors too). + if (en.index() >= linalgOp.getNumInputs()) + continue; + if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) { + auto *originalOp = info->originalProducer.getOperation(); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + // Don't mark for erasure in the tensor case, let DCE handle this. + } } } } - // The `fuseProducerOf` function performs structural checks and in particular - // that no covering read or write exist between the consumer and the producer. - // As a consequence, the only fusions that may occur preserve subsequent - // dependences and are guaranteed by construction to produce the whole view. - // We may thus erase the producer once it is fused. + // The `fuseProducerOfBuffer` function performs structural checks and in + // particular that no covering read or write exist between the consumer and + // the producer. As a consequence, the only fusions that may occur preserve + // subsequent dependences and are guaranteed by construction to produce the + // whole view. We may thus erase the producer once it is fused. for (auto *e : eraseSet) e->erase(); + LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); } diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir new file mode 100644 index 00000000000000..e43f261632e9a1 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion -canonicalize -cse -split-input-file | FileCheck %s --check-prefix=CANONICALIZED + +#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)> +#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)> +#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)> +#map3 = affine_map<(d0, d1) -> (2, d0 - d1)> +#map4 = affine_map<(d0, d1) -> (3, d0 - d1)> + +func @matmul_tensors(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %t0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) + init(%arg2: tensor) + -> tensor + + %c4 = constant 4 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %c3 = constant 3 : index + %c1 = constant 1 : index + %0 = dim %t0, %c0 : tensor + %1 = dim %t0, %c1 : tensor + %2 = dim %arg1, %c1 : tensor + %3 = scf.for %arg3 = %c0 to %0 step %c2 iter_args(%arg4 = %arg2) -> (tensor) { + %4 = scf.for %arg5 = %c0 to %2 step %c3 iter_args(%arg6 = %arg4) -> (tensor) { + %5 = scf.for %arg7 = %c0 to %1 step %c4 iter_args(%arg8 = %arg6) -> (tensor) { + %6 = subtensor %t0[%arg3, %arg7][%c2, 4][1, 1] : tensor to tensor + %7 = subtensor %arg1[%arg7, %arg5][4, %c3][1, 1] : tensor to tensor<4x?xf32> + %8 = subtensor %arg8[%arg3, %arg5][%c2, %c3][1, 1] : tensor to tensor + %9 = linalg.matmul ins(%6, %7 : tensor, tensor<4x?xf32>) init(%8 : tensor) -> tensor + %10 = subtensor_insert %9 into %arg8[%arg3, %arg5] [%c2, %c3] [1, 1] : tensor into tensor + scf.yield %10 : tensor + } + scf.yield %5 : tensor + } + scf.yield %4 : tensor + } + return %3 : tensor +} + +// CHECK-LABEL: func @matmul_tensors( +// CHECK-SAME: %[[A:[0-9a-z]*]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]*]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]*]]: tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: scf.for %[[I:[0-9a-z]*]] +// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]] +// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] +// +// subtensor of the original program, first one refers to the unfused matmul and becomes a dead SSA value. +// CHECK: subtensor %{{.*}}[%[[I]], %[[K]]] {{.*}} : tensor to tensor +// CHECK: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] {{.*}} : tensor to tensor<4x?xf32> +// CHECK: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] {{.*}} : tensor to tensor +// +// subtensors of the producing matmul. +// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], %[[C0]]] {{.*}} : tensor to tensor +// CHECK-NEXT: %[[stB2:.*]] = subtensor %[[B]][%[[C0]], %[[K]]] {{.*}} : tensor to tensor +// CHECK-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] {{.*}} : tensor to tensor +// CHECK-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor, tensor) init(%[[stC]] : tensor) -> tensor +// CHECK-NEXT: %[[stD2:.*]] = tensor_cast %[[stD]] : tensor to tensor +// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD2]], %[[stB1]] : tensor, tensor<4x?xf32>) init(%[[stF]] : tensor) -> tensor +// CHECK-NEXT: subtensor_insert %[[stG]] + + +// CANONICALIZED-LABEL: func @matmul_tensors( +// CANONICALIZED-SAME: %[[A:[0-9a-z]*]]: tensor +// CANONICALIZED-SAME: %[[B:[0-9a-z]*]]: tensor +// CANONICALIZED-SAME: %[[C:[0-9a-z]*]]: tensor +// CANONICALIZED: %[[C0:.*]] = constant 0 : index +// CANONICALIZED: %[[C1:.*]] = constant 1 : index +// CANONICALIZED: scf.for %[[I:[0-9a-z]*]] +// CANONICALIZED-NEXT: scf.for %[[J:[0-9a-z]*]] +// CANONICALIZED-NEXT: scf.for %[[K:[0-9a-z]*]] +// +// CANONICALIZED: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor to tensor<4x3xf32> +// CANONICALIZED: %[[stF:.*]] = subtensor %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : tensor to tensor<2x3xf32> +// +// subtensors of the producing matmul. +// CANONICALIZED: %[[dA1:.*]] = dim %[[A]], %[[C1]] : tensor +// CANONICALIZED: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor to tensor<2x?xf32> +// CANONICALIZED-NEXT: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor to tensor +// CANONICALIZED-NEXT: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor to tensor<2x4xf32> +// CANONICALIZED-NEXT: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor) init(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32> +// CANONICALIZED-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) init(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> +// CANONICALIZED-NEXT: subtensor_insert %[[stG]]