Skip to content

Commit

Permalink
[mlir][Bufferization] Add support for controlled bufferization of all…
Browse files Browse the repository at this point in the history
…oc_tensor (#70957)

This revision adds support to
`transform.structured.bufferize_to_allocation` to bufferize
`bufferization.alloc_tensor()` ops.
    
This is useful as a means path to control the bufferization of
`tensor.empty` ops that have bene previously
`bufferization.empty_tensor_to_alloc_tensor`'ed.
  • Loading branch information
nicolasvasilache committed Nov 2, 2023
1 parent 65bad23 commit 3a223f4
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <utility>

#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
Expand All @@ -28,6 +29,7 @@

namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
} // namespace bufferization

Expand Down Expand Up @@ -110,6 +112,18 @@ Value bufferizeToAllocation(RewriterBase &rewriter,
vector::MaskOp maskOp, Attribute memorySpace = {},
Operation *insertionPoint = nullptr);

/// Materialize a buffer allocation for the given bufferization.alloc_tensor op
/// and lower the op to memref.alloc + memref.tensor_store.
///
/// In addition to rewriting the IR, this function returns the newly allocated
/// buffer. The `insertionPoint` parameter can be used to specify a custom
/// insertion point for the buffer allocation.
Value bufferizeToAllocation(RewriterBase &rewriter,
const BufferizeToAllocationOptions &options,
bufferization::AllocTensorOp allocTensorOp,
Attribute memorySpace = {},
Operation *insertionPoint = nullptr);

/// Bufferize the given op with tensor semantics and materialize the result in
/// a newly allocated buffer.
///
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,27 @@ Value linalg::bufferizeToAllocation(
return alloc;
}

Value linalg::bufferizeToAllocation(
RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
Operation *insertionPoint) {
Location loc = allocTensorOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
bufferization::BufferizationOptions bufferizationOptions;

// Create buffer allocation.
Value alloc = createAllocationForTensor(
rewriter, loc, allocTensorOp.getResult(), options, memorySpace);

// Create bufferization.to_tensor with "restrict" and "writable". The returned
// tensor is a new buffer allocation, so it does not alias with any buffer.
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
loc, alloc, /*restrict=*/true, /*writable=*/true);
rewriter.replaceOp(allocTensorOp, toTensorOp);
return alloc;
}

/// Lower tensor.from_elements to a sequence of chained tensor.insert.
FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
Expand Down Expand Up @@ -454,6 +475,8 @@ Value linalg::bufferizeToAllocation(
return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
if (auto maskOp = dyn_cast<vector::MaskOp>(op))
return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);

// Only bufferizable ops are supported.
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,26 @@ func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %
}
return
}

// -----

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
: (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
{alloc_op = "memref.alloca"}
: !transform.op<"bufferization.alloc_tensor">
transform.yield
}
}

// Expect `bufferization.bufferize_to_allocation` to create an alloc.
// CHECK-LABEL: func.func @empty_to_tensor_alloc()
func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
// CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32>
// CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32>
// CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32>
%0 = bufferization.alloc_tensor() : tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}

0 comments on commit 3a223f4

Please sign in to comment.