Skip to content

Commit

Permalink
[mlir][linalg][bufferize] Add FuncOp bufferization pass
Browse files Browse the repository at this point in the history
This passes bufferizes FuncOp bodies, but not FuncOp boundaries.

Differential Revision: https://reviews.llvm.org/D114671
  • Loading branch information
matthias-springer committed Dec 7, 2021
1 parent e7f53ec commit 8a23263
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 0 deletions.
69 changes: 69 additions & 0 deletions mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir
@@ -0,0 +1,69 @@
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s

// Run fuzzer with different seeds.
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null

// CHECK-LABEL: func @use_tensor_func_arg(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func @use_tensor_func_arg(%A : tensor<?xf32>) -> (vector<4xf32>) {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32

// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
// CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]]
%0 = vector.transfer_read %A[%c0], %f0 : tensor<?xf32>, vector<4xf32>

// CHECK: return %[[res]]
return %0 : vector<4xf32>
}

// -----

// CHECK-LABEL: func @return_tensor(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index

// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
// CHECK: %[[dim:.*]] = tensor.dim %[[A]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK: memref.copy %[[A_memref]], %[[casted]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>

// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
// CHECK: return %[[res_tensor]]
return %0 : tensor<?xf32>
}

// -----

// CHECK-LABEL: func @func_without_tensor_args
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
// CHECK: %[[alloc:.*]] = memref.alloc()
%0 = linalg.init_tensor[10] : tensor<10xf32>

%c0 = arith.constant 0 : index
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
%1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>

%cst = arith.constant 0.0 : f32
// CHECK: vector.transfer_read %[[alloc]]
%r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>

vector.print %r : vector<11xf32>
return
}

// -----

// CHECK-LABEL: func private @private_func
func private @private_func(tensor<?xf32>) -> ()

// CHECK-LABEL: func @empty_func()
func @empty_func() -> () {
return
}
29 changes: 29 additions & 0 deletions mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Expand Up @@ -979,3 +979,32 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
}
return %1: tensor<?xf32>
}

// -----

// CHECK-LABEL: func @func_without_tensor_args
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
// CHECK: %[[alloc:.*]] = memref.alloc()
%0 = linalg.init_tensor[10] : tensor<10xf32>

%c0 = arith.constant 0 : index
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
%1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>

%cst = arith.constant 0.0 : f32
// CHECK: vector.transfer_read %[[alloc]]
%r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>

vector.print %r : vector<11xf32>
return
}

// -----

// CHECK-LABEL: func private @private_func
func private @private_func(tensor<?xf32>) -> ()

// CHECK-LABEL: func @empty_func()
func @empty_func() -> () {
return
}
13 changes: 13 additions & 0 deletions mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLinalgTestPasses
TestComprehensiveBufferize.cpp
TestConvVectorization.cpp
TestLinalgCodegenStrategy.cpp
TestLinalgDistribution.cpp
Expand All @@ -12,13 +13,25 @@ add_mlir_library(MLIRLinalgTestPasses

LINK_LIBS PUBLIC
MLIRAffine
MLIRAffineBufferizableOpInterfaceImpl
MLIRArithBufferizableOpInterfaceImpl
MLIRArithmetic
MLIRBufferizableOpInterface
MLIRComprehensiveBufferize
MLIRGPUTransforms
MLIRLinalg
MLIRLinalgBufferizableOpInterfaceImpl
MLIRLinalgTransforms
MLIRLLVMToLLVMIRTranslation
MLIRMemRef
MLIRPass
MLIRSCF
MLIRSCFBufferizableOpInterfaceImpl
MLIRStandard
MLIRTensor
MLIRTensorBufferizableOpInterfaceImpl
MLIRTransformUtils
MLIRVector
MLIRVectorBufferizableOpInterfaceImpl
MLIRVectorToSCF
)
124 changes: 124 additions & 0 deletions mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -0,0 +1,124 @@
//===- TestComprehensiveBufferize.cpp - Test Comprehensive Bufferize ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic for testing Comprehensive Bufferize.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::linalg::comprehensive_bufferize;

