diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 8b76930aed35a..d40cb5ee8a064 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -14,10 +14,15 @@ #ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H +#include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" namespace mlir { +class Location; class OpBuilder; class RewritePatternSet; class RewriterBase; @@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter; namespace memref { class AllocOp; class AllocaOp; +class CollapseShapeOp; class DeallocOp; +class ExpandShapeOp; //===----------------------------------------------------------------------===// // Patterns @@ -213,6 +220,100 @@ FailureOr replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocToAlloca( RewriterBase &rewriter, memref::AllocOp alloc, function_ref filter = nullptr); + +/// Compute the expanded sizes of the given \p expandShape for the +/// \p groupId-th reassociation group. +/// \p origSizes hold the sizes of the source shape as values. +/// This is used to compute the new sizes in cases of dynamic shapes. +/// +/// sizes#i = +/// baseSizes#groupId / product(expandShapeSizes#j, +/// for j in group excluding reassIdx#i) +/// Where reassIdx#i is the reassociation index at index i in \p groupId. +/// +/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() +/// +/// TODO: Move this utility function directly within ExpandShapeOp. For now, +/// this is not possible because this function uses the Affine dialect and the +/// MemRef dialect cannot depend on the Affine dialect. +SmallVector getExpandedSizes(ExpandShapeOp expandShape, + OpBuilder &builder, + ArrayRef origSizes, + unsigned groupId); + +/// Compute the expanded strides of the given \p expandShape for the +/// \p groupId-th reassociation group. +/// \p origStrides and \p origSizes hold respectively the strides and sizes +/// of the source shape as values. +/// This is used to compute the strides in cases of dynamic shapes and/or +/// dynamic stride for this reassociation group. +/// +/// strides#i = +/// origStrides#reassDim * product(expandShapeSizes#j, for j in +/// reassIdx#i+1..reassIdx#i+group.size-1) +/// +/// Where reassIdx#i is the reassociation index for at index i in \p groupId +/// and expandShapeSizes#j is either: +/// - The constant size at dimension j, derived directly from the result type of +/// the expand_shape op, or +/// - An affine expression: baseSizes#reassDim / product of all constant sizes +/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic +/// element.) +/// +/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() +/// +/// TODO: Move this utility function directly within ExpandShapeOp. For now, +/// this is not possible because this function uses the Affine dialect and the +/// MemRef dialect cannot depend on the Affine dialect. +SmallVector getExpandedStrides(ExpandShapeOp expandShape, + OpBuilder &builder, + ArrayRef origSizes, + ArrayRef origStrides, + unsigned groupId); + +/// Produce an OpFoldResult object with \p builder at \p loc representing +/// `prod(valueOrConstant#i, for i in {indices})`, +/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, +/// values[i] otherwise. +/// +/// \pre for all index in indices: index < values.size() +/// \pre for all index in indices: index < maybeConstants.size() +OpFoldResult getProductOfValues(ArrayRef indices, OpBuilder &builder, + Location loc, ArrayRef maybeConstants, + ArrayRef values, + llvm::function_ref isDynamic); + +/// Compute the collapsed size of the given \p collapseShape for the +/// \p groupId-th reassociation group. +/// \p origSizes hold the sizes of the source shape as values. +/// This is used to compute the new sizes in cases of dynamic shapes. +/// +/// TODO: Move this utility function directly within CollapseShapeOp. For now, +/// this is not possible because this function uses the Affine dialect and the +/// MemRef dialect cannot depend on the Affine dialect. +SmallVector getCollapsedSize(CollapseShapeOp collapseShape, + OpBuilder &builder, + ArrayRef origSizes, + unsigned groupId); + +/// Compute the collapsed stride of the given \p collpaseShape for the +/// \p groupId-th reassociation group. +/// \p origStrides and \p origSizes hold respectively the strides and sizes +/// of the source shape as values. +/// This is used to compute the strides in cases of dynamic shapes and/or +/// dynamic stride for this reassociation group. +/// +/// Conceptually this helper function returns the stride of the inner most +/// dimension of that group in the original shape. +/// +/// \post result.size() == 1, in other words, each group collapse to one +/// dimension. +SmallVector getCollapsedStride(CollapseShapeOp collapseShape, + OpBuilder &builder, + ArrayRef origSizes, + ArrayRef origStrides, + unsigned groupId); + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index d35566a9c0d29..96bceae88af9d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -35,6 +36,7 @@ namespace memref { using namespace mlir; using namespace mlir::affine; +using namespace mlir::memref; namespace { @@ -250,24 +252,14 @@ struct ExtractStridedMetadataOpSubviewFolder } }; -/// Compute the expanded sizes of the given \p expandShape for the -/// \p groupId-th reassociation group. -/// \p origSizes hold the sizes of the source shape as values. -/// This is used to compute the new sizes in cases of dynamic shapes. -/// -/// sizes#i = -/// baseSizes#groupId / product(expandShapeSizes#j, -/// for j in group excluding reassIdx#i) -/// Where reassIdx#i is the reassociation index at index i in \p groupId. -/// -/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() -/// -/// TODO: Move this utility function directly within ExpandShapeOp. For now, -/// this is not possible because this function uses the Affine dialect and the -/// MemRef dialect cannot depend on the Affine dialect. -static SmallVector -getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, - ArrayRef origSizes, unsigned groupId) { +} // namespace + +namespace mlir { +namespace memref { +SmallVector getExpandedSizes(ExpandShapeOp expandShape, + OpBuilder &builder, + ArrayRef origSizes, + unsigned groupId) { SmallVector reassocGroup = expandShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && @@ -305,31 +297,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, return expandedSizes; } -/// Compute the expanded strides of the given \p expandShape for the -/// \p groupId-th reassociation group. -/// \p origStrides and \p origSizes hold respectively the strides and sizes -/// of the source shape as values. -/// This is used to compute the strides in cases of dynamic shapes and/or -/// dynamic stride for this reassociation group. -/// -/// strides#i = -/// origStrides#reassDim * product(expandShapeSizes#j, for j in -/// reassIdx#i+1..reassIdx#i+group.size-1) -/// -/// Where reassIdx#i is the reassociation index for at index i in \p groupId -/// and expandShapeSizes#j is either: -/// - The constant size at dimension j, derived directly from the result type of -/// the expand_shape op, or -/// - An affine expression: baseSizes#reassDim / product of all constant sizes -/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic -/// element.) -/// -/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() -/// -/// TODO: Move this utility function directly within ExpandShapeOp. For now, -/// this is not possible because this function uses the Affine dialect and the -/// MemRef dialect cannot depend on the Affine dialect. -SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, +SmallVector getExpandedStrides(ExpandShapeOp expandShape, OpBuilder &builder, ArrayRef origSizes, ArrayRef origStrides, @@ -405,18 +373,10 @@ SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, return expandedStrides; } -/// Produce an OpFoldResult object with \p builder at \p loc representing -/// `prod(valueOrConstant#i, for i in {indices})`, -/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, -/// values[i] otherwise. -/// -/// \pre for all index in indices: index < values.size() -/// \pre for all index in indices: index < maybeConstants.size() -static OpFoldResult -getProductOfValues(ArrayRef indices, OpBuilder &builder, Location loc, - ArrayRef maybeConstants, - ArrayRef values, - llvm::function_ref isDynamic) { +OpFoldResult getProductOfValues(ArrayRef indices, OpBuilder &builder, + Location loc, ArrayRef maybeConstants, + ArrayRef values, + llvm::function_ref isDynamic) { AffineExpr productOfValues = builder.getAffineConstantExpr(1); SmallVector inputValues; unsigned numberOfSymbols = 0; @@ -450,9 +410,10 @@ getProductOfValues(ArrayRef indices, OpBuilder &builder, Location loc, /// TODO: Move this utility function directly within CollapseShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. -static SmallVector -getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, - ArrayRef origSizes, unsigned groupId) { +SmallVector getCollapsedSize(CollapseShapeOp collapseShape, + OpBuilder &builder, + ArrayRef origSizes, + unsigned groupId) { SmallVector collapsedSize; MemRefType collapseShapeType = collapseShape.getResultType(); @@ -491,10 +452,11 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, /// /// \post result.size() == 1, in other words, each group collapse to one /// dimension. -static SmallVector -getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, - ArrayRef origSizes, - ArrayRef origStrides, unsigned groupId) { +SmallVector getCollapsedStride(CollapseShapeOp collapseShape, + OpBuilder &builder, + ArrayRef origSizes, + ArrayRef origStrides, + unsigned groupId) { SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && @@ -546,6 +508,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, return {lastValidStride}; } +} // namespace memref +} // namespace mlir + +namespace { /// From `reshape_like(memref, subSizes, subStrides))` compute /// diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 1208fddf37e0b..1e6805139f7ff 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -23,9 +23,11 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { @@ -89,17 +91,21 @@ static bool needFlattening(Value val) { return type.getRank() > 1; } -static bool checkLayout(Value val) { - auto type = cast(val.getType()); +static bool checkLayout(MemRefType type) { return type.getLayout().isIdentity() || isa(type.getLayout()); } +static bool checkLayout(Value val) { + return checkLayout(cast(val.getType())); +} + namespace { static Value getTargetMemref(Operation *op) { return llvm::TypeSwitch(op) .template Case([](auto op) { return op.getMemref(); }) + memref::AllocOp, memref::DeallocOp>( + [](auto op) { return op.getMemref(); }) .template Case( @@ -189,6 +195,10 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, rewriter, loc, op.getVector(), flatMemref, ValueRange{offset}); rewriter.replaceOp(op, newTransferWrite); }) + .template Case([&](auto op) { + auto newDealloc = memref::DeallocOp::create(rewriter, loc, flatMemref); + rewriter.replaceOp(op, newDealloc); + }) .Default([&](auto op) { op->emitOpError("unimplemented: do not know how to replace op."); }); @@ -197,7 +207,8 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, template static ValueRange getIndices(T op) { if constexpr (std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { return ValueRange{}; } else { return op.getIndices(); @@ -250,6 +261,249 @@ struct MemRefRewritePattern : public OpRewritePattern { } }; +/// Flattens memref global ops with more than 1 dimensions to 1 dimension. +struct FlattenGlobal final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static Attribute flattenAttribute(Attribute value, ShapedType newType) { + if (!value) + return value; + if (auto splatAttr = llvm::dyn_cast(value)) { + return splatAttr.reshape(newType); + } else if (auto denseAttr = llvm::dyn_cast(value)) { + return denseAttr.reshape(newType); + } else if (auto denseResourceAttr = + llvm::dyn_cast(value)) { + return DenseResourceElementsAttr::get(newType, + denseResourceAttr.getRawHandle()); + } + return {}; + } + + LogicalResult matchAndRewrite(memref::GlobalOp globalOp, + PatternRewriter &rewriter) const override { + auto oldType = llvm::dyn_cast(globalOp.getType()); + if (!oldType || !oldType.getLayout().isIdentity() || oldType.getRank() <= 1) + return failure(); + + auto tensorType = RankedTensorType::get({oldType.getNumElements()}, + oldType.getElementType()); + auto memRefType = + MemRefType::get({oldType.getNumElements()}, oldType.getElementType(), + AffineMap(), oldType.getMemorySpace()); + auto newInitialValue = + flattenAttribute(globalOp.getInitialValueAttr(), tensorType); + rewriter.replaceOpWithNewOp( + globalOp, globalOp.getSymName(), globalOp.getSymVisibilityAttr(), + memRefType, newInitialValue, globalOp.getConstant(), + /*alignment=*/IntegerAttr()); + return success(); + } +}; + +struct FlattenCollapseShape final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CollapseShapeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + memref::ExtractStridedMetadataOp metadata = + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc()); + + SmallVector origSizes = metadata.getConstifiedMixedSizes(); + SmallVector origStrides = + metadata.getConstifiedMixedStrides(); + OpFoldResult offset = metadata.getConstifiedMixedOffset(); + + SmallVector collapsedSizes; + SmallVector collapsedStrides; + unsigned numGroups = op.getReassociationIndices().size(); + collapsedSizes.reserve(numGroups); + collapsedStrides.reserve(numGroups); + for (unsigned i = 0; i < numGroups; ++i) { + SmallVector groupSizes = + memref::getCollapsedSize(op, rewriter, origSizes, i); + SmallVector groupStrides = + memref::getCollapsedStride(op, rewriter, origSizes, origStrides, i); + collapsedSizes.append(groupSizes.begin(), groupSizes.end()); + collapsedStrides.append(groupStrides.begin(), groupStrides.end()); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSrc(), offset, collapsedSizes, + collapsedStrides); + return success(); + } +}; + +struct FlattenExpandShape final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExpandShapeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + memref::ExtractStridedMetadataOp metadata = + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSrc()); + + SmallVector origSizes = metadata.getConstifiedMixedSizes(); + SmallVector origStrides = + metadata.getConstifiedMixedStrides(); + OpFoldResult offset = metadata.getConstifiedMixedOffset(); + + SmallVector expandedSizes; + SmallVector expandedStrides; + unsigned numGroups = op.getReassociationIndices().size(); + expandedSizes.reserve(op.getResultType().getRank()); + expandedStrides.reserve(op.getResultType().getRank()); + + for (unsigned i = 0; i < numGroups; ++i) { + SmallVector groupSizes = + memref::getExpandedSizes(op, rewriter, origSizes, i); + SmallVector groupStrides = + memref::getExpandedStrides(op, rewriter, origSizes, origStrides, i); + expandedSizes.append(groupSizes.begin(), groupSizes.end()); + expandedStrides.append(groupStrides.begin(), groupStrides.end()); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSrc(), offset, expandedSizes, expandedStrides); + return success(); + } +}; + +// Flattens memref subview ops with more than 1 dimension into 1-D accesses. +struct FlattenSubView final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::SubViewOp op, + PatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(op.getSource().getType()); + if (!sourceType || sourceType.getRank() <= 1) + return failure(); + if (!checkLayout(sourceType)) + return failure(); + + MemRefType resultType = op.getType(); + if (resultType.getRank() <= 1 || !checkLayout(resultType)) + return failure(); + + unsigned elementBitWidth = sourceType.getElementTypeBitWidth(); + if (!elementBitWidth) + return failure(); + + Location loc = op.getLoc(); + + // Materialize offsets as values so they can participate in linearization. + SmallVector mixedOffsets = op.getMixedOffsets(); + SmallVector mixedSizes = op.getMixedSizes(); + SmallVector mixedStrides = op.getMixedStrides(); + + SmallVector offsetValues; + offsetValues.reserve(mixedOffsets.size()); + for (OpFoldResult ofr : mixedOffsets) + offsetValues.push_back(getValueFromOpFoldResult(rewriter, loc, ofr)); + + auto [flatSource, linearOffset] = getFlattenMemrefAndOffset( + rewriter, loc, op.getSource(), ValueRange(offsetValues)); + + memref::ExtractStridedMetadataOp sourceMetadata = + memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getSource()); + + SmallVector sourceStrides = + sourceMetadata.getConstifiedMixedStrides(); + OpFoldResult sourceOffset = sourceMetadata.getConstifiedMixedOffset(); + + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + + SmallVector resultSizes; + SmallVector resultStrides; + resultSizes.reserve(resultType.getRank()); + resultStrides.reserve(resultType.getRank()); + + OpFoldResult resultOffset = sourceOffset; + for (auto zipped : llvm::enumerate(llvm::zip_equal( + mixedOffsets, sourceStrides, mixedSizes, mixedStrides))) { + auto idx = zipped.index(); + auto it = zipped.value(); + auto offsetOfr = std::get<0>(it); + auto strideOfr = std::get<1>(it); + auto sizeOfr = std::get<2>(it); + auto relativeStrideOfr = std::get<3>(it); + OpFoldResult contribution = [&]() -> OpFoldResult { + if (Attribute offsetAttr = dyn_cast(offsetOfr)) { + if (Attribute strideAttr = dyn_cast(strideOfr)) { + auto offsetInt = cast(offsetAttr).getInt(); + auto strideInt = cast(strideAttr).getInt(); + return rewriter.getIndexAttr(offsetInt * strideInt); + } + } + Value offsetVal = getValueFromOpFoldResult(rewriter, loc, offsetOfr); + Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr); + return rewriter.create(loc, offsetVal, strideVal) + .getResult(); + }(); + resultOffset = [&]() -> OpFoldResult { + if (Attribute offsetAttr = dyn_cast(resultOffset)) { + if (Attribute contribAttr = dyn_cast(contribution)) { + auto offsetInt = cast(offsetAttr).getInt(); + auto contribInt = cast(contribAttr).getInt(); + return rewriter.getIndexAttr(offsetInt + contribInt); + } + } + Value offsetVal = getValueFromOpFoldResult(rewriter, loc, resultOffset); + Value contribVal = + getValueFromOpFoldResult(rewriter, loc, contribution); + return rewriter.create(loc, offsetVal, contribVal) + .getResult(); + }(); + + if (droppedDims.test(idx)) + continue; + + resultSizes.push_back(sizeOfr); + OpFoldResult combinedStride = [&]() -> OpFoldResult { + if (Attribute relStrideAttr = dyn_cast(relativeStrideOfr)) { + if (Attribute strideAttr = dyn_cast(strideOfr)) { + auto relStrideInt = cast(relStrideAttr).getInt(); + auto strideInt = cast(strideAttr).getInt(); + return rewriter.getIndexAttr(relStrideInt * strideInt); + } + } + Value relStrideVal = + getValueFromOpFoldResult(rewriter, loc, relativeStrideOfr); + Value strideVal = getValueFromOpFoldResult(rewriter, loc, strideOfr); + return rewriter.create(loc, relStrideVal, strideVal) + .getResult(); + }(); + resultStrides.push_back(combinedStride); + } + + memref::LinearizedMemRefInfo linearizedInfo; + [[maybe_unused]] OpFoldResult linearizedIndex; + std::tie(linearizedInfo, linearizedIndex) = + memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth, + elementBitWidth, resultOffset, + resultSizes, resultStrides); + + Value flattenedSize = + getValueFromOpFoldResult(rewriter, loc, linearizedInfo.linearizedSize); + Value strideOne = arith::ConstantIndexOp::create(rewriter, loc, 1); + + Value flattenedSubview = memref::SubViewOp::create( + rewriter, loc, flatSource, ValueRange{linearOffset}, + ValueRange{flattenedSize}, ValueRange{strideOne}); + + Value replacement = memref::ReinterpretCastOp::create( + rewriter, loc, resultType, flattenedSubview, resultOffset, resultSizes, + resultStrides); + + rewriter.replaceOp(op, replacement); + return success(); + } +}; + struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { using Base::Base; @@ -271,6 +525,44 @@ struct FlattenMemrefsPass } // namespace +struct FlattenGetGlobal : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::GetGlobalOp op, + PatternRewriter &rewriter) const override { + // Check if this get_global references a multi-dimensional global + auto module = op->template getParentOfType(); + auto globalOp = + module.template lookupSymbol(op.getName()); + if (!globalOp) { + return failure(); + } + + auto globalType = globalOp.getType(); + auto resultType = op.getType(); + + // Only apply if the global has been flattened but the get_global hasn't + if (globalType.getRank() == 1 && resultType.getRank() > 1) { + auto newGetGlobal = memref::GetGlobalOp::create(rewriter, op.getLoc(), + globalType, op.getName()); + + // Cast the flattened result back to the original shape + memref::ExtractStridedMetadataOp stridedMetadata = + memref::ExtractStridedMetadataOp::create(rewriter, op.getLoc(), + op.getResult()); + auto castResult = memref::ReinterpretCastOp::create( + rewriter, op.getLoc(), resultType, newGetGlobal, + /*offset=*/rewriter.getIndexAttr(0), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides()); + rewriter.replaceOp(op, castResult); + return success(); + } + + return failure(); + } +}; + void memref::populateFlattenVectorOpsOnMemrefPatterns( RewritePatternSet &patterns) { patterns.insert, @@ -286,8 +578,10 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) { patterns.insert, MemRefRewritePattern, MemRefRewritePattern, - MemRefRewritePattern>( - patterns.getContext()); + MemRefRewritePattern, + MemRefRewritePattern, FlattenExpandShape, + FlattenCollapseShape, FlattenSubView, FlattenGetGlobal, + FlattenGlobal>(patterns.getContext()); } void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index e45a10ca0d431..05290f96f45e9 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -194,6 +194,77 @@ func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: in // ----- +func.func @flatten_subview_static(%arg0: memref<3x4xf32, strided<[4, 1], offset: 0>>) -> memref<2x2xf32, strided<[4, 1], offset: 1>> { + %sub = memref.subview %arg0[0, 1] [2, 2] [1, 1] + : memref<3x4xf32, strided<[4, 1], offset: 0>> to memref<2x2xf32, strided<[4, 1], offset: 1>> + return %sub : memref<2x2xf32, strided<[4, 1], offset: 1>> +} +// CHECK-LABEL: func @flatten_subview_static +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[FLAT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [12], strides: [1] +// CHECK: %[[SUB:.*]] = memref.subview %[[FLAT]][%[[C1]]] [%[[C8]]] [%[[C1]]] +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[SUB]] to offset: [1], sizes: [2, 2], strides: [4, 1] +// CHECK: return %[[CAST]] + +func.func @collapse_shape_static(%arg0: memref<2x3x4xf32>) -> memref<6x4xf32> { + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] + : memref<2x3x4xf32> into memref<6x4xf32> + return %0 : memref<6x4xf32> +} +// CHECK-LABEL: func @collapse_shape_static +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [6, 4], strides: [4, 1] +// CHECK: return %[[REINT]] + +// ----- + +func.func @collapse_shape_dynamic( + %arg0: memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>>) -> + memref> { + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] + : memref<2x?x4xf32, strided<[?, ?, ?], offset: ?>> + into memref> + return %0 : memref> +} +// CHECK: #map = affine_map<()[s0] -> (s0 * 2)> +// CHECK-LABEL: func @collapse_shape_dynamic +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %arg0 +// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#1] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 4], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2] +// CHECK: return %[[REINT]] + +// ----- + +func.func @expand_shape_static(%arg0: memref<6x4xf32>) -> memref<2x3x4xf32> { + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 4] + : memref<6x4xf32> into memref<2x3x4xf32> + return %0 : memref<2x3x4xf32> +} +// CHECK-LABEL: func @expand_shape_static +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1] +// CHECK: return %[[REINT]] + +// ----- + +func.func @expand_shape_dynamic( + %arg0: memref>, %size: index) -> + memref> { + %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%size, 3, 4] + : memref> + into memref> + return %0 : memref> +} +// CHECK: #map = affine_map<()[s0] -> (s0 floordiv 3)> +// CHECK: #map1 = affine_map<()[s0] -> (s0 * 3)> +// CHECK-LABEL: func @expand_shape_dynamic +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %arg0 +// CHECK: %[[SIZE:.*]] = affine.apply #map()[%[[SIZES]]#0] +// CHECK: %[[STRIDE:.*]] = affine.apply #map1()[%[[STRIDES]]#0] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]], 3, 4], strides: [%[[STRIDE]], %[[STRIDES]]#0, %[[STRIDES]]#1] +// CHECK: return %[[REINT]] + +// ----- + func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> { %c0 = arith.constant 0 : i2 %0 = vector.transfer_read %input[%col, %row], %c0 {in_bounds = [true]} : memref<4x8xi2>, vector<8xi2> @@ -298,3 +369,60 @@ func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32, // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> // CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>> + +// ----- + +func.func @dealloc_static_memref(%input: memref<4x8xf32>) { + memref.dealloc %input : memref<4x8xf32> + return +} + +// CHECK-LABEL: func @dealloc_static_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>) +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32], strides: [1] : memref<4x8xf32> to memref<32xf32, strided<[1]>> +// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1]>> + +// ----- + +func.func @dealloc_dynamic_memref(%input: memref) { + memref.dealloc %input : memref + return +} + +// CHECK-LABEL: func @dealloc_dynamic_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] +// CHECK: %[[SIZE:.*]] = affine.max #{{.*}}()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [%[[SIZE]]], strides: [1] : memref to memref> +// CHECK: memref.dealloc %[[REINT]] : memref> + +// ----- + +func.func @dealloc_strided_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) { + memref.dealloc %input : memref<4x8xf32, strided<[8, 1], offset: 100>> + return +} + +// CHECK-LABEL: func @dealloc_strided_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 1], offset: 100>>) +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> +// CHECK-NEXT: memref.dealloc %[[REINT]] : memref<32xf32, strided<[1], offset: 100>> + +// ----- + +memref.global "private" constant @constant_3x3x1x1xf32 : memref<3x3x1x1xf32> = dense<[[[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]], [[[-2.000000e+00]], [[0.000000e+00]], [[2.000000e+00]]], [[[-1.000000e+00]], [[0.000000e+00]], [[1.000000e+00]]]]> +func.func @load_global_with_offset(%i0: index, %i1: index, %i2: index, %i3: index) -> f32 { + %global = memref.get_global @constant_3x3x1x1xf32 : memref<3x3x1x1xf32> + %val = memref.load %global[%i0, %i1, %i2, %i3] : memref<3x3x1x1xf32> + return %val: f32 +} + +// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2, s3] -> (s0 * 3 + s1 + s2 + s3)> +// CHECK: memref.global "private" constant @constant_3x3x1x1xf32 : memref<9xf32> = dense<[-1.000000e+00, 0.000000e+00, 1.000000e+00, -2.000000e+00, 0.000000e+00, 2.000000e+00, -1.000000e+00, 0.000000e+00, 1.000000e+00]> +//CHECK-LABEL: func.func @load_global_with_offset +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index) +// CHECK: %[[GLOBAL:.+]] = memref.get_global @constant_3x3x1x1xf32 : memref<9xf32> +// CHECK: %[[INDEX:.+]] = affine.apply #[[$MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]]] +// CHECK: %[[REINTERPRET:.+]] = memref.reinterpret_cast %[[GLOBAL]] to offset: [0], sizes: [9], strides: [1] : memref<9xf32> to memref<9xf32, strided<[1]>> +// CHECK: %[[LOAD:.+]] = memref.load %[[REINTERPRET]][%[[INDEX]]] : memref<9xf32, strided<[1]>> +// CHECK: return %[[LOAD]]