Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"

Expand Down Expand Up @@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
Transform_AnyOpType:$new_ops);
let assemblyFormat = "$target attr-dict `:` type($target)";
let hasVerifier = 1;
}

let builders = [
OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
];
//===----------------------------------------------------------------------===//
// PromoteTensorOp
//===----------------------------------------------------------------------===//

def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SameOperandsAndResultType]> {
let summary = "Request a tensor value to live in a specific memory space "
"after bufferization";
let description = [{
Requests that a tensor value lives in a specific memory space for its
lifetime. This is achieved by allocating a new tensor in the desired
memory space with `bufferization.alloc_tensor` and optionally materializing
the source value into that allocation with
`bufferization.materialize_in_destination`. All uses of the original value
are then redirected to the promoted value.

The generated code for promoting tensor value %0 resembles the following:

%1 = bufferization.alloc_tensor(<dynamic dims of %0>)
{ memory_space = memory_space }
// Note: the materialization is omitted if %0 is never read and is only
// written into (i.e., it behaves as a result tensor).
%2 = bufferization.materialize_in_destination %0 in %1
// ...
<all users of %0 now use %2 instead>

Deallocation is not handled by this transform.

Return modes:
- Produces a silenceable failure if the given handle does not point to
tensor-typed values.
- Succeeds otherwise and returns a handle to the promoted value(s), i.e.,
the result of materialization if present and the allocation otherwise.
}];

let arguments = (ins TransformValueHandleTypeInterface:$tensor,
OptionalAttr<AnyAttr>:$memory_space);
let results = (outs TransformValueHandleTypeInterface:$promoted);

let assemblyFormat =
"(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)";
}

//===----------------------------------------------------------------------===//
Expand Down
116 changes: 90 additions & 26 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
Expand Down Expand Up @@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//

void transform::BufferizeToAllocationOp::build(OpBuilder &b,
OperationState &result,
Value target,
Attribute memorySpace) {
SmallVector<Type> resultTypes;
resultTypes.push_back(b.getType<transform::AnyValueType>());
resultTypes.push_back(b.getType<transform::AnyOpType>());
return build(b, result,
/*resultTypes=*/resultTypes,
/*target=*/target,
/*memory_space=*/memorySpace);
}

void transform::BufferizeToAllocationOp::build(OpBuilder &b,
OperationState &result,
Value target,
int64_t memorySpace) {
SmallVector<Type> resultTypes;
resultTypes.push_back(b.getType<transform::AnyValueType>());
resultTypes.push_back(b.getType<transform::AnyOpType>());
return build(b, result,
/*resultTypes=*/resultTypes,
/*target=*/target,
/*memory_space=*/b.getI64IntegerAttr(memorySpace));
}

namespace {
class NewOpsListener : public RewriterBase::ForwardingListener {
public:
Expand Down Expand Up @@ -408,6 +383,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// PromoteTensorOp
//===----------------------------------------------------------------------===//

/// Return true if the operand may be read from by its owner. This is currently
/// very conservative and only looks inside linalg operations to prevent
/// unintentional data loss.
static bool mayBeRead(OpOperand &operand) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());

// Be conservative about ops we cannot analyze deeper.
if (!linalgOp)
return true;

// Look inside linalg ops.
Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
return !blockArgument.use_empty();
}

/// Return true if the value may be read through any of its uses.
static bool mayBeRead(Value value) {
// If the value has a reference semantics, it
// may be read through any alias...
if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
return true;
return llvm::any_of(value.getUses(),
static_cast<bool (&)(OpOperand &)>(mayBeRead));
}

DiagnosedSilenceableFailure
transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Value> promoted;
for (Value tensor : state.getPayloadValues(getTensor())) {
auto type = dyn_cast<RankedTensorType>(tensor.getType());
if (!type) {
return emitSilenceableError() << "non-tensor type: " << tensor;
}

Operation *definingOp = tensor.getDefiningOp();
if (definingOp)
rewriter.setInsertionPointAfter(definingOp);
else
rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());

