From 73e880fbf17fc4d8dc3fdfec2a18eb26d1804750 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 15 Feb 2022 21:16:50 +0900 Subject: [PATCH] [mlir][bufferize] Add vector-bufferize pass and remove obsolete patterns from Linalg Bufferize Differential Revision: https://reviews.llvm.org/D119444 --- .../IR/BufferizableOpInterface.h | 30 +++++++++++- .../Dialect/Vector/Transforms/CMakeLists.txt | 6 ++- .../mlir/Dialect/Vector/Transforms/Passes.h | 30 ++++++++++++ .../mlir/Dialect/Vector/Transforms/Passes.td | 19 ++++++++ mlir/include/mlir/InitAllPasses.h | 2 + .../Dialect/Linalg/Transforms/Bufferize.cpp | 44 +----------------- .../SparseTensor/Pipelines/CMakeLists.txt | 1 + .../Pipelines/SparseTensorPipelines.cpp | 2 + .../Dialect/Vector/Transforms/Bufferize.cpp | 46 +++++++++++++++++++ .../Dialect/Vector/Transforms/CMakeLists.txt | 6 +++ .../Dialect/Vector/Transforms/PassDetail.h | 29 ++++++++++++ mlir/test/Dialect/Linalg/bufferize.mlir | 17 ------- mlir/test/Dialect/Vector/bufferize.mlir | 30 ++++++++++++ .../llvm-project-overlay/mlir/BUILD.bazel | 24 ++++++++++ 14 files changed, 223 insertions(+), 63 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/Passes.h create mode 100644 mlir/include/mlir/Dialect/Vector/Transforms/Passes.td create mode 100644 mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp create mode 100644 mlir/lib/Dialect/Vector/Transforms/PassDetail.h create mode 100644 mlir/test/Dialect/Vector/bufferize.mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 257f67a2db440..433da0d953086 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -130,8 +130,34 @@ struct BufferizationOptions { OpFilterEntry::FilterFn filterFn = [=](Operation *op) { return op->getName().getStringRef() == opName; }; - opFilter.push_back( - OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW}); + allowOperationInFilter(filterFn); + } + + /// Deny the given op and activate the filter (`hasFilter`). + /// + /// This function adds a DENY filter. + void denyOperationInFilter(StringRef opName) { + hasFilter = true; + OpFilterEntry::FilterFn filterFn = [=](Operation *op) { + return op->getName().getStringRef() == opName; + }; + denyOperationInFilter(filterFn); + } + + /// Allow ops that are matched by `fn` and activate the filter (`hasFilter`). + /// + /// This function adds an ALLOW filter. + void allowOperationInFilter(OpFilterEntry::FilterFn fn) { + hasFilter = true; + opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW}); + } + + /// Deny ops that are matched by `fn` and activate the filter (`hasFilter`). + /// + /// This function adds a DENY filter. + void denyOperationInFilter(OpFilterEntry::FilterFn fn) { + hasFilter = true; + opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY}); } /// Try to cast the given op to BufferizableOpInterface if the op is allow diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt index bcf3de2010254..35868d1e69233 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt @@ -1 +1,5 @@ -# This dialect does currently not have any passes. +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector) +add_public_tablegen_target(MLIRVectorTransformsIncGen) + +add_mlir_doc(Passes VectorPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h new file mode 100644 index 0000000000000..7734fe9d9cc98 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h @@ -0,0 +1,30 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace vector { +/// Creates an instance of the `vector` dialect bufferization pass. +std::unique_ptr createVectorBufferizePass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" +} // namespace vector + +} // namespace mlir + +#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td new file mode 100644 index 0000000000000..6bca0dad1bf85 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td @@ -0,0 +1,19 @@ +//===-- Passes.td - Vector pass definition file ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES +#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def VectorBufferize : Pass<"vector-bufferize", "FuncOp"> { + let summary = "Bufferize Vector dialect ops"; + let constructor = "mlir::vector::createVectorBufferizePass()"; +} + +#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index d36bf6082876a..eb74ca02a8eac 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -32,6 +32,7 @@ #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" #include @@ -71,6 +72,7 @@ inline void registerAllPasses() { registerStandardPasses(); tensor::registerTensorPasses(); tosa::registerTosaOptPasses(); + vector::registerVectorPasses(); // Dialect pipelines sparse_tensor::registerSparseTensorPipelines(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp index 6ffb357264566..9ef78ea15d1c9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -268,43 +268,6 @@ class InsertSliceOpConverter return success(); } }; - -class VectorTransferReadOpConverter - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (readOp.getShapedType().isa()) - return failure(); - rewriter.replaceOpWithNewOp( - readOp, readOp.getType(), adaptor.source(), adaptor.indices(), - adaptor.permutation_mapAttr(), adaptor.padding(), adaptor.mask(), - adaptor.in_boundsAttr()); - return success(); - } -}; - -class VectorTransferWriteOpConverter - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (writeOp.getShapedType().isa()) - return failure(); - rewriter.create( - writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), - adaptor.permutation_mapAttr(), - adaptor.in_bounds() ? adaptor.in_boundsAttr() : ArrayAttr()); - rewriter.replaceOp(writeOp, adaptor.source()); - return success(); - } -}; } // namespace namespace { @@ -329,9 +292,6 @@ struct LinalgBufferizePass : public LinalgBufferizeBase { return typeConverter.isLegal(op); }; target.addDynamicallyLegalDialect(isLegalOperation); - target - .addDynamicallyLegalOp( - isLegalOperation); RewritePatternSet patterns(&context); populateLinalgBufferizePatterns(typeConverter, patterns); @@ -358,9 +318,7 @@ void mlir::linalg::populateLinalgBufferizePatterns( BufferizeTensorReshapeOp, BufferizeTensorReshapeOp, ExtractSliceOpConverter, - InsertSliceOpConverter, - VectorTransferReadOpConverter, - VectorTransferWriteOpConverter + InsertSliceOpConverter >(typeConverter, patterns.getContext()); // clang-format on patterns.add(patterns.getContext()); diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt index fc0f59b232922..e76261f292fe5 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt @@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines MLIRStandardOpsTransforms MLIRTensorTransforms MLIRVectorToLLVM + MLIRVectorTransforms ) diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index 819ecc6e5c882..25487e431708b 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" using namespace mlir; @@ -31,6 +32,7 @@ void mlir::sparse_tensor::buildSparseCompiler( pm.addPass(createSparsificationPass(options.sparsificationOptions())); pm.addPass(createSparseTensorConversionPass()); pm.addNestedPass(createLinalgBufferizePass()); + pm.addNestedPass(vector::createVectorBufferizePass()); pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addNestedPass(createConvertVectorToSCFPass()); pm.addNestedPass(createConvertSCFToCFPass()); diff --git a/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp new file mode 100644 index 0000000000000..4ed2dd629c1b9 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp @@ -0,0 +1,46 @@ +//===- Bufferize.cpp - Bufferization for `vector` dialect ops -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements bufferization of `vector` dialect ops +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "PassDetail.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" + +using namespace mlir; +using namespace bufferization; + +namespace { +struct VectorBufferizePass : public VectorBufferizeBase { + void runOnOperation() override { + BufferizationOptions options = getPartialBufferizationOptions(); + options.allowDialectInFilter(); + + if (failed(bufferizeOp(getOperation(), options))) + signalPassFailure(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + vector::registerBufferizableOpInterfaceExternalModels(registry); + } +}; +} // namespace + +std::unique_ptr mlir::vector::createVectorBufferizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index cf6be4ca59539..bfba400c59994 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRVectorTransforms BufferizableOpInterfaceImpl.cpp + Bufferize.cpp VectorDropLeadUnitDim.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp VectorMultiDimReductionTransforms.cpp @@ -12,17 +13,22 @@ add_mlir_dialect_library(MLIRVectorTransforms ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms + DEPENDS + MLIRVectorTransformsIncGen + LINK_LIBS PUBLIC MLIRAffine MLIRAffineAnalysis MLIRAffineUtils MLIRArithmetic MLIRBufferization + MLIRBufferizationTransforms MLIRDialectUtils MLIRIR MLIRLinalg MLIRMemRef MLIRSCF + MLIRTransforms MLIRVector MLIRVectorInterfaces MLIRVectorUtils diff --git a/mlir/lib/Dialect/Vector/Transforms/PassDetail.h b/mlir/lib/Dialect/Vector/Transforms/PassDetail.h new file mode 100644 index 0000000000000..2ef3176bf67a3 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/PassDetail.h @@ -0,0 +1,29 @@ +//===- PassDetail.h - Vector Pass class details -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace bufferization { +class BufferizationDialect; +} // namespace bufferization + +namespace memref { +class MemRefDialect; +} // namespace memref + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" + +} // namespace mlir + +#endif // DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 3fb7d49c04f4b..614f207bb3354 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -303,23 +303,6 @@ func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tens // CHECK: return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32> // CHECK: } - -// ----- - -// CHECK-LABEL: func @vector_transfer -func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) { - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f32 - %read = vector.transfer_read %in[%c0], %cst {in_bounds = [true]} - : tensor<4xf32>, vector<4xf32> - %tanh = math.tanh %read : vector<4xf32> - %write = vector.transfer_write %tanh, %out[%c0] {in_bounds = [true]} - : vector<4xf32>, tensor<4xf32> - return - // CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32> -} - // ----- // CHECK-LABEL: func @bufferize_dot diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir new file mode 100644 index 0000000000000..b68f4a6d3a17d --- /dev/null +++ b/mlir/test/Dialect/Vector/bufferize.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s -vector-bufferize -split-input-file | FileCheck %s + +// CHECK-LABEL: func @transfer_read( +// CHECK-SAME: %[[t:.*]]: tensor, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[pad:.*]]: f32) +// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref +// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[o1]], %[[o2]]], %[[pad]] {in_bounds = [true, false]} : memref, vector<5x6xf32> +// CHECK: return %[[r]] +func @transfer_read(%t: tensor, %o1: index, + %o2: index, %pad: f32) -> vector<5x6xf32> { + %0 = vector.transfer_read %t[%o1, %o2], %pad {in_bounds = [true, false]} + : tensor, vector<5x6xf32> + return %0 : vector<5x6xf32> +} + +// ----- + +// CHECK-LABEL: func @transfer_write( +// CHECK-SAME: %[[t:.*]]: tensor, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>) +// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref +// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}, %{{.*}}) {{.*}} : memref +// CHECK: memref.copy %[[m]], %[[alloc]] +// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]] {in_bounds = [true, false]} : vector<5x6xf32>, memref +// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref +// CHECK: return %[[r]] +func @transfer_write(%t: tensor, %o1: index, + %o2: index, %vec: vector<5x6xf32>) -> tensor { + %0 = vector.transfer_write %vec, %t[%o1, %o2] {in_bounds = [true, false]} + : vector<5x6xf32>, tensor + return %0 : tensor +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index a63365cf385d0..8ef56dbc8c684 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1994,6 +1994,7 @@ cc_library( ":StandardOpsTransforms", ":TensorTransforms", ":VectorToLLVM", + ":VectorTransforms", ], ) @@ -2906,11 +2907,29 @@ cc_library( ], ) +gentbl_cc_library( + name = "VectorPassIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=Vector", + ], + "include/mlir/Dialect/Vector/Transforms/Passes.h.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Vector/Transforms/Passes.td", + deps = [":PassBaseTdFiles"], +) + cc_library( name = "VectorTransforms", srcs = glob( [ "lib/Dialect/Vector/Transforms/*.cpp", + "lib/Dialect/Vector/Transforms/*.h", ], ), hdrs = glob([ @@ -2923,16 +2942,20 @@ cc_library( ":Analysis", ":ArithmeticDialect", ":BufferizationDialect", + ":BufferizationTransforms", ":DialectUtils", ":IR", ":LinalgOps", ":MemRefDialect", + ":Pass", ":SCFDialect", ":StandardOps", ":Support", ":TensorDialect", + ":Transforms", ":VectorInterfaces", ":VectorOps", + ":VectorPassIncGen", ":VectorUtils", "//llvm:Support", ], @@ -5911,6 +5934,7 @@ cc_library( ":VectorToROCDL", ":VectorToSCF", ":VectorToSPIRV", + ":VectorTransforms", ":X86Vector", ":X86VectorTransforms", ],