diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 16203e5459b97a..942581b4bbaf55 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -31,13 +31,6 @@ struct LinalgTilingOptions; //===----------------------------------------------------------------------===// using LinalgLoops = SmallVector; -struct TiledLinalgOp { - LinalgOp op; - SmallVector loops; - SmallVector tensorResults; - TiledLinalgOp &operator=(const TiledLinalgOp &) = default; -}; - /// Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, @@ -63,6 +56,12 @@ void populateLinalgBufferizePatterns(MLIRContext *context, /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`tileSizes.size()` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). +struct TiledLinalgOp { + LinalgOp op; + SmallVector loops; + SmallVector tensorResults; + TiledLinalgOp &operator=(const TiledLinalgOp &) = default; +}; Optional tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options); @@ -264,7 +263,12 @@ Optional promoteSubViews(OpBuilder &b, LinalgOp op, OperationFolder *folder = nullptr); /// Emit a suitable vector form for a Linalg op with fully static shape. -void vectorizeLinalgOp(OpBuilder &builder, Operation *op); +struct VectorizedLinalgOp { + SmallVector tensorResults; + VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default; +}; +Optional vectorizeLinalgOp(OpBuilder &builder, + Operation *op); /// Emits a loop nest of `LoopTy` with the proper body for `op`. template diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 8dac82a57de58c..b80b6fb090e794 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -468,10 +468,13 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); - if (failed(vectorizeLinalgOpPrecondition(op))) + Optional res = vectorizeLinalgOp(rewriter, op); + if (!res) return failure(); - vectorizeLinalgOp(rewriter, op); - rewriter.eraseOp(op); + if (!res->tensorResults.empty()) + rewriter.replaceOp(op, res->tensorResults); + else + rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 6e5b4912584550..a9a43e194d75e4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -248,8 +248,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op, /// TODO: Reuse opportunities for RAR dependencies. /// 4. Register CustomVectorizationHook for YieldOp to capture the results. /// 5. Iteratively call vectorizeOneOp on the region operations. -/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp. -static LogicalResult vectorizeAsLinalgGeneric( +static Optional vectorizeAsLinalgGeneric( OpBuilder &builder, LinalgOp linalgOp, ArrayRef customVectorizationHooks = {}) { // 1. Certain Linalg ops do not have a region but only a region builder. @@ -306,7 +305,7 @@ static LogicalResult vectorizeAsLinalgGeneric( VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks); if (result.status == VectorizationStatus::Failure) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op); - return failure(); + return llvm::None; } if (result.status == VectorizationStatus::NewOp) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: " @@ -315,10 +314,7 @@ static LogicalResult vectorizeAsLinalgGeneric( } } - // 6. RAUW the linalg op by the results captured vectorizing the YieldOp. - if (!results.empty()) - linalgOp->replaceAllUsesWith(results); - return success(); + return VectorizedLinalgOp{{results}}; } /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp. @@ -357,7 +353,8 @@ static bool isElementwise(Operation *op) { return hasOnlyScalarElementwiseOp(genericOp.getRegion()); } -static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) { +static Optional vectorizeContraction(OpBuilder &builder, + LinalgOp linalgOp) { assert(isaContractionOpInterface(linalgOp) && "expected vectorizeContraction preconditions to be met"); Location loc = linalgOp.getLoc(); @@ -384,11 +381,7 @@ static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) { linalgOp.indexing_maps(), linalgOp.iterator_types()); return VectorizationResult{VectorizationStatus::NewOp, contract}; }; - auto status = - vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction}); - (void)status; - assert(succeeded(status) && - "Unexpected vectorization failed despite preconditions"); + return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction}); } LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { @@ -408,8 +401,10 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { return success(isaContractionOpInterface(linalgOp)); } -void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { - assert(succeeded(vectorizeLinalgOpPrecondition(op))); +Optional mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, + Operation *op) { + if (failed(vectorizeLinalgOpPrecondition(op))) + return llvm::None; edsc::ScopedContext scope(builder, op->getLoc()); // In the case of 0-D memrefs, return null and special case to scalar load or @@ -418,8 +413,10 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { // Vectorize fill as a vector.broadcast. LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " << "Rewrite linalg.fill as vector.broadcast: " << *op); - buildVectorWrite(builder, fillOp.value(), fillOp.output()); - return; + VectorizedLinalgOp res; + if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output())) + res.tensorResults.push_back(v); + return res; } if (auto copyOp = dyn_cast(op)) { // Vectorize copy as a vector.transfer_read+vector.transfer_write. @@ -428,21 +425,26 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { "vector.transfer_write: " << *op); Value vector = buildVectorRead(builder, copyOp.input()); - buildVectorWrite(builder, vector, copyOp.output()); - return; + VectorizedLinalgOp res; + if (Value v = buildVectorWrite(builder, vector, copyOp.output())) + res.tensorResults.push_back(v); + return res; } - if (isElementwise(op)) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " - << "Rewrite linalg op as vector.transfer_read + " << *op); - auto status = vectorizeAsLinalgGeneric(builder, cast(op)); - (void)status; - assert(succeeded(status) && - "Unexpected vectorization failed despite preconditions"); - return; + << "Vectorize linalg op as a generic: " << *op); + return vectorizeAsLinalgGeneric(builder, cast(op)); } - vectorizeContraction(builder, cast(op)); + // TODO: as soon as Copy and FillOp. get a region builder, replace all the + // above by: + // if (isa(op) || isElementwise(op)) { + // LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: " + // << "Vectorize linalg op as a generic: " << *op); + // return vectorizeAsLinalgGeneric(builder, cast(op)); + // } + + return vectorizeContraction(builder, cast(op)); } //----------------------------------------------------------------------------// diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 3904353287c527..12841a4b6803bf 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file -debug-only=linalg-vectorization + +//| FileCheck %s // -----