diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index 8c9a1e3ad1d80..3e20412b17369 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -448,6 +448,12 @@ def PerformConcurrentlyOp : SCF_Op<"foreach_thread.perform_concurrently", [ let hasCustomAssemblyFormat = 1; let hasVerifier = 1; + // The default builder does not add a region with an empty body, add our own. + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins)>, + ]; + // TODO: Add a `PerformConcurrentlyOpInterface` interface for ops that can // appear inside perform_concurrently. let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index a160ba028928c..ecb66faee1999 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1138,10 +1138,11 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder, result.addOperands(numThreads); Region *bodyRegion = result.addRegion(); - { - OpBuilder::InsertionGuard g(builder); - builder.createBlock(bodyRegion); - } + OpBuilder::InsertionGuard g(builder); + // createBlock sets the IP inside the block. + // Generally we would guard against that but the default ensureTerminator impl + // expects it .. + builder.createBlock(bodyRegion); Block &bodyBlock = bodyRegion->front(); bodyBlock.addArguments( SmallVector(numThreads.size(), builder.getIndexType()), @@ -1158,8 +1159,9 @@ void ForeachThreadOp::build( function_ref bodyBuilder) { result.addOperands(numThreads); + OpBuilder::InsertionGuard g(builder); Region *bodyRegion = result.addRegion(); - bodyRegion->push_back(new Block); + builder.createBlock(bodyRegion); Block &bodyBlock = bodyRegion->front(); bodyBlock.addArguments( SmallVector(numThreads.size(), builder.getIndexType()), @@ -1167,9 +1169,11 @@ void ForeachThreadOp::build( OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(builder, result.location, bodyBlock.getArgument(0)); + bodyBuilder(builder, result.location, bodyBlock.getArguments()); auto terminator = - llvm::cast(bodyBlock.getTerminator()); + llvm::dyn_cast(bodyBlock.getTerminator()); + assert(terminator && + "expected bodyBuilder to create PerformConcurrentlyOp terminator"); result.addTypes(terminator.yieldedTypes()); } @@ -1272,6 +1276,13 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns( // PerformConcurrentlyOp //===----------------------------------------------------------------------===// +// Build a PerformConcurrentlyOp with mixed static and dynamic entries. +void PerformConcurrentlyOp::build(OpBuilder &b, OperationState &result) { + OpBuilder::InsertionGuard g(b); + Region *bodyRegion = result.addRegion(); + b.createBlock(bodyRegion); +} + LogicalResult PerformConcurrentlyOp::verify() { // TODO: PerformConcurrentlyOpInterface. for (const Operation &op : getRegion().front().getOperations()) diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 2a6a95e5e0d1c..79bef06dfc5f3 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -812,6 +813,289 @@ struct YieldOpInterface } }; +using tensor::ExtractSliceOp; + +/// Return the destinations that an ForeachThreadOp is inserting into. One per +/// ParallelInsertSliceOp. +static SmallVector +getInsertionDest(ForeachThreadOp foreachThreadOp) { + PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator(); + SmallVector result; + terminator.walk([&](ParallelInsertSliceOp insertOp) { + result.push_back(&insertOp->getOpOperand(1) /*dest*/); + }); + return result; +} + +/// Bufferization of ForeachThreadOp. This also bufferizes the terminator of the +/// region. There are op interfaces for the terminators (PerformConcurrentlyOp +/// and ParallelInsertSliceOp), but these are only used during analysis. Not +/// for bufferization. +struct ForeachThreadOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const AnalysisState &state) const { + // Get OpOperand (dest) from corresponding ParallelInsertSliceOp. + auto foreachThreadOp = cast(op); + return {getInsertionDest(foreachThreadOp)[opResult.getResultNumber()]}; + } + + bool isMemoryWrite(Operation *op, OpResult opResult, + const AnalysisState &state) const { + // This op is a memory write. Stop lookup here to avoid finding false + // conflicts involving this op and one of the ops in the region. This is + // similar to how scf.if ops are analyzed. + return true; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, RewriterBase &b, + BufferizationState &state) const { + OpBuilder::InsertionGuard g(b); + auto foreachThreadOp = cast(op); + + // Gather new results of the ForeachThreadOp. + SmallVector newResults; + for (OpResult opResult : foreachThreadOp->getOpResults()) { + SmallVector insertDestOperands = + state.getAnalysisState().getAliasingOpOperand(opResult); + assert(insertDestOperands.size() == 1 && + "expected exactly one aliasing OpOperand"); + // Insert copies right before the PerformConcurrentlyOp terminator. They + // should not be inside terminator (which would be the default insertion + // point). + Value buffer = *state.getBuffer(b, *insertDestOperands.front(), + /*forceInPlace=*/llvm::None, + /*customCopyInsertionPoint=*/op); + newResults.push_back(buffer); + } + + // Create new ForeachThreadOp without any results and drop the automatically + // introduced terminator. + TypeRange newResultTypes; + auto newForeachThreadOp = + b.create(foreachThreadOp.getLoc(), newResultTypes, + foreachThreadOp.getNumThreads()); + newForeachThreadOp.getBody()->getTerminator()->erase(); + + // Move over block contents of the old op. + b.mergeBlocks(foreachThreadOp.getBody(), newForeachThreadOp.getBody(), + {newForeachThreadOp.getBody()->getArguments()}); + + // Bufferize terminator. + auto performConcurrentlyOp = cast( + newForeachThreadOp.getBody()->getTerminator()); + b.setInsertionPoint(performConcurrentlyOp); + unsigned resultCounter = 0; + WalkResult walkResult = + performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) { + Location loc = insertOp.getLoc(); + Type srcType = getMemRefType( + insertOp.getSource().getType().cast(), + state.getOptions()); + // ParallelInsertSliceOp bufferizes to a copy. + auto srcMemref = b.create( + loc, srcType, insertOp.getSource()); + Value destMemref = newResults[resultCounter++]; + Value subview = b.create( + loc, destMemref, insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + // This memcpy will fold away if everything bufferizes in-place. + if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(), + srcMemref, subview))) + return WalkResult::interrupt(); + b.eraseOp(insertOp); + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + + // Replace the op. + replaceOpWithBufferizedValues(b, op, newResults); + + return success(); + } +}; + +/// Nothing to do for PerformConcurrentlyOp. +struct PerformConcurrentlyOpInterface + : public BufferizableOpInterface::ExternalModel< + PerformConcurrentlyOpInterface, PerformConcurrentlyOp> { + LogicalResult bufferize(Operation *op, RewriterBase &b, + BufferizationState &state) const { + assert(false && "op does not have any tensor OpOperands / OpResults"); + return failure(); + } +}; + +/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. +/// equivalent operand / result and same offset/sizes/strides specification). +static bool areEquivalentExtractSliceOps(const AnalysisState &state, + ExtractSliceOp st, + ParallelInsertSliceOp sti) { + if (!st || !sti) + return false; + if (st != sti && + !state.areEquivalentBufferizedValues(st.source(), sti.getDest())) + return false; + if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) + return false; + return true; +} + +/// Return true if `value` is originating from an ExtractSliceOp that matches +/// the given InsertSliceOp. +static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, + ParallelInsertSliceOp insertOp) { + auto condition = [&](Value val) { + if (auto extractOp = val.getDefiningOp()) + if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) + return true; + return false; + }; + + return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), + condition); +} + +/// Analysis of ParallelInsertSliceOp. +struct ParallelInsertSliceOpInterface + : public BufferizableOpInterface::ExternalModel< + ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + if (&opOperand != &op->getOpOperand(1) /*dest*/) + return {}; + + // ParallelInsertSliceOp itself has no results. Tensors are returned via + // the parent op. + auto foreachThreadOp = op->getParentOfType(); + assert(foreachThreadOp && + "could not find valid owner of parallel_insert_slice"); + + // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult + // of the parent ForeachThreadOp. + Block *block = op->getBlock(); + unsigned int opIdx = 0; + for (ParallelInsertSliceOp insertOp : + block->getOps()) { + if (insertOp.getOperation() == op) + break; + ++opIdx; + } + assert(opIdx < foreachThreadOp->getNumResults() && + "could not find op inside terminator op"); + + return {foreachThreadOp->getResult(opIdx)}; + } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return &opOperand == &op->getOpOperand(1) /*dest*/; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, RewriterBase &b, + BufferizationState &state) const { + // Will be bufferized as part of ForeachThreadOp. + return failure(); + } + + // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share + // the code. + bool isNotConflicting(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const AnalysisState &state) const { + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If + // uRead is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(readingOp)) { + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + + // TODO: Use insertSliceOp.getDestOpOperand etc. when available. + if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(state, uConflictingWrite->get(), + insertSliceOp)) + // Case 1: The main insight is that InsertSliceOp reads only part of + // the destination tensor. The overwritten area is not read. If + // uConflictingWrite writes into exactly the memory location that is + // being read by uRead, this is not a conflict. + // + // In the above example: + // uRead = OpOperand 1 (%t) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%0) of linalg.fill + // + // The read of %t does not conflict with the write of the FillOp + // (same aliases!) because the area that the FillOp operates on is + // exactly the one that is *not* read via %t. + return true; + + if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && + uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) + // Case 2: The read of the source tensor and the write to the dest + // tensor via an InsertSliceOp is not a conflict if the read is + // reading exactly that part of an equivalent tensor that the + // InsertSliceOp is writing. + // + // In the above example: + // uRead = OpOperand 0 (%1) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + return true; + } + + // If uConflictingWrite is an InsertSliceOp... + if (auto insertSliceOp = + dyn_cast(conflictingWritingOp)) + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + // %3 = vector.transfer_read %1, %cst + // + // In the above example: + // uRead = OpOperand 0 (%1) of vector.transfer_read + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + // lastWrite = %1 + // + // This is not a conflict because the InsertSliceOp overwrites the + // memory segment of %1 with the exact same data. (Effectively, there + // is no memory write here.) + if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + state.areEquivalentBufferizedValues(uRead->get(), + insertSliceOp.getSource()) && + hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), + insertSliceOp)) + return true; + + return false; + } +}; + } // namespace } // namespace scf } // namespace mlir @@ -822,6 +1106,11 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels( ExecuteRegionOp::attachInterface(*ctx); ForOp::attachInterface(*ctx); IfOp::attachInterface(*ctx); + ForeachThreadOp::attachInterface(*ctx); + ParallelInsertSliceOp::attachInterface( + *ctx); + PerformConcurrentlyOp::attachInterface( + *ctx); WhileOp::attachInterface(*ctx); YieldOp::attachInterface(*ctx); }); diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir index 888eea82bbf79..00c977d52b99d 100644 --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -486,3 +486,127 @@ func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>, } return } + +// ----- + +// CHECK-LABEL: func.func @parallel_insert_slice_no_conflict( +// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index, +// CHECK-SAME: %[[arg1:.*]]: memref, +// CHECK-SAME: %[[arg2:.*]]: memref +func.func @parallel_insert_slice_no_conflict( + %idx: index, + %idx2: index, + %arg1: tensor {bufferization.writable = true}, + %arg2: tensor {bufferization.writable = true}) -> (tensor, f32) { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> () + %2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor) { + // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1] + %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor to tensor + // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref) -> tensor + // Self-copy will DCE away later. + // CHECK: memref.copy %[[subview]], %[[subview]] + + // Empty terminator is elided from pretty-printing. + // CHECK-NOT: scf.foreach_thread.perform_concurrently + // CHECK-NOT: parallel_insert_slice + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : + tensor into tensor + } + } + + // CHECK: %[[load:.*]] = memref.load %[[arg2]] + %f = tensor.extract %2[%c0] : tensor + + // CHECK: return %[[load]] : f32 + return %2, %f : tensor, f32 +} + +// ----- + +// CHECK-LABEL: func.func @parallel_insert_slice_with_conflict( +// CHECK-SAME: %[[idx:.*]]: index, %[[idx2:.*]]: index, +// CHECK-SAME: %[[arg1:.*]]: memref, +// CHECK-SAME: %[[arg2:.*]]: memref +func.func @parallel_insert_slice_with_conflict( + %idx: index, + %idx2: index, + %arg1: tensor {bufferization.writable = true}, + %arg2: tensor {bufferization.writable = true}) -> (f32, f32) +{ + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // The parallel_insert_slice_op bufferizes out-of-place due to a RAW conflict + // on %arg2, so we need an allocation. + // CHECK: %[[alloc1:.*]] = memref.alloc + // CHECK: memref.copy %[[arg2]], %[[alloc1]] + + // CHECK: scf.foreach_thread (%[[tidx:.*]]) in (%[[idx2]]) -> () + %2 = scf.foreach_thread (%arg3) in (%idx2) -> (tensor) { + // Another alloc for the extract_slice op. + // CHECK: %[[alloc2:.*]] = memref.alloc + %6 = tensor.extract_slice %arg2[5] [%idx] [%c1] : tensor to tensor + + // CHECK: linalg.fill ins(%{{.*}}) outs(%[[alloc2]] : memref) -> tensor + + // Now the copy of the actual insert_slice. + // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1] + // + // CHECK: memref.copy %[[alloc2]], %[[subview1]] + // CHECK: memref.dealloc %[[alloc2]] + + // Empty terminator is elided from pretty-printing. + // CHECK-NOT: scf.foreach_thread.perform_concurrently + // CHECK-NOT: parallel_insert_slice + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %8 into %arg2[5] [%idx] [%c1] : + tensor into tensor + } + } + + // CHECK: %[[load:.*]] = memref.load %[[arg2]] + // CHECK: %[[load2:.*]] = memref.load %[[alloc1]] + // CHECK: memref.dealloc %[[alloc1]] + %f = tensor.extract %arg2[%c0] : tensor + %f2 = tensor.extract %2[%c0] : tensor + + // CHECK: return %[[load2]], %[[load]] : f32, f32 + return %f2, %f : f32, f32 +} + +// ----- + +#map0 = affine_map<(d0) -> (d0 * 4)> +#map1 = affine_map<(d0) -> (d0 * 2)> + +// CHECK: #[[$DYN_LAYOUT_MAP:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +// CHECK-LABEL: func.func @matmul +func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32> {bufferization.writable = true}) -> tensor<8x8xf32> { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + // CHECK: scf.foreach_thread {{.*}} -> () + %0 = scf.foreach_thread (%arg3, %arg4) in (%c2, %c4) -> (tensor<8x8xf32>) { + %1 = affine.apply #map0(%arg3) + %3 = tensor.extract_slice %arg0[%1, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32> + %4 = affine.apply #map1(%arg4) + %6 = tensor.extract_slice %arg1[0, %4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32> + %7 = tensor.extract_slice %arg2[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32> + + // CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, #[[$DYN_LAYOUT_MAP]]>, memref<8x4xf32, #[[$DYN_LAYOUT_MAP]]>) outs({{.*}} : memref<4x4xf32, #[[$DYN_LAYOUT_MAP]]>) + %8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32> + scf.foreach_thread.perform_concurrently { + scf.foreach_thread.parallel_insert_slice %8 into %arg2[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32> + } + } + return %0 : tensor<8x8xf32> +}