Skip to content

Commit

Permalink
[mlir][tensor][bufferize] Implement getBufferType for CastOp
Browse files Browse the repository at this point in the history
This interface method is used to compute the buffer type of a value during bufferization. It was missing. This is interface method is used during loop bufferization.

Also fix a bug where a cast from an unranked tensor to a ranked tensor type did not always apply a fully dynamic layout map on the result memref.

Differential Revision: https://reviews.llvm.org/D143063
  • Loading branch information
matthias-springer committed Feb 1, 2023
1 parent 0256280 commit b6ae3f8
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 16 deletions.
55 changes: 41 additions & 14 deletions mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Expand Up @@ -51,6 +51,39 @@ struct CastOpInterface
return BufferRelation::Equivalent;
}

FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto castOp = cast<tensor::CastOp>(op);
auto maybeSrcBufferType =
bufferization::getBufferType(castOp.getSource(), options, fixedTypes);
if (failed(maybeSrcBufferType))
return failure();
Attribute memorySpace = maybeSrcBufferType->getMemorySpace();

// Note: `getMemRefTypeWithFullyDynamicLayout` returns an unranked memref
// type in case the input is an unranked tensor type.

// Case 1: Casting an unranked tensor
if (castOp.getSource().getType().isa<UnrankedTensorType>()) {
// When casting to a ranked tensor, we cannot infer any static offset or
// strides from the source. Assume fully dynamic.
return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
}

// Case 2: Casting to an unranked tensor type
if (castOp.getType().isa<UnrankedTensorType>()) {
return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
}

// Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
// change.
auto rankedResultType = castOp.getType().cast<RankedTensorType>();
return MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto castOp = cast<tensor::CastOp>(op);
Expand All @@ -60,25 +93,19 @@ struct CastOpInterface
getBuffer(rewriter, castOp.getSource(), options);
if (failed(resultBuffer))
return failure();
auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
TensorType resultTensorType =
castOp.getResult().getType().cast<TensorType>();
MemRefLayoutAttrInterface layout;

if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
if (resultTensorType.isa<RankedTensorType>())
layout = rankedMemRefType.getLayout();

// Compute the new memref type.
Type resultMemRefType = getMemRefType(castOp.getResult(), options, layout,
sourceMemRefType.getMemorySpace());
// Compute the new type.
auto resultMemRefType =
bufferization::getBufferType(castOp.getResult(), options);
if (failed(resultMemRefType))
return failure();

// Replace the op with a memref.cast.
assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
resultMemRefType) &&
*resultMemRefType) &&
"CallOp::bufferize: cast incompatible");
replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
*resultBuffer);
replaceOpWithNewBufferizedOp<memref::CastOp>(
rewriter, op, *resultMemRefType, *resultBuffer);

return success();
}
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Expand Up @@ -925,3 +925,26 @@ func.func @non_block_argument_yield() {
}
return
}

// -----

// This is a regression test. Make sure that bufferization succeeds.

// CHECK-LABEL: func @regression_cast_in_loop(
func.func @regression_cast_in_loop() -> tensor<2xindex> {
%false = arith.constant false
%c0 = arith.constant 0 : index
%0 = bufferization.alloc_tensor() : tensor<2xindex>
// CHECK: scf.while (%{{.*}} = %{{.*}}) : (memref<2xindex>) -> memref<2xindex>
%1 = scf.while (%arg0 = %0) : (tensor<2xindex>) -> tensor<2xindex> {
scf.condition(%false) %arg0 : tensor<2xindex>
} do {
// CHECK: ^bb0(%{{.*}}: memref<2xindex>):
^bb0(%arg0: tensor<2xindex>):
%cast = tensor.cast %0 : tensor<2xindex> to tensor<?xindex>
%inserted = tensor.insert %c0 into %cast[%c0] : tensor<?xindex>
%cast_0 = tensor.cast %inserted : tensor<?xindex> to tensor<2xindex>
scf.yield %cast_0 : tensor<2xindex>
}
return %1 : tensor<2xindex>
}
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Tensor/bufferize.mlir
Expand Up @@ -40,8 +40,8 @@ func.func @tensor.cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
// CHECK-LABEL: func @tensor.cast_from_unranked(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> tensor<2xf32> {
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<*xf32>
// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32>
// CHECK: %[[CASTED_MEMREF:.*]] = memref.cast %[[MEMREF]] : memref<*xf32> to memref<2xf32, strided<[?], offset: ?>>
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[CASTED_MEMREF]] : memref<2xf32, strided<[?], offset: ?>>
// CHECK: return %[[RET]] : tensor<2xf32>
func.func @tensor.cast_from_unranked(%arg0: tensor<*xf32>) -> tensor<2xf32> {
%0 = tensor.cast %arg0 : tensor<*xf32> to tensor<2xf32>
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Expand Up @@ -347,3 +347,26 @@ func.func @dim_not_reading(%t: tensor<?xf32>, %f: f32, %pos: index)
%1 = tensor.dim %t, %c0 : tensor<?xf32>
return %0, %1 : tensor<?xf32>, index
}

// -----

// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0 + 5)>
// CHECK-LABEL: func.func @cast_retains_buffer_layout(
// CHECK-SAME: %[[t:.*]]: memref<?xf32, #[[$map]]>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, #[[$map]]> to memref<10xf32, #[[$map]]>
// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, #[[$map]]> to memref<?xf32, strided<[1], offset: 7>>
// CHECK: return %[[slice]]
func.func @cast_retains_buffer_layout(
%t: tensor<?xf32>
{bufferization.buffer_layout = affine_map<(d0) -> (d0 + 5)>},
%sz: index)
-> (tensor<10xf32>, tensor<?xf32>)
{
%casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
%slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor<?xf32>

// Note: The %casted return type is folded away because both buffers are
// equivalent. Therefore, we currently loose some static type information
// in the caller.
return %casted, %slice : tensor<10xf32>, tensor<?xf32>
}

0 comments on commit b6ae3f8

Please sign in to comment.