Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][bufferize] Add vector-bufferize pass and remove obsolete patte…
…rns from Linalg Bufferize Differential Revision: https://reviews.llvm.org/D119444
- Loading branch information
1 parent
8527859
commit 73e880f
Showing
14 changed files
with
223 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
# This dialect does currently not have any passes. | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector) | ||
add_public_tablegen_target(MLIRVectorTransformsIncGen) | ||
|
||
add_mlir_doc(Passes VectorPasses ./ -gen-pass-doc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ | ||
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
namespace vector { | ||
/// Creates an instance of the `vector` dialect bufferization pass. | ||
std::unique_ptr<Pass> createVectorBufferizePass(); | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Registration | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Generate the code for registering passes. | ||
#define GEN_PASS_REGISTRATION | ||
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" | ||
} // namespace vector | ||
|
||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
//===-- Passes.td - Vector pass definition file ------------*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES | ||
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def VectorBufferize : Pass<"vector-bufferize", "FuncOp"> { | ||
let summary = "Bufferize Vector dialect ops"; | ||
let constructor = "mlir::vector::createVectorBufferizePass()"; | ||
} | ||
|
||
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
//===- Bufferize.cpp - Bufferization for `vector` dialect ops -------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements bufferization of `vector` dialect ops | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" | ||
#include "PassDetail.h" | ||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" | ||
#include "mlir/Dialect/Vector/Transforms/Passes.h" | ||
|
||
using namespace mlir; | ||
using namespace bufferization; | ||
|
||
namespace { | ||
struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> { | ||
void runOnOperation() override { | ||
BufferizationOptions options = getPartialBufferizationOptions(); | ||
options.allowDialectInFilter<vector::VectorDialect>(); | ||
|
||
if (failed(bufferizeOp(getOperation(), options))) | ||
signalPassFailure(); | ||
} | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect, | ||
tensor::TensorDialect, vector::VectorDialect>(); | ||
vector::registerBufferizableOpInterfaceExternalModels(registry); | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<Pass> mlir::vector::createVectorBufferizePass() { | ||
return std::make_unique<VectorBufferizePass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
//===- PassDetail.h - Vector Pass class details -----------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ | ||
#define DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
|
||
namespace bufferization { | ||
class BufferizationDialect; | ||
} // namespace bufferization | ||
|
||
namespace memref { | ||
class MemRefDialect; | ||
} // namespace memref | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" | ||
|
||
} // namespace mlir | ||
|
||
#endif // DIALECT_VECTOR_TRANSFORMS_PASSDETAIL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// RUN: mlir-opt %s -vector-bufferize -split-input-file | FileCheck %s | ||
|
||
// CHECK-LABEL: func @transfer_read( | ||
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[pad:.*]]: f32) | ||
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32> | ||
// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[o1]], %[[o2]]], %[[pad]] {in_bounds = [true, false]} : memref<?x?xf32>, vector<5x6xf32> | ||
// CHECK: return %[[r]] | ||
func @transfer_read(%t: tensor<?x?xf32>, %o1: index, | ||
%o2: index, %pad: f32) -> vector<5x6xf32> { | ||
%0 = vector.transfer_read %t[%o1, %o2], %pad {in_bounds = [true, false]} | ||
: tensor<?x?xf32>, vector<5x6xf32> | ||
return %0 : vector<5x6xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @transfer_write( | ||
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[o1:.*]]: index, %[[o2:.*]]: index, %[[vec:.*]]: vector<5x6xf32>) | ||
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?x?xf32> | ||
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}}, %{{.*}}) {{.*}} : memref<?x?xf32> | ||
// CHECK: memref.copy %[[m]], %[[alloc]] | ||
// CHECK: vector.transfer_write %[[vec]], %[[alloc]][%[[o1]], %[[o2]]] {in_bounds = [true, false]} : vector<5x6xf32>, memref<?x?xf32> | ||
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]] : memref<?x?xf32> | ||
// CHECK: return %[[r]] | ||
func @transfer_write(%t: tensor<?x?xf32>, %o1: index, | ||
%o2: index, %vec: vector<5x6xf32>) -> tensor<?x?xf32> { | ||
%0 = vector.transfer_write %vec, %t[%o1, %o2] {in_bounds = [true, false]} | ||
: vector<5x6xf32>, tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} |
Oops, something went wrong.