Skip to content

Commit

Permalink
[mlir][bufferize] Add vector-bufferize pass and remove obsolete patte…
Browse files Browse the repository at this point in the history
…rns from Linalg Bufferize

Differential Revision: https://reviews.llvm.org/D119444
  • Loading branch information
matthias-springer committed Feb 15, 2022
1 parent 8527859 commit 73e880f
Show file tree
Hide file tree
Showing 14 changed files with 223 additions and 63 deletions.
Expand Up @@ -130,8 +130,34 @@ struct BufferizationOptions {
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
opFilter.push_back(
OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
allowOperationInFilter(filterFn);
}

/// Deny the given op and activate the filter (`hasFilter`).
///
/// This function adds a DENY filter.
void denyOperationInFilter(StringRef opName) {
hasFilter = true;
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
denyOperationInFilter(filterFn);
}

/// Allow ops that are matched by `fn` and activate the filter (`hasFilter`).
///
/// This function adds an ALLOW filter.
void allowOperationInFilter(OpFilterEntry::FilterFn fn) {
hasFilter = true;
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW});
}

/// Deny ops that are matched by `fn` and activate the filter (`hasFilter`).
///
/// This function adds a DENY filter.
void denyOperationInFilter(OpFilterEntry::FilterFn fn) {
hasFilter = true;
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY});
}

/// Try to cast the given op to BufferizableOpInterface if the op is allow
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1 +1,5 @@
# This dialect does currently not have any passes.
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector)
add_public_tablegen_target(MLIRVectorTransformsIncGen)

add_mlir_doc(Passes VectorPasses ./ -gen-pass-doc)
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -0,0 +1,30 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_

#include "mlir/Pass/Pass.h"

namespace mlir {
namespace vector {
/// Creates an instance of the `vector` dialect bufferization pass.
std::unique_ptr<Pass> createVectorBufferizePass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
} // namespace vector

} // namespace mlir

#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -0,0 +1,19 @@
//===-- Passes.td - Vector pass definition file ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def VectorBufferize : Pass<"vector-bufferize", "FuncOp"> {
let summary = "Bufferize Vector dialect ops";
let constructor = "mlir::vector::createVectorBufferizePass()";
}

#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllPasses.h
Expand Up @@ -32,6 +32,7 @@
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"

#include <cstdlib>
Expand Down Expand Up @@ -71,6 +72,7 @@ inline void registerAllPasses() {
registerStandardPasses();
tensor::registerTensorPasses();
tosa::registerTosaOptPasses();
vector::registerVectorPasses();

// Dialect pipelines
sparse_tensor::registerSparseTensorPipelines();
Expand Down
44 changes: 1 addition & 43 deletions mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
Expand Up @@ -268,43 +268,6 @@ class InsertSliceOpConverter
return success();
}
};

class VectorTransferReadOpConverter
: public OpConversionPattern<vector::TransferReadOp> {
public:
using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (readOp.getShapedType().isa<MemRefType>())
return failure();
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
adaptor.permutation_mapAttr(), adaptor.padding(), adaptor.mask(),
adaptor.in_boundsAttr());
return success();
}
};

class VectorTransferWriteOpConverter
: public OpConversionPattern<vector::TransferWriteOp> {
public:
using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (writeOp.getShapedType().isa<MemRefType>())
return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
adaptor.permutation_mapAttr(),
adaptor.in_bounds() ? adaptor.in_boundsAttr() : ArrayAttr());
rewriter.replaceOp(writeOp, adaptor.source());
return success();
}
};
} // namespace

namespace {
Expand All @@ -329,9 +292,6 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
return typeConverter.isLegal(op);
};
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
target
.addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
isLegalOperation);

RewritePatternSet patterns(&context);
populateLinalgBufferizePatterns(typeConverter, patterns);
Expand All @@ -358,9 +318,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeTensorReshapeOp<tensor::ExpandShapeOp>,
BufferizeTensorReshapeOp<tensor::CollapseShapeOp>,
ExtractSliceOpConverter,
InsertSliceOpConverter,
VectorTransferReadOpConverter,
VectorTransferWriteOpConverter
InsertSliceOpConverter
>(typeConverter, patterns.getContext());
// clang-format on
patterns.add<GeneralizePadOpPattern>(patterns.getContext());
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt
Expand Up @@ -19,4 +19,5 @@ add_mlir_dialect_library(MLIRSparseTensorPipelines
MLIRStandardOpsTransforms
MLIRTensorTransforms
MLIRVectorToLLVM
MLIRVectorTransforms
)
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"

