Skip to content

Commit

Permalink
[mlir][bufferization] Add OneShotBufferize transform op
Browse files Browse the repository at this point in the history
This commit allows for One-Shot Bufferize to be used through the transform dialect. No op handle is currently returned for the bufferized IR.

Differential Revision: https://reviews.llvm.org/D125098
  • Loading branch information
matthias-springer committed Jun 9, 2022
1 parent cedfb54 commit 461dafd
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Bufferization/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(TransformOps)
add_subdirectory(Transforms)
@@ -0,0 +1,30 @@
//===- BufferizationTransformOps.h - Buff. transf. ops ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"

//===----------------------------------------------------------------------===//
// Bufferization Transform Operations
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace bufferization {
void registerTransformDialectExtension(DialectRegistry &registry);
} // namespace bufferization
} // namespace mlir

#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
@@ -0,0 +1,58 @@
//===- BufferizationTransformOps.td - Buff. transf. ops ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef BUFFERIZATION_TRANSFORM_OPS
#define BUFFERIZATION_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def OneShotBufferizeOp
: Op<Transform_Dialect, "bufferization.one_shot_bufferize",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Indicates that the given `target` op should be bufferized with One-Shot
Bufferize. The bufferization can be configured with various attributes that
corresponding to options in `BufferizationOptions` and the
`one-shot-bufferize` pass. More information can be found in the pass
documentation.

If `target_is_module` is set, `target` must be a module. In that case the
`target` handle can be reused by other transform ops. When bufferizing other
ops, the `target` handled is freed after bufferization and can no longer be
used.

Note: Only ops that implement `BufferizableOpInterface` are bufferized. All
other ops are ignored if `allow_unknown_ops`. If `allow_unknown_ops` is
unset, this transform fails when an unknown/non-bufferizable op is found.
Many ops implement `BufferizableOpInterface` via an external model. These
external models must be registered when applying this transform op;
otherwise, said ops would be considered non-bufferizable.
}];

let arguments = (
ins PDL_Operation:$target,
DefaultValuedAttr<BoolAttr, "false">:$allow_return_allocs,
DefaultValuedAttr<BoolAttr, "false">:$allow_unknown_ops,
DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
DefaultValuedAttr<BoolAttr, "true">:$create_deallocs,
DefaultValuedAttr<BoolAttr, "true">:$target_is_module,
DefaultValuedAttr<BoolAttr, "false">:$test_analysis_only,
DefaultValuedAttr<BoolAttr, "false">:$print_conflicts);

let results = (outs);

let assemblyFormat = "$target attr-dict";
}

#endif // BUFFERIZATION_TRANSFORM_OPS
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS BufferizationTransformOps.td)
mlir_tablegen(BufferizationTransformOps.h.inc -gen-op-decls)
mlir_tablegen(BufferizationTransformOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRBufferizationTransformOpsIncGen)
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
Expand Down Expand Up @@ -107,6 +108,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
// clang-format on

// Register all dialect extensions.
bufferization::registerTransformDialectExtension(registry);
linalg::registerTransformDialectExtension(registry);
scf::registerTransformDialectExtension(registry);

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Bufferization/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(TransformOps)
add_subdirectory(Transforms)
@@ -0,0 +1,96 @@
//===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::transform;

//===----------------------------------------------------------------------===//
// OneShotBufferizeOp
//===----------------------------------------------------------------------===//

LogicalResult
transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
TransformState &state) {
OneShotBufferizationOptions options;
options.allowReturnAllocs = getAllowReturnAllocs();
options.allowUnknownOps = getAllowUnknownOps();
options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
options.createDeallocs = getCreateDeallocs();
options.testAnalysisOnly = getTestAnalysisOnly();
options.printConflicts = getPrintConflicts();

ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
for (Operation *target : payloadOps) {
auto moduleOp = dyn_cast<ModuleOp>(target);
if (getTargetIsModule() && !moduleOp)
return emitError("expected ModuleOp target");
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitError("expected ModuleOp target");
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
return emitError("bufferization failed");
} else {
if (failed(bufferization::runOneShotBufferize(target, options)))
return emitError("bufferization failed");
}
}

return success();
}

void transform::OneShotBufferizeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
TransformMappingResource::get());

// Handles that are not modules are not longer usable.
if (!getTargetIsModule())
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
TransformMappingResource::get());
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
/// Registers new ops and declares PDL as dependent dialect since the additional
/// ops are using PDL types for operands and results.
class BufferizationTransformDialectExtension
: public transform::TransformDialectExtension<
BufferizationTransformDialectExtension> {
public:
BufferizationTransformDialectExtension() {
declareDependentDialect<bufferization::BufferizationDialect>();
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<memref::MemRefDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"

void mlir::bufferization::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<BufferizationTransformDialectExtension>();
}
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
@@ -0,0 +1,18 @@
add_mlir_dialect_library(MLIRBufferizationTransformOps
BufferizationTransformOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/TransformOps

DEPENDS
MLIRBufferizationTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRBufferization
MLIRBufferizationTransforms
MLIRParser
MLIRPDL
MLIRSideEffectInterfaces
MLIRTransformDialect
)
125 changes: 125 additions & 0 deletions mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
@@ -0,0 +1,125 @@
// RUN: mlir-opt --test-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s

// Test One-Shot Bufferize.

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
transform.bufferization.one_shot_bufferize %0
{target_is_module = false}
}

pdl.pattern @pdl_target : benefit(1) {
%0 = operation "func.func"
rewrite %0 with "transform.dialect"
}
}

// CHECK-LABEL: func @test_function(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index

// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>

// CHECK: memref.dealloc %[[alloc]]
// CHECK: return %[[res_tensor]]
return %0 : tensor<?xf32>
}

// -----

// Test analysis of One-Shot Bufferize only.

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
transform.bufferization.one_shot_bufferize %0
{target_is_module = false, test_analysis_only = true}
}

pdl.pattern @pdl_target : benefit(1) {
%0 = operation "func.func"
rewrite %0 with "transform.dialect"
}
}

// CHECK-LABEL: func @test_function_analysis(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func.func @test_function_analysis(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index
// CHECK: vector.transfer_write
// CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]}
// CHECK-SAME: tensor<?xf32>
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
return %0 : tensor<?xf32>
}

// -----

// Test One-Shot Bufferize transform failure with an unknown op. This would be
// allowed with `allow_unknown_ops`.

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
// expected-error @+1 {{bufferization failed}}
transform.bufferization.one_shot_bufferize %0 {target_is_module = false}
}

pdl.pattern @pdl_target : benefit(1) {
%0 = operation "func.func"
rewrite %0 with "transform.dialect"
}
}

func.func @test_unknown_op_failure() -> (tensor<?xf32>) {
// expected-error @+1 {{op was not bufferized}}
%0 = "test.dummy_op"() : () -> (tensor<?xf32>)
return %0 : tensor<?xf32>
}

// -----

// Test One-Shot Bufferize transform failure with a module op.

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
// %arg1 is the module
transform.bufferization.one_shot_bufferize %arg1
}
}

module {
// CHECK-LABEL: func @test_function(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index

// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>

// CHECK: memref.dealloc %[[alloc]]
// CHECK: return %[[res_tensor]]
return %0 : tensor<?xf32>
}
}

0 comments on commit 461dafd

Please sign in to comment.