Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
atLeastOneReplacement |= replaceConstantUsesOf(
builder, getLoc(), getStrides(), getConstifiedMixedStrides());

// extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
if (auto prev = getSource().getDefiningOp<CastOp>())
if (isa<MemRefType>(prev.getSource().getType())) {
getSourceMutable().assign(prev.getSource());
atLeastOneReplacement = true;
}

return success(atLeastOneReplacement);
}

Expand Down
87 changes: 0 additions & 87 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};

/// Replace `base, offset, sizes, strides =
/// extract_strided_metadata(
/// cast(src) to dstTy)`
/// With
/// ```
/// base, ... = extract_strided_metadata(src)
/// offset = !dstTy.srcOffset.isDynamic()
/// ? dstTy.srcOffset
/// : extract_strided_metadata(src).offset
/// sizes = for each srcSize in dstTy.srcSizes:
/// !srcSize.isDynamic()
/// ? srcSize
// : extract_strided_metadata(src).sizes[i]
/// strides = for each srcStride in dstTy.srcStrides:
/// !srcStrides.isDynamic()
/// ? srcStrides
/// : extract_strided_metadata(src).strides[i]
/// ```
///
/// In other words, consume the `cast` and apply its effects
/// on the offset, sizes, and strides or compute them directly from `src`.
class ExtractStridedMetadataOpCastFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
PatternRewriter &rewriter) const override {
Value source = extractStridedMetadataOp.getSource();
auto castOp = source.getDefiningOp<memref::CastOp>();
if (!castOp)
return failure();

Location loc = extractStridedMetadataOp.getLoc();
// Check if the source is suitable for extract_strided_metadata.
SmallVector<Type> inferredReturnTypes;
if (failed(extractStridedMetadataOp.inferReturnTypes(
rewriter.getContext(), loc, {castOp.getSource()},
/*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
inferredReturnTypes)))
return rewriter.notifyMatchFailure(castOp,
"cast source's type is incompatible");

auto memrefType = cast<MemRefType>(source.getType());
unsigned rank = memrefType.getRank();
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);

auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
rewriter, loc, castOp.getSource());

// Register the base_buffer.
results[0] = newExtractStridedMetadata.getBaseBuffer();

auto getConstantOrValue = [&rewriter](int64_t constant,
OpFoldResult ofr) -> OpFoldResult {
return ShapedType::isStatic(constant)
? OpFoldResult(rewriter.getIndexAttr(constant))
: ofr;
};

auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
assert(sourceStrides.size() == rank && "unexpected number of strides");

// Register the new offset.
results[1] =
getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());

const unsigned sizeStartIdx = 2;
const unsigned strideStartIdx = sizeStartIdx + rank;
ArrayRef<int64_t> sourceSizes = memrefType.getShape();

SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
for (unsigned i = 0; i < rank; ++i) {
results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
results[strideStartIdx + i] =
getConstantOrValue(sourceStrides[i], strides[i]);
}
rewriter.replaceOp(extractStridedMetadataOp,
getValueOrCreateConstantIndexOp(rewriter, loc, results));
return success();
}
};

/// Replace `base, offset, sizes, strides = extract_strided_metadata(
/// memory_space_cast(src) to dstTy)`
/// with
Expand Down Expand Up @@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
Expand All @@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
Expand Down
126 changes: 126 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,132 @@ func.func @scope_merge_without_terminator() {

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// When we apply the transformation the resulting offset, sizes and strides
// should come straight from the inputs of the cast.
// Additionally the folder on extract_strided_metadata should propagate the
// static information.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
//
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
func.func @extract_strided_metadata_of_cast(
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<3x?xi32, strided<[4, ?], offset: ?>> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
// in the destination type.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
//
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
func.func @extract_strided_metadata_of_cast_w_csts(
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<?x?xi32, strided<[?, ?], offset: ?>> to
memref<4x?xi32, strided<[?, 18], offset: 25>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

// -----

// Check that we don't simplify extract_strided_metadata of
// cast when the source of the cast is unranked.
// Unranked memrefs cannot feed into extract_strided_metadata operations.
// Note: Technically we could still fold the sizes and strides.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
//
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
//
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
func.func @extract_strided_metadata_of_cast_unranked(
%arg : memref<*xi32>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<*xi32> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

// -----

// CHECK-LABEL: func @reinterpret_noop
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
// CHECK-NEXT: return %[[ARG]]
Expand Down
127 changes: 0 additions & 127 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1376,133 +1376,6 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
memref<i32>, index, index, index, index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// When we apply the transformation the resulting offset, sizes and strides
// should come straight from the inputs of the cast.
// Additionally the folder on extract_strided_metadata should propagate the
// static information.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
//
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
func.func @extract_strided_metadata_of_cast(
%arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<3x?xi32, strided<[4, ?], offset: ?>> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

// -----

// Check that we simplify extract_strided_metadata of cast
// when the source of the cast is compatible with what
// `extract_strided_metadata`s accept.
//
// Same as extract_strided_metadata_of_cast but with constant sizes and strides
// in the destination type.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
//
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
//
// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
func.func @extract_strided_metadata_of_cast_w_csts(
%arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<?x?xi32, strided<[?, ?], offset: ?>> to
memref<4x?xi32, strided<[?, 18], offset: 25>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}

// -----

// Check that we don't simplify extract_strided_metadata of
// cast when the source of the cast is unranked.
// Unranked memrefs cannot feed into extract_strided_metadata operations.
// Note: Technically we could still fold the sizes and strides.
//
// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
//
// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
//
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
func.func @extract_strided_metadata_of_cast_unranked(
%arg : memref<*xi32>)
-> (memref<i32>, index,
index, index,
index, index) {

%cast =
memref.cast %arg :
memref<*xi32> to
memref<?x?xi32, strided<[?, ?], offset: ?>>

%base, %base_offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
-> memref<i32>, index,
index, index,
index, index

return %base, %base_offset,
%sizes#0, %sizes#1,
%strides#0, %strides#1 :
memref<i32>, index,
index, index,
index, index
}


// -----

memref.global "private" @dynamicShmem : memref<0xf16,3>
Expand Down