namespace {
/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are
/// mostly identical.
struct TestComprehensiveFunctionBufferize
: public PassWrapper<TestComprehensiveFunctionBufferize, FunctionPass> {
StringRef getArgument() const final {
return "test-comprehensive-function-bufferize";
}

StringRef getDescription() const final {
return "Test Comprehensive Bufferize of FuncOps (body only).";
}

TestComprehensiveFunctionBufferize() = default;
TestComprehensiveFunctionBufferize(
const TestComprehensiveFunctionBufferize &pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
memref::MemRefDialect, tensor::TensorDialect,
vector::VectorDialect, scf::SCFDialect,
arith::ArithmeticDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
}

void runOnFunction() override;

Option<bool> allowReturnMemref{
*this, "allow-return-memref",
llvm::cl::desc("Allow returning/yielding memrefs from functions/blocks"),
llvm::cl::init(false)};
Option<bool> allowUnknownOps{
*this, "allow-unknown-ops",
llvm::cl::desc(
"Allows the return of memrefs (for testing purposes only)"),
llvm::cl::init(false)};
Option<bool> testAnalysisOnly{
*this, "test-analysis-only",
llvm::cl::desc(
"Only runs inplaceability analysis (for testing purposes only)"),
llvm::cl::init(false)};
Option<unsigned> analysisFuzzerSeed{
*this, "analysis-fuzzer-seed",
llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"),
llvm::cl::init(0)};
};
} // namespace

void TestComprehensiveFunctionBufferize::runOnFunction() {
BufferizationOptions options;

// Enable InitTensorOp elimination.
options.addPostAnalysisStep<
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
// TODO: Find a way to enable this step automatically when bufferizing
// tensor dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();

options.allowReturnMemref = allowReturnMemref;
options.allowUnknownOps = allowUnknownOps;
options.testAnalysisOnly = testAnalysisOnly;
options.analysisFuzzerSeed = analysisFuzzerSeed;

Operation *op = getFunction().getOperation();
if (failed(runComprehensiveBufferize(op, options)))
return;

OpPassManager cleanupPipeline("builtin.func");
cleanupPipeline.addPass(createCanonicalizerPass());
cleanupPipeline.addPass(createCSEPass());
cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
(void)this->runPipeline(cleanupPipeline, op);
}

namespace mlir {
namespace test {
void registerTestComprehensiveFunctionBufferize() {
PassRegistration<TestComprehensiveFunctionBufferize>();
}
} // namespace test
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Expand Up @@ -64,6 +64,7 @@ void registerTestAffineLoopParametricTilingPass();
void registerTestAliasAnalysisPass();
void registerTestBuiltinAttributeInterfaces();
void registerTestCallGraphPass();
void registerTestComprehensiveFunctionBufferize();
void registerTestConstantFold();
void registerTestConvVectorization();
void registerTestGpuSerializeToCubinPass();
Expand Down Expand Up @@ -159,6 +160,7 @@ void registerTestPasses() {
#if MLIR_ROCM_CONVERSIONS_ENABLED
mlir::test::registerTestGpuSerializeToHsacoPass();
#endif
mlir::test::registerTestComprehensiveFunctionBufferize();
mlir::test::registerTestConvVectorization();
mlir::test::registerTestDecomposeCallGraphTypes();
mlir::test::registerTestDataLayoutQuery();
Expand Down
12 changes: 12 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Expand Up @@ -381,15 +381,27 @@ cc_library(
deps = [
"//llvm:Support",
"//mlir:Affine",
"//mlir:AffineBufferizableOpInterfaceImpl",
"//mlir:ArithBufferizableOpInterfaceImpl",
"//mlir:ArithmeticDialect",
"//mlir:BufferizableOpInterface",
"//mlir:BufferizationDialect",
"//mlir:ComprehensiveBufferize",
"//mlir:GPUDialect",
"//mlir:IR",
"//mlir:LinalgBufferizableOpInterfaceImpl",
"//mlir:LinalgOps",
"//mlir:LinalgTransforms",
"//mlir:MemRefDialect",
"//mlir:Pass",
"//mlir:SCFBufferizableOpInterfaceImpl",
"//mlir:SCFDialect",
"//mlir:SCFTransforms",
"//mlir:StandardOps",
"//mlir:TensorBufferizableOpInterfaceImpl",
"//mlir:TensorDialect",
"//mlir:TransformUtils",
"//mlir:VectorBufferizableOpInterfaceImpl",
"//mlir:VectorOps",
"//mlir:VectorToSCF",
],
Expand Down

0 comments on commit 8a23263

Please sign in to comment.