Skip to content

Commit

Permalink
[mlir][bufferization][NFC] Change signature of getMemRefType
Browse files Browse the repository at this point in the history
These functions now accep unsigned attributes for address spaces instead of Attributes.

Differential Revision: https://reviews.llvm.org/D128275
  • Loading branch information
matthias-springer committed Jun 27, 2022
1 parent 43c84e4 commit b06614e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -513,18 +513,17 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
BaseMemRefType getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
unsigned memorySpace = 0);

/// Return a MemRef type with fully dynamic layout. If the given tensor type
/// is unranked, return an unranked MemRef type.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
Attribute memorySpace = {});
unsigned memorySpace = 0);

/// Return a MemRef type with a static identity layout (i.e., no layout map). If
/// the given tensor type is unranked, return an unranked MemRef type.
BaseMemRefType
getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
Attribute memorySpace = {});
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
unsigned memorySpace = 0);

} // namespace bufferization
} // namespace mlir
Expand Down
21 changes: 14 additions & 7 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,20 +596,23 @@ bool bufferization::isFunctionArgument(Value value) {
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
unsigned memorySpace) {
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);

// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
assert(!layout && "UnrankedTensorType cannot have a layout map");
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
memorySpaceAttr);
}

// Case 2: Ranked memref type with specified layout.
auto rankedTensorType = tensorType.cast<RankedTensorType>();
if (layout) {
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
memorySpace);
memorySpaceAttr);
}

// Case 3: Configured with "fully dynamic layout maps".
Expand All @@ -627,14 +630,16 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,

BaseMemRefType
bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
Attribute memorySpace) {
unsigned memorySpace) {
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}

// Case 2: Ranked memref type.
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);
auto rankedTensorType = tensorType.cast<RankedTensorType>();
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
Expand All @@ -643,14 +648,14 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
dynamicStrides, dynamicOffset, rankedTensorType.getContext());
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), stridedLayout,
memorySpace);
memorySpaceAttr);
}

/// Return a MemRef type with a static identity layout (i.e., no layout map). If
/// the given tensor type is unranked, return an unranked MemRef type.
BaseMemRefType
bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
Attribute memorySpace) {
unsigned memorySpace) {
// Case 1: Unranked memref type.
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
Expand All @@ -659,8 +664,10 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,

// Case 2: Ranked memref type.
auto rankedTensorType = tensorType.cast<RankedTensorType>();
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);
MemRefLayoutAttrInterface layout = {};
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
memorySpace);
memorySpaceAttr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type.
Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType =
castOp.getResult().getType().cast<TensorType>();
MemRefLayoutAttrInterface layout;
Expand All @@ -65,7 +64,8 @@ struct CastOpInterface

// Compute the new memref type.
Type resultMemRefType =
getMemRefType(resultTensorType, options, layout, memorySpace);
getMemRefType(resultTensorType, options, layout,
sourceMemRefType.getMemorySpaceAsInt());

// Replace the op with a memref.cast.
assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
Expand Down

0 comments on commit b06614e

Please sign in to comment.