using namespace mlir;
Expand All @@ -31,6 +32,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
pm.addPass(createSparseTensorConversionPass());
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(vector::createVectorBufferizePass());
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
pm.addNestedPass<FuncOp>(createConvertVectorToSCFPass());
pm.addNestedPass<FuncOp>(createConvertSCFToCFPass());
Expand Down
46 changes: 46 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/Bufferize.cpp
@@ -0,0 +1,46 @@
//===- Bufferize.cpp - Bufferization for `vector` dialect ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of `vector` dialect ops
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"

using namespace mlir;
using namespace bufferization;

namespace {
struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<vector::VectorDialect>();

if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
tensor::TensorDialect, vector::VectorDialect>();
vector::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // namespace

std::unique_ptr<Pass> mlir::vector::createVectorBufferizePass() {
return std::make_unique<VectorBufferizePass>();
}
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
VectorDropLeadUnitDim.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorMultiDimReductionTransforms.cpp
Expand All @@ -12,17 +13,22 @@ add_mlir_dialect_library(MLIRVectorTransforms
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms

DEPENDS
MLIRVectorTransformsIncGen

LINK_LIBS PUBLIC
MLIRAffine
MLIRAffineAnalysis
MLIRAffineUtils
MLIRArithmetic
MLIRBufferization
MLIRBufferizationTransforms
MLIRDialectUtils
MLIRIR
MLIRLinalg
MLIRMemRef
MLIRSCF
MLIRTransforms
MLIRVector
MLIRVectorInterfaces
MLIRVectorUtils
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/PassDetail.h
@@ -0,0 +1,29 @@
//===- PassDetail.h - Vector Pass class details -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_

#include "mlir/Pass/Pass.h"

namespace mlir {

namespace bufferization {
class BufferizationDialect;
} // namespace bufferization

namespace memref {
class MemRefDialect;
} // namespace memref

#define GEN_PASS_CLASSES
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"

} // namespace mlir

#endif // DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_
17 changes: 0 additions & 17 deletions mlir/test/Dialect/Linalg/bufferize.mlir
Expand Up @@ -303,23 +303,6 @@ func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tens
// CHECK: return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32>
// CHECK: }


// -----

// CHECK-LABEL: func @vector_transfer
func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%read = vector.transfer_read %in[%c0], %cst {in_bounds = [true]}
: tensor<4xf32>, vector<4xf32>
%tanh = math.tanh %read : vector<4xf32>
%write = vector.transfer_write %tanh, %out[%c0] {in_bounds = [true]}
: vector<4xf32>, tensor<4xf32>
return
// CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32>
}

// -----

// CHECK-LABEL: func @bufferize_dot
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Vector/bufferize.mlir
@@ -0,0 +1,30 @@
// RUN: mlir-opt %s -vector-bufferize -split-input-file | FileCheck %s

// CHECK-LABEL: func @transfer_read(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[pad:.*]]: f32)
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32>
// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[o1]], %[[o2]]], %[[pad]] {in_bounds = [true, false]} : memref<?x?xf32>, vector<5x6xf32>
// CHECK: return %[[r]]
func @transfer_read(%t: tensor<?x?xf32>, %o1: index,
%o2: index, %pad: f32) -> vector<5x6xf32> {
%0 = vector.transfer_read %t[%o1, %o2], %pad {in_bounds = [true, false]}
: tensor<?x?xf32>, vector<5x6xf32>
return %0 : vector<5x6xf32>
}

// -----

// CHECK-LABEL: func @transfer_write(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>)
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32>
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}, %{{.*}}) {{.*}} : memref<?x?xf32>
// CHECK: memref.copy %[[m]], %[[alloc]]
// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]] {in_bounds = [true, false]} : vector<5x6xf32>, memref<?x?xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref<?x?xf32>
// CHECK: return %[[r]]
func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
%o2: index, %vec: vector<5x6xf32>) -> tensor<?x?xf32> {
%0 = vector.transfer_write %vec, %t[%o1, %o2] {in_bounds = [true, false]}
: vector<5x6xf32>, tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

0 comments on commit 73e880f

Please sign in to comment.