Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][bufferization] Add OneShotBufferize transform op
This commit allows for One-Shot Bufferize to be used through the transform dialect. No op handle is currently returned for the bufferized IR. Differential Revision: https://reviews.llvm.org/D125098
- Loading branch information
1 parent
cedfb54
commit 461dafd
Showing
10 changed files
with
391 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
add_subdirectory(IR) | ||
add_subdirectory(TransformOps) | ||
add_subdirectory(Transforms) |
30 changes: 30 additions & 0 deletions
30
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//===- BufferizationTransformOps.h - Buff. transf. ops ----------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H | ||
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H | ||
|
||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" | ||
#include "mlir/IR/OpImplementation.h" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Bufferization Transform Operations | ||
//===----------------------------------------------------------------------===// | ||
|
||
#define GET_OP_CLASSES | ||
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h.inc" | ||
|
||
namespace mlir { | ||
class DialectRegistry; | ||
|
||
namespace bufferization { | ||
void registerTransformDialectExtension(DialectRegistry ®istry); | ||
} // namespace bufferization | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H |
58 changes: 58 additions & 0 deletions
58
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
//===- BufferizationTransformOps.td - Buff. transf. ops ----*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef BUFFERIZATION_TRANSFORM_OPS | ||
#define BUFFERIZATION_TRANSFORM_OPS | ||
|
||
include "mlir/Dialect/Transform/IR/TransformDialect.td" | ||
include "mlir/Dialect/Transform/IR/TransformEffects.td" | ||
include "mlir/Dialect/Transform/IR/TransformInterfaces.td" | ||
include "mlir/Dialect/PDL/IR/PDLTypes.td" | ||
include "mlir/Interfaces/SideEffectInterfaces.td" | ||
include "mlir/IR/OpBase.td" | ||
|
||
def OneShotBufferizeOp | ||
: Op<Transform_Dialect, "bufferization.one_shot_bufferize", | ||
[DeclareOpInterfaceMethods<TransformOpInterface>, | ||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> { | ||
let description = [{ | ||
Indicates that the given `target` op should be bufferized with One-Shot | ||
Bufferize. The bufferization can be configured with various attributes that | ||
corresponding to options in `BufferizationOptions` and the | ||
`one-shot-bufferize` pass. More information can be found in the pass | ||
documentation. | ||
|
||
If `target_is_module` is set, `target` must be a module. In that case the | ||
`target` handle can be reused by other transform ops. When bufferizing other | ||
ops, the `target` handled is freed after bufferization and can no longer be | ||
used. | ||
|
||
Note: Only ops that implement `BufferizableOpInterface` are bufferized. All | ||
other ops are ignored if `allow_unknown_ops`. If `allow_unknown_ops` is | ||
unset, this transform fails when an unknown/non-bufferizable op is found. | ||
Many ops implement `BufferizableOpInterface` via an external model. These | ||
external models must be registered when applying this transform op; | ||
otherwise, said ops would be considered non-bufferizable. | ||
}]; | ||
|
||
let arguments = ( | ||
ins PDL_Operation:$target, | ||
DefaultValuedAttr<BoolAttr, "false">:$allow_return_allocs, | ||
DefaultValuedAttr<BoolAttr, "false">:$allow_unknown_ops, | ||
DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries, | ||
DefaultValuedAttr<BoolAttr, "true">:$create_deallocs, | ||
DefaultValuedAttr<BoolAttr, "true">:$target_is_module, | ||
DefaultValuedAttr<BoolAttr, "false">:$test_analysis_only, | ||
DefaultValuedAttr<BoolAttr, "false">:$print_conflicts); | ||
|
||
let results = (outs); | ||
|
||
let assemblyFormat = "$target attr-dict"; | ||
} | ||
|
||
#endif // BUFFERIZATION_TRANSFORM_OPS |
4 changes: 4 additions & 0 deletions
4
mlir/include/mlir/Dialect/Bufferization/TransformOps/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
set(LLVM_TARGET_DEFINITIONS BufferizationTransformOps.td) | ||
mlir_tablegen(BufferizationTransformOps.h.inc -gen-op-decls) | ||
mlir_tablegen(BufferizationTransformOps.cpp.inc -gen-op-defs) | ||
add_public_tablegen_target(MLIRBufferizationTransformOpsIncGen) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
add_subdirectory(IR) | ||
add_subdirectory(TransformOps) | ||
add_subdirectory(Transforms) |
96 changes: 96 additions & 0 deletions
96
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
//===- BufferizationTransformOps.h - Bufferization transform ops ----------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" | ||
|
||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" | ||
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/PDL/IR/PDL.h" | ||
#include "mlir/Dialect/PDL/IR/PDLTypes.h" | ||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::bufferization; | ||
using namespace mlir::transform; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// OneShotBufferizeOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
LogicalResult | ||
transform::OneShotBufferizeOp::apply(TransformResults &transformResults, | ||
TransformState &state) { | ||
OneShotBufferizationOptions options; | ||
options.allowReturnAllocs = getAllowReturnAllocs(); | ||
options.allowUnknownOps = getAllowUnknownOps(); | ||
options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries(); | ||
options.createDeallocs = getCreateDeallocs(); | ||
options.testAnalysisOnly = getTestAnalysisOnly(); | ||
options.printConflicts = getPrintConflicts(); | ||
|
||
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); | ||
for (Operation *target : payloadOps) { | ||
auto moduleOp = dyn_cast<ModuleOp>(target); | ||
if (getTargetIsModule() && !moduleOp) | ||
return emitError("expected ModuleOp target"); | ||
if (options.bufferizeFunctionBoundaries) { | ||
if (!moduleOp) | ||
return emitError("expected ModuleOp target"); | ||
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) | ||
return emitError("bufferization failed"); | ||
} else { | ||
if (failed(bufferization::runOneShotBufferize(target, options))) | ||
return emitError("bufferization failed"); | ||
} | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
void transform::OneShotBufferizeOp::getEffects( | ||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | ||
effects.emplace_back(MemoryEffects::Read::get(), getTarget(), | ||
TransformMappingResource::get()); | ||
|
||
// Handles that are not modules are not longer usable. | ||
if (!getTargetIsModule()) | ||
effects.emplace_back(MemoryEffects::Free::get(), getTarget(), | ||
TransformMappingResource::get()); | ||
} | ||
//===----------------------------------------------------------------------===// | ||
// Transform op registration | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace { | ||
/// Registers new ops and declares PDL as dependent dialect since the additional | ||
/// ops are using PDL types for operands and results. | ||
class BufferizationTransformDialectExtension | ||
: public transform::TransformDialectExtension< | ||
BufferizationTransformDialectExtension> { | ||
public: | ||
BufferizationTransformDialectExtension() { | ||
declareDependentDialect<bufferization::BufferizationDialect>(); | ||
declareDependentDialect<pdl::PDLDialect>(); | ||
declareDependentDialect<memref::MemRefDialect>(); | ||
registerTransformOps< | ||
#define GET_OP_LIST | ||
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" | ||
>(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
#define GET_OP_CLASSES | ||
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" | ||
|
||
void mlir::bufferization::registerTransformDialectExtension( | ||
DialectRegistry ®istry) { | ||
registry.addExtensions<BufferizationTransformDialectExtension>(); | ||
} |
18 changes: 18 additions & 0 deletions
18
mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
add_mlir_dialect_library(MLIRBufferizationTransformOps | ||
BufferizationTransformOps.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/TransformOps | ||
|
||
DEPENDS | ||
MLIRBufferizationTransformOpsIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRBufferization | ||
MLIRBufferizationTransforms | ||
MLIRParser | ||
MLIRPDL | ||
MLIRSideEffectInterfaces | ||
MLIRTransformDialect | ||
) |
125 changes: 125 additions & 0 deletions
125
mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// RUN: mlir-opt --test-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s | ||
|
||
// Test One-Shot Bufferize. | ||
|
||
transform.with_pdl_patterns { | ||
^bb0(%arg0: !pdl.operation): | ||
sequence %arg0 { | ||
^bb0(%arg1: !pdl.operation): | ||
%0 = pdl_match @pdl_target in %arg1 | ||
transform.bufferization.one_shot_bufferize %0 | ||
{target_is_module = false} | ||
} | ||
|
||
pdl.pattern @pdl_target : benefit(1) { | ||
%0 = operation "func.func" | ||
rewrite %0 with "transform.dialect" | ||
} | ||
} | ||
|
||
// CHECK-LABEL: func @test_function( | ||
// CHECK-SAME: %[[A:.*]]: tensor<?xf32> | ||
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) { | ||
%c0 = arith.constant 0 : index | ||
|
||
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] | ||
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] | ||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) | ||
// CHECK: memref.copy %[[A_memref]], %[[alloc]] | ||
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]] | ||
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] | ||
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32> | ||
|
||
// CHECK: memref.dealloc %[[alloc]] | ||
// CHECK: return %[[res_tensor]] | ||
return %0 : tensor<?xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// Test analysis of One-Shot Bufferize only. | ||
|
||
transform.with_pdl_patterns { | ||
^bb0(%arg0: !pdl.operation): | ||
sequence %arg0 { | ||
^bb0(%arg1: !pdl.operation): | ||
%0 = pdl_match @pdl_target in %arg1 | ||
transform.bufferization.one_shot_bufferize %0 | ||
{target_is_module = false, test_analysis_only = true} | ||
} | ||
|
||
pdl.pattern @pdl_target : benefit(1) { | ||
%0 = operation "func.func" | ||
rewrite %0 with "transform.dialect" | ||
} | ||
} | ||
|
||
// CHECK-LABEL: func @test_function_analysis( | ||
// CHECK-SAME: %[[A:.*]]: tensor<?xf32> | ||
func.func @test_function_analysis(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) { | ||
%c0 = arith.constant 0 : index | ||
// CHECK: vector.transfer_write | ||
// CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} | ||
// CHECK-SAME: tensor<?xf32> | ||
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32> | ||
return %0 : tensor<?xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// Test One-Shot Bufferize transform failure with an unknown op. This would be | ||
// allowed with `allow_unknown_ops`. | ||
|
||
transform.with_pdl_patterns { | ||
^bb0(%arg0: !pdl.operation): | ||
sequence %arg0 { | ||
^bb0(%arg1: !pdl.operation): | ||
%0 = pdl_match @pdl_target in %arg1 | ||
// expected-error @+1 {{bufferization failed}} | ||
transform.bufferization.one_shot_bufferize %0 {target_is_module = false} | ||
} | ||
|
||
pdl.pattern @pdl_target : benefit(1) { | ||
%0 = operation "func.func" | ||
rewrite %0 with "transform.dialect" | ||
} | ||
} | ||
|
||
func.func @test_unknown_op_failure() -> (tensor<?xf32>) { | ||
// expected-error @+1 {{op was not bufferized}} | ||
%0 = "test.dummy_op"() : () -> (tensor<?xf32>) | ||
return %0 : tensor<?xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// Test One-Shot Bufferize transform failure with a module op. | ||
|
||
transform.with_pdl_patterns { | ||
^bb0(%arg0: !pdl.operation): | ||
sequence %arg0 { | ||
^bb0(%arg1: !pdl.operation): | ||
// %arg1 is the module | ||
transform.bufferization.one_shot_bufferize %arg1 | ||
} | ||
} | ||
|
||
module { | ||
// CHECK-LABEL: func @test_function( | ||
// CHECK-SAME: %[[A:.*]]: tensor<?xf32> | ||
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) { | ||
%c0 = arith.constant 0 : index | ||
|
||
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] | ||
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] | ||
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) | ||
// CHECK: memref.copy %[[A_memref]], %[[alloc]] | ||
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]] | ||
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] | ||
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32> | ||
|
||
// CHECK: memref.dealloc %[[alloc]] | ||
// CHECK: return %[[res_tensor]] | ||
return %0 : tensor<?xf32> | ||
} | ||
} |
Oops, something went wrong.