// Check this before we emit operations using this value.
bool needsMaterialization = mayBeRead(tensor);

SmallVector<Value> dynamicDims;
llvm::SmallPtrSet<Operation *, 4> preservedOps;
for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
if (!ShapedType::isDynamic(dim))
continue;
Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
preservedOps.insert(dimOp);
dynamicDims.push_back(dimOp);
}
auto allocation = rewriter.create<bufferization::AllocTensorOp>(
tensor.getLoc(), type, dynamicDims);
// Set memory space if provided.
if (getMemorySpaceAttr())
allocation.setMemorySpaceAttr(getMemorySpaceAttr());
Value allocated = allocation;

// Only insert a materialization (typically bufferizes to a copy) when the
// value may be read from.
if (needsMaterialization) {
auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
tensor.getLoc(), tensor, allocated);
preservedOps.insert(copy);
promoted.push_back(copy.getResult());
} else {
promoted.push_back(allocated);
}
rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
}
results.setValues(cast<OpResult>(getPromoted()), promoted);
return DiagnosedSilenceableFailure::success();
}

void transform::PromoteTensorOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTensorMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 0 additions & 6 deletions mlir/python/mlir/dialects/transform/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,12 @@ def __init__(
loc=None,
ip=None,
):
# No other types are allowed, so hard-code those here.
allocated_buffer_type = transform.AnyValueType.get()
new_ops_type = transform.AnyOpType.get()

if isinstance(memory_space, int):
memory_space = str(memory_space)
if isinstance(memory_space, str):
memory_space = Attribute.parse(memory_space)

super().__init__(
allocated_buffer_type,
new_ops_type,
target,
memory_space=memory_space,
memcpy_op=memcpy_op,
Expand Down
104 changes: 104 additions & 0 deletions mlir/test/Dialect/Transform/test-promote-tensors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s

// CHECK-LABEL: @promote_in0
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}})
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64}
// CHECK: %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]]
// CHECK: linalg.matmul ins(%[[MAT]], %{{.*}}
func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op) {
%mm = transform.structured.match ops{["linalg.matmul"]} in %root
: (!transform.any_op) -> !transform.any_op
%op0 = transform.get_operand %mm[0]
: (!transform.any_op) -> !transform.any_value
transform.structured.promote_tensor to 1 %op0 : !transform.any_value
transform.yield
}
}

// -----

// CHECK-LABEL: @promote_out
// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64}
// CHECK-NOT: materialize_in_destination
// CHECK: linalg.add {{.*}} outs(%[[ALLOC]]
%0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op) {
%la = transform.structured.match ops{["linalg.add"]} in %root
: (!transform.any_op) -> !transform.any_op
%init = transform.get_operand %la[2]
: (!transform.any_op) -> !transform.any_value
transform.structured.promote_tensor to 1 %init : !transform.any_value

transform.yield
}
}

// -----

// CHECK-LABEL: @promote_in0_out_bufferize
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>>
// CHECK: %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
// CHECK: %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
// CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK: %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1>
// CHECK: %{{.+}} = arith.constant 0 : index
// CHECK: %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>>
// CHECK: %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1>
// CHECK: memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1>
// CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>)
%0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op) {
%la = transform.structured.match ops{["linalg.add"]} in %root
: (!transform.any_op) -> !transform.any_op
%op0 = transform.get_operand %la[0]
: (!transform.any_op) -> !transform.any_value
transform.structured.promote_tensor to 1 %op0 : !transform.any_value

%init = transform.get_operand %la[2]
: (!transform.any_op) -> !transform.any_value
transform.structured.promote_tensor to 1 %init : !transform.any_value

%func = transform.structured.match ops{["func.func"]} in %root
: (!transform.any_op) -> !transform.any_op

%bufferized = transform.bufferization.one_shot_bufferize %func
: (!transform.any_op) -> !transform.any_op

transform.yield
}
}