diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h index bdec699eb4ce4..30f33ed2fd1d6 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -18,6 +18,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/MemOpInterfaces.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 671cc05e963b4..cd92ca98b2530 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/MemOpInterfaces.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -145,7 +146,8 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [ DeclareOpInterfaceMethods, Pure, ViewLikeOpInterface, - SameOperandsAndResultType + SameOperandsAndResultType, + DeclareOpInterfaceMethods ]> { let summary = "assumption that gives alignment information to the input memref"; @@ -456,6 +458,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return", def MemRef_CastOp : MemRef_Op<"cast", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, MemRefsNormalizable, Pure, SameOperandsAndResultShape, @@ -1194,6 +1197,7 @@ def LoadOp : MemRef_Op<"load", "memref", "result", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "load operation"; @@ -1284,6 +1288,7 @@ def LoadOp : MemRef_Op<"load", def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + MemorySpaceCastOpInterface, MemRefsNormalizable, Pure, SameOperandsAndResultElementType, @@ -1302,6 +1307,10 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [ If the source and target address spaces are the same, this operation is a noop. + Finally, if the target memory-space is the generic/default memory-space, + then it is assumed this cast can be bubbled down safely. See the docs of + `MemorySpaceCastOpInterface` interface for more details. + Example: ```mlir @@ -1321,6 +1330,27 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [ let extraClassDeclaration = [{ Value getViewSource() { return getSource(); } + + //===------------------------------------------------------------------===// + // MemorySpaceCastConsumerOpInterface + //===------------------------------------------------------------------===// + /// Returns the `source` memref. + TypedValue getSourcePtr(); + /// Returns the `dest` memref. + TypedValue getTargetPtr(); + /// Returns whether the memory-space cast is valid. Only casts between + /// memrefs are considered valid. Further, the `tgt` and `src` should only + /// differ on the memory-space parameter of the memref type. + bool isValidMemorySpaceCast(PtrLikeTypeInterface tgt, + PtrLikeTypeInterface src); + /// Clones the operation using a new target type and source value. + MemorySpaceCastOpInterface cloneMemorySpaceCastOp( + OpBuilder &b, PtrLikeTypeInterface tgt, + TypedValue src); + /// Returns whether the `source` value can be promoted by the + /// `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method. The only + /// casts the op recognizes as promotable are to the generic memory-space. + bool isSourcePromotable(); }]; let hasFolder = 1; @@ -1376,6 +1406,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> { def MemRef_ReinterpretCastOp : MemRef_OpWithOffsetSizesAndStrides<"reinterpret_cast", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, AttrSizedOperandSegments, MemRefsNormalizable, Pure, @@ -1603,6 +1634,7 @@ def MemRef_RankOp : MemRef_Op<"rank", [Pure]> { def MemRef_ReshapeOp: MemRef_Op<"reshape", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure, ViewLikeOpInterface]> { let summary = "memref reshape operation"; @@ -1701,6 +1733,7 @@ class MemRef_ReassociativeReshapeOp traits = []> : def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "operation to produce a memref with a higher rank."; let description = [{ @@ -1822,7 +1855,9 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ } def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { let summary = "operation to produce a memref with a smaller rank."; let description = [{ The `memref.collapse_shape` op produces a new view with a smaller rank @@ -1929,6 +1964,7 @@ def MemRef_StoreOp : MemRef_Op<"store", "memref", "value", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "store operation"; @@ -2006,6 +2042,7 @@ def MemRef_StoreOp : MemRef_Op<"store", def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, @@ -2281,6 +2318,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [ def MemRef_TransposeOp : MemRef_Op<"transpose", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure]>, Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>, Results<(outs AnyStridedMemRef)> { @@ -2316,6 +2354,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [ def MemRef_ViewOp : MemRef_Op<"view", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, Pure]> { let summary = "memref view operation"; @@ -2392,6 +2431,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [ //===----------------------------------------------------------------------===// def AtomicRMWOp : MemRef_Op<"atomic_rmw", [ + DeclareOpInterfaceMethods, AllTypesMatch<["value", "result"]>, TypesMatchWith<"value type matches element type of memref", "memref", "value", diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 63410b8bea747..bbf55f5d507e3 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -27,6 +27,7 @@ #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/IndexingMapOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/MemOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 26d06624cb976..252c0b72456df 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -24,6 +24,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/IndexingMapOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/MemOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -1246,6 +1247,7 @@ def Vector_TransferReadOp : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, AttrSizedOperandSegments, DestinationStyleOpInterface ]>, @@ -1493,6 +1495,7 @@ def Vector_TransferWriteOp : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, AttrSizedOperandSegments, DestinationStyleOpInterface ]>, @@ -1649,6 +1652,7 @@ def Vector_TransferWriteOp : def Vector_LoadOp : Vector_Op<"load", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let summary = "reads an n-D slice of memory into an n-D vector"; let description = [{ @@ -1765,6 +1769,7 @@ def Vector_LoadOp : Vector_Op<"load", [ def Vector_StoreOp : Vector_Op<"store", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let summary = "writes an n-D vector to an n-D slice of memory"; let description = [{ @@ -1869,7 +1874,7 @@ def Vector_StoreOp : Vector_Op<"store", [ } def Vector_MaskedLoadOp : - Vector_Op<"maskedload">, + Vector_Op<"maskedload", [DeclareOpInterfaceMethods]>, Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, @@ -1961,7 +1966,7 @@ def Vector_MaskedLoadOp : } def Vector_MaskedStoreOp : - Vector_Op<"maskedstore">, + Vector_Op<"maskedstore", [DeclareOpInterfaceMethods]>, Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, @@ -2041,6 +2046,7 @@ def Vector_MaskedStoreOp : def Vector_GatherOp : Vector_Op<"gather", [ DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins Arg, "", [MemRead]>:$base, @@ -2144,7 +2150,7 @@ def Vector_GatherOp : } def Vector_ScatterOp : - Vector_Op<"scatter">, + Vector_Op<"scatter", [DeclareOpInterfaceMethods]>, Arguments<(ins Arg:$base, Variadic:$offsets, VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices, @@ -2229,7 +2235,7 @@ def Vector_ScatterOp : } def Vector_ExpandLoadOp : - Vector_Op<"expandload">, + Vector_Op<"expandload", [DeclareOpInterfaceMethods]>, Arguments<(ins Arg:$base, Variadic:$indices, FixedVectorOfNonZeroRankOf<[I1]>:$mask, @@ -2317,7 +2323,7 @@ def Vector_ExpandLoadOp : } def Vector_CompressStoreOp : - Vector_Op<"compressstore">, + Vector_Op<"compressstore", [DeclareOpInterfaceMethods]>, Arguments<(ins Arg:$base, Variadic:$indices, FixedVectorOfNonZeroRankOf<[I1]>:$mask, diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index 2add220fdfb7c..a5feb592045c0 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_interface(IndexingMapOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) +add_mlir_interface(MemOpInterfaces) add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.h b/mlir/include/mlir/Interfaces/MemOpInterfaces.h new file mode 100644 index 0000000000000..cdc423f5da1a5 --- /dev/null +++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.h @@ -0,0 +1,36 @@ +//===- MemOpInterfaces.h - Memory operation interfaces ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of interfaces for operations that interact +// with memory. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_MEMOPINTERFACES_H +#define MLIR_INTERFACES_MEMOPINTERFACES_H + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace detail { +/// Attempt to verify the given memory space cast operation. +LogicalResult verifyMemorySpaceCastOpInterface(Operation *op); + +/// Tries to bubble-down inplace a `MemorySpaceCastOpInterface` operation +/// referenced by `operand`. On success, it returns `std::nullopt`. It +/// returns failure if `operand` doesn't reference a +/// `MemorySpaceCastOpInterface` op. +FailureOr>> +bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results); +} // namespace detail +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/MemOpInterfaces.h.inc" + +#endif // MLIR_INTERFACES_MEMOPINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/MemOpInterfaces.td b/mlir/include/mlir/Interfaces/MemOpInterfaces.td new file mode 100644 index 0000000000000..1a64e97c3412d --- /dev/null +++ b/mlir/include/mlir/Interfaces/MemOpInterfaces.td @@ -0,0 +1,125 @@ +//===- MemOpInterfaces.td - Memory operation interfaces -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains interfaces for operations that interact with memory. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_MEMOPINTERFACES_TD +#define MLIR_INTERFACES_MEMOPINTERFACES_TD + +include "mlir/IR/OpBase.td" + +def MemorySpaceCastConsumerOpInterface : + OpInterface<"MemorySpaceCastConsumerOpInterface"> { + let description = [{ + An interface for operations that can consume memory-space cast-like + operations. + + This interface can be used to bubble-down memory-space cast operations, + see the `bubble-down-memory-space-casts` pass for an example. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<[{ + Attempt to bubble-down the incoming cast-like operands. On success + returns a `std::optional>`, otherwise it returns + failure. If the optional is `std::nullopt` then the cast was performed + in place, otherwise the method returns a list of replacement values. + If new results are produced, these must be compatible with the original + operation results. + + If the operation was not modified in place, then the interface + guarantees it is valid to erase the original operation. + If the operation was modified in place, then the interface must + guarantee no operations were created by the method, and that no further + IR modification is necessary. + + Any implementations of this method must not erase/replace the original + operation, instead it is the caller responsibility to erase or replace + the op with the results provided by the method. + + Finally, any implementations of this method have to guarantee that the + IR remains valid at all times. + }], + "::llvm::FailureOr>>", + "bubbleDownCasts", + (ins "::mlir::OpBuilder &":$builder) + >, + ]; +} + +def MemorySpaceCastOpInterface : OpInterface<"MemorySpaceCastOpInterface"> { + let description = [{ + An interface for operations that perform memory-space casts. This + interface assumes that the cast operation is `pure`. + + These operations expect to have a well-defined ptr-like operand, and + a well-defined target ptr-like result. + + This interface also allows to determine whether a cast can be bubbled-down + by the `MemorySpaceCastConsumerOpInterface`, allowing control over which + casts can be bubbled-down or not. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<[{ + Returns the source ptr-like value. + }], + "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getSourcePtr" + >, + InterfaceMethod<[{ + Returns the target ptr-like value. + }], + "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>", "getTargetPtr" + >, + InterfaceMethod<[{ + Returns whether the memory space cast specified by `tgt` and `src` + is supported. + }], + "bool", "isValidMemorySpaceCast", + (ins "::mlir::PtrLikeTypeInterface":$tgt, + "::mlir::PtrLikeTypeInterface":$src) + >, + InterfaceMethod<[{ + Clones the memory space cast op with the given source and target type. + }], + "::mlir::MemorySpaceCastOpInterface", "cloneMemorySpaceCastOp", + (ins "::mlir::OpBuilder &":$builder, "::mlir::PtrLikeTypeInterface":$tgt, + "::mlir::TypedValue<::mlir::PtrLikeTypeInterface>":$src) + >, + InterfaceMethod<[{ + Returns whether the source pointer of the memory-space cast can be used + by the `MemorySpaceCastConsumerOpInterface::bubbleDownCasts` method to + promote the source pointer and bubble down the cast. + + For example, a cast operation might decide that all casts to the generic + memory-space can be promoted. + }], + "bool", "isSourcePromotable" + > + ]; + let verify = [{ + return ::mlir::detail::verifyMemorySpaceCastOpInterface($_op); + }]; + let extraClassDeclaration = [{ + /// Returns the underlying `MemorySpaceCastOpInterface` op if `value` + /// is produced by a `MemorySpaceCastOpInterface` op, and + /// `isSourcePromotable` returns true, otherwise it returns null. + static ::mlir::MemorySpaceCastOpInterface + getIfPromotableCast(::mlir::Value value) { + auto op = ::llvm::dyn_cast_or_null<::mlir::MemorySpaceCastOpInterface>( + value.getDefiningOp()); + if (!op || !op.isSourcePromotable()) + return nullptr; + return op; + } + }]; +} + +#endif // MLIR_INTERFACES_MEMOPINTERFACES_TD diff --git a/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h new file mode 100644 index 0000000000000..99db092879a90 --- /dev/null +++ b/mlir/include/mlir/Transforms/BubbleDownMemorySpaceCasts.h @@ -0,0 +1,20 @@ +//===-- BubbleDownMemorySpaceCasts.h - Bubble down cast patterns ---C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H +#define MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H + +namespace mlir { +class PatternBenefit; +class RewritePatternSet; +/// Collect a set of patterns to bubble-down memory-space cast operations. +void populateBubbleDownMemorySpaceCastPatterns(RewritePatternSet &patterns, + PatternBenefit benefit); +} // namespace mlir + +#endif // MLIR_TRANSFORMS_BUBBLEDOWNMEMORYSPACECASTS_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 9cd2ef34e15ea..1c035f2a843ff 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -46,6 +46,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_SYMBOLPRIVATIZE #define GEN_PASS_DECL_TOPOLOGICALSORT #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS +#define GEN_PASS_DECL_BUBBLEDOWNMEMORYSPACECASTS #include "mlir/Transforms/Passes.h.inc" /// Creates an instance of the Canonicalizer pass, configured with default diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index beb59784947c5..b2b7f20a497e3 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -585,4 +585,48 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { ]; } +def BubbleDownMemorySpaceCasts : + Pass<"bubble-down-memory-space-casts"> { + let summary = "Bubbles down memory-space cast operations."; + let description = [{ + This pass tries to iteratively bubble down all possible memory-space cast + operations. It is important to note that the determination of which casts + are bubbled down is based on the interfaces + `MemorySpaceCastConsumerOpInterface`, and `MemorySpaceCastOpInterface`, and + not the pass. The pass only looks for operations implementing the + `MemorySpaceCastConsumerOpInterface` interface, and invoking the interface + methods to perform the bubbling down. + + Example: + + ```mlir + func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { + %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32> + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32> + %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32> + %loaded = memref.load %collapsed[%c0] : memref<16xf32> + %added = arith.addf %loaded, %arg2 : f32 + memref.store %added, %collapsed[%c0] : memref<16xf32> + %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32 + return %collapsed : memref<16xf32> + } + // mlir-opt --bubble-down-memory-space-casts + func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %expand_shape = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1> + %collapse_shape = memref.collapse_shape %expand_shape [[0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1> + %memspacecast = memref.memory_space_cast %collapse_shape : memref<16xf32, 1> to memref<16xf32> + %0 = memref.load %collapse_shape[%c0] : memref<16xf32, 1> + %1 = arith.addf %0, %arg2 : f32 + memref.store %1, %collapse_shape[%c0] : memref<16xf32, 1> + %2 = memref.atomic_rmw addf %arg2, %collapse_shape[%c4] : (f32, memref<16xf32, 1>) -> f32 + return %memspacecast : memref<16xf32> + } + ``` + }]; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 734294bd014c6..e25a0121a3359 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR + MLIRMemOpInterfaces MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 5d15d5f6e3de4..349b4deb29023 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -111,6 +111,65 @@ static void constifyIndexValues(SmallVectorImpl &values, } } +/// Helper function to retrieve a lossless memory-space cast, and the +/// corresponding new result memref type. +static std::tuple +getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src) { + MemorySpaceCastOpInterface castOp = + MemorySpaceCastOpInterface::getIfPromotableCast(src); + + // Bail if the cast is not lossless. + if (!castOp) + return {}; + + // Transform the source and target type of `castOp` to have the same metadata + // as `resultTy`. Bail if not possible. + FailureOr srcTy = resultTy.clonePtrWith( + castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt); + if (failed(srcTy)) + return {}; + + FailureOr tgtTy = resultTy.clonePtrWith( + castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt); + if (failed(tgtTy)) + return {}; + + // Check if this is a valid memory-space cast. + if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy)) + return {}; + + return std::make_tuple(castOp, *tgtTy, *srcTy); +} + +/// Implementation of `bubbleDownCasts` method for memref operations that +/// return a single memref result. +template +static FailureOr>> +bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, + OpOperand &src) { + auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get()); + // Bail if we cannot cast. + if (!castOp) + return failure(); + + // Create the new operands. + SmallVector operands; + llvm::append_range(operands, op->getOperands()); + operands[src.getOperandNumber()] = castOp.getSourcePtr(); + + // Create the new op and results. + auto newOp = ConcreteOpTy::create( + builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(), + llvm::to_vector_of(op->getDiscardableAttrs())); + + // Insert a memory-space cast to the original memory space of the op. + MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp( + builder, tgtTy, + cast>(newOp.getResult())); + return std::optional>( + SmallVector({result.getTargetPtr()})); +} + //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -542,6 +601,11 @@ OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) { return getMemref(); } +FailureOr>> +AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable()); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// @@ -710,6 +774,11 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +FailureOr>> +CastOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); +} + //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// @@ -1601,6 +1670,12 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) { return OpFoldResult(); } +FailureOr>> +LoadOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(), + getResult()); +} + //===----------------------------------------------------------------------===// // MemorySpaceCastOp //===----------------------------------------------------------------------===// @@ -1645,6 +1720,32 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { return Value{}; } +TypedValue MemorySpaceCastOp::getSourcePtr() { + return cast>(getSource()); +} + +TypedValue MemorySpaceCastOp::getTargetPtr() { + return cast>(getDest()); +} + +bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt, + PtrLikeTypeInterface src) { + return isa(tgt) && + tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src; +} + +MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp( + OpBuilder &b, PtrLikeTypeInterface tgt, + TypedValue src) { + assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments"); + return MemorySpaceCastOp::create(b, getLoc(), tgt, src); +} + +/// The only cast we recognize as promotable is to the generic space. +bool MemorySpaceCastOp::isSourcePromotable() { + return getDest().getType().getMemorySpace() == nullptr; +} + //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// @@ -2041,6 +2142,11 @@ void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); +} + //===----------------------------------------------------------------------===// // Reassociative reshape ops //===----------------------------------------------------------------------===// @@ -2348,6 +2454,11 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, ComposeExpandOfCollapseOp>(context); } +FailureOr>> +ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable()); +} + /// Compute the layout map after collapsing a given source MemRef type with the /// specified reassociation indices. /// @@ -2569,6 +2680,11 @@ OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { adaptor.getOperands()); } +FailureOr>> +CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable()); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -2609,6 +2725,11 @@ LogicalResult ReshapeOp::verify() { return success(); } +FailureOr>> +ReshapeOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -2626,6 +2747,12 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor, return foldMemRefCast(*this, getValueToStore()); } +FailureOr>> +StoreOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(), + ValueRange()); +} + //===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// @@ -3282,6 +3409,11 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { return {}; } +FailureOr>> +SubViewOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// @@ -3382,6 +3514,11 @@ OpFoldResult TransposeOp::fold(FoldAdaptor) { return {}; } +FailureOr>> +TransposeOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable()); +} + //===----------------------------------------------------------------------===// // ViewOp //===----------------------------------------------------------------------===// @@ -3525,6 +3662,11 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +ViewOp::bubbleDownCasts(OpBuilder &builder) { + return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable()); +} + //===----------------------------------------------------------------------===// // AtomicRMWOp //===----------------------------------------------------------------------===// @@ -3570,6 +3712,12 @@ OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) { return OpFoldResult(); } +FailureOr>> +AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(), + getResult()); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8d6e263934fb4..b2e5a5b1e36cc 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5087,6 +5087,14 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +TransferReadOp::bubbleDownCasts(OpBuilder &builder) { + if (!hasPureBufferSemantics()) + return failure(); + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + getResult()); +} + //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// @@ -5574,6 +5582,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +TransferWriteOp::bubbleDownCasts(OpBuilder &builder) { + if (!hasPureBufferSemantics()) + return failure(); + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + ValueRange()); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// @@ -5628,6 +5644,12 @@ std::optional> LoadOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } +FailureOr>> +LoadOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + getResult()); +} + //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// @@ -5667,6 +5689,12 @@ std::optional> StoreOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } +FailureOr>> +StoreOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + ValueRange()); +} + //===----------------------------------------------------------------------===// // MaskedLoadOp //===----------------------------------------------------------------------===// @@ -5721,6 +5749,12 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) { return OpFoldResult(); } +FailureOr>> +MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + getResult()); +} + //===----------------------------------------------------------------------===// // MaskedStoreOp //===----------------------------------------------------------------------===// @@ -5771,6 +5805,12 @@ LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor, return memref::foldMemRefCast(*this); } +FailureOr>> +MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + ValueRange()); +} + //===----------------------------------------------------------------------===// // GatherOp //===----------------------------------------------------------------------===// @@ -5874,6 +5914,12 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +GatherOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + getResult()); +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// @@ -5936,6 +5982,12 @@ void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +ScatterOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + ValueRange()); +} + //===----------------------------------------------------------------------===// // ExpandLoadOp //===----------------------------------------------------------------------===// @@ -5984,6 +6036,12 @@ void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + getResult()); +} + //===----------------------------------------------------------------------===// // CompressStoreOp //===----------------------------------------------------------------------===// @@ -6030,6 +6088,12 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +FailureOr>> +CompressStoreOp::bubbleDownCasts(OpBuilder &builder) { + return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getBaseMutable(), + ValueRange()); +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index fdc19844702bc..388de1c3e5abf 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -11,6 +11,7 @@ set(LLVM_OPTIONAL_SOURCES InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp + MemOpInterfaces.cpp MemorySlotInterfaces.cpp ParallelCombiningOpInterface.cpp RuntimeVerifiableOpInterface.cpp @@ -79,6 +80,7 @@ add_mlir_library(MLIRLoopLikeInterface MLIRFunctionInterfaces ) +add_mlir_interface_library(MemOpInterfaces) add_mlir_interface_library(MemorySlotInterfaces) add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(RuntimeVerifiableOpInterface) diff --git a/mlir/lib/Interfaces/MemOpInterfaces.cpp b/mlir/lib/Interfaces/MemOpInterfaces.cpp new file mode 100644 index 0000000000000..fe5c717f67bc4 --- /dev/null +++ b/mlir/lib/Interfaces/MemOpInterfaces.cpp @@ -0,0 +1,73 @@ +//===- MemOpInterfaces.cpp - Memory operation interfaces ---------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/MemOpInterfaces.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" + +using namespace mlir; + +LogicalResult mlir::detail::verifyMemorySpaceCastOpInterface(Operation *op) { + auto memCastOp = cast(op); + + // Verify that the source and target pointers are valid + Value sourcePtr = memCastOp.getSourcePtr(); + Value targetPtr = memCastOp.getTargetPtr(); + + if (!sourcePtr || !targetPtr) { + return op->emitError() + << "memory space cast op must have valid source and target pointers"; + } + + if (sourcePtr.getType().getTypeID() != targetPtr.getType().getTypeID()) { + return op->emitError() + << "expected source and target types of the same kind"; + } + + // Verify the Types are of `PtrLikeTypeInterface` type. + auto sourceType = dyn_cast(sourcePtr.getType()); + if (!sourceType) { + return op->emitError() + << "source type must implement `PtrLikeTypeInterface`, but got: " + << sourcePtr.getType(); + } + + auto targetType = dyn_cast(targetPtr.getType()); + if (!targetType) { + return op->emitError() + << "target type must implement `PtrLikeTypeInterface`, but got: " + << targetPtr.getType(); + } + + // Verify that the operation has exactly one result + if (op->getNumResults() != 1) { + return op->emitError() + << "memory space cast op must have exactly one result"; + } + + return success(); +} + +FailureOr>> +mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, + ValueRange results) { + MemorySpaceCastOpInterface castOp = + MemorySpaceCastOpInterface::getIfPromotableCast(operand.get()); + + // Bail if the src is not valid. + if (!castOp) + return failure(); + + // Modify the op. + operand.set(castOp.getSourcePtr()); + return std::optional>(); +} + +#include "mlir/Interfaces/MemOpInterfaces.cpp.inc" diff --git a/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp new file mode 100644 index 0000000000000..00dac19e37171 --- /dev/null +++ b/mlir/lib/Transforms/BubbleDownMemorySpaceCasts.cpp @@ -0,0 +1,69 @@ +//===- BubbleDownMemorySpaceCasts.cpp - Bubble down casts transform -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/BubbleDownMemorySpaceCasts.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/MemOpInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +namespace mlir { +#define GEN_PASS_DEF_BUBBLEDOWNMEMORYSPACECASTS +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +namespace { +//===----------------------------------------------------------------------===// +// BubbleDownCastsPattern pattern +//===----------------------------------------------------------------------===// +/// Pattern to bubble down casts into consumer operations. +struct BubbleDownCastsPattern + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(MemorySpaceCastConsumerOpInterface op, + PatternRewriter &rewriter) const override { + FailureOr>> results = + op.bubbleDownCasts(rewriter); + if (failed(results)) + return failure(); + if (!results->has_value()) { + rewriter.modifyOpInPlace(op, []() {}); + return success(); + } + rewriter.replaceOp(op, **results); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// BubbleDownMemorySpaceCasts pass +//===----------------------------------------------------------------------===// + +struct BubbleDownMemorySpaceCasts + : public impl::BubbleDownMemorySpaceCastsBase { + using impl::BubbleDownMemorySpaceCastsBase< + BubbleDownMemorySpaceCasts>::BubbleDownMemorySpaceCastsBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateBubbleDownMemorySpaceCastPatterns(patterns, PatternBenefit(1)); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +void mlir::populateBubbleDownMemorySpaceCastPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 058039e47313e..54b67f5c7a91e 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_library(MLIRTransforms ControlFlowSink.cpp CSE.cpp GenerateRuntimeVerification.cpp + BubbleDownMemorySpaceCasts.cpp InlinerPass.cpp LocationSnapshot.cpp LoopInvariantCodeMotion.cpp @@ -31,6 +32,7 @@ add_mlir_library(MLIRTransforms MLIRAnalysis MLIRFunctionInterfaces MLIRLoopLikeInterface + MLIRMemOpInterfaces MLIRMemorySlotInterfaces MLIRPass MLIRRuntimeVerifiableOpInterface diff --git a/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir new file mode 100644 index 0000000000000..e4fce89cffb45 --- /dev/null +++ b/mlir/test/Transforms/test-bubble-down-memory-space-casts.mlir @@ -0,0 +1,298 @@ +// RUN: mlir-opt %s --bubble-down-memory-space-casts | FileCheck %s + +#map = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)> + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)> +// CHECK-LABEL: func.func @load_store( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref +// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref +// CHECK: return +// CHECK: } +func.func @load_store(%arg0: memref, %arg1: index) { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %0 = memref.load %memspacecast[%arg1] : memref + memref.store %0, %memspacecast[%arg1] : memref + return +} + +// CHECK-LABEL: func.func @load_store_unfoldable( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = memref.memory_space_cast %[[ARG0]] : memref to memref +// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[ARG1]]] : memref +// CHECK: memref.store %[[VAL_1]], %[[VAL_0]]{{\[}}%[[ARG1]]] : memref +// CHECK: return +// CHECK: } +func.func @load_store_unfoldable(%arg0: memref, %arg1: index) { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %0 = memref.load %memspacecast[%arg1] : memref + memref.store %0, %memspacecast[%arg1] : memref + return +} + +// CHECK-LABEL: func.func @cast( +// CHECK-SAME: %[[ARG0:.*]]: memref<2xf32, 1>, +// CHECK-SAME: %[[ARG1:.*]]: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) { +// CHECK: %[[VAL_0:.*]] = memref.cast %[[ARG0]] : memref<2xf32, 1> to memref<*xf32, 1> +// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<*xf32, 1> to memref<*xf32> +// CHECK: %[[VAL_2:.*]] = memref.cast %[[ARG1]] : memref<*xf32, 1> to memref<3x2xf32, 1> +// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<3x2xf32, 1> to memref<3x2xf32> +// CHECK: return %[[VAL_1]], %[[VAL_3]] : memref<*xf32>, memref<3x2xf32> +// CHECK: } +func.func @cast(%arg0: memref<2xf32, 1>, %arg1: memref<*xf32, 1>) -> (memref<*xf32>, memref<3x2xf32>) { + %memspacecast = memref.memory_space_cast %arg0 : memref<2xf32, 1> to memref<2xf32> + %1 = memref.cast %memspacecast : memref<2xf32> to memref<*xf32> + %memspacecast_1 = memref.memory_space_cast %arg1 : memref<*xf32, 1> to memref<*xf32> + %2 = memref.cast %memspacecast_1 : memref<*xf32> to memref<3x2xf32> + return %1, %2 : memref<*xf32>, memref<3x2xf32> +} + +// CHECK-LABEL: func.func @view( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref { +// CHECK: %[[VAL_0:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_1:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[ARG2]], %[[VAL_0]]] : memref to memref +// CHECK: %[[VAL_2:.*]] = memref.memory_space_cast %[[VAL_1]] : memref to memref +// CHECK: return %[[VAL_2]] : memref +// CHECK: } +func.func @view(%arg0: memref, %arg1: index, %arg2: index) -> memref { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %c100 = arith.constant 100 : index + %view = memref.view %memspacecast[%arg1][%arg2, %c100] : memref to memref + return %view : memref +} + +// CHECK-LABEL: func.func @subview( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> { +// CHECK: %[[VAL_0:.*]] = memref.subview %[[ARG0]][4, 2] [8, 2] [3, 2] : memref to memref<8x2xf32, strided<[?, 2], offset: ?>, 1> +// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<8x2xf32, strided<[?, 2], offset: ?>, 1> to memref<8x2xf32, strided<[?, 2], offset: ?>> +// CHECK: return %[[VAL_1]] : memref<8x2xf32, strided<[?, 2], offset: ?>> +// CHECK: } +func.func @subview(%arg0: memref, %arg1: index) -> memref<8x2xf32, strided<[?, 2], offset: ?>> { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %subview = memref.subview %memspacecast[4, 2] [8, 2] [3, 2] : memref to memref<8x2xf32, strided<[?, 2], offset: ?>> + return %subview : memref<8x2xf32, strided<[?, 2], offset: ?>> +} + +// CHECK-LABEL: func.func @reinterpret_cast( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: {{\[}}%[[VAL_1]]], sizes: [10, %[[VAL_0]]], strides: {{\[}}%[[VAL_0]], 1] : memref to memref<10x?xf32, strided<[?, 1], offset: ?>, 1> +// CHECK: %[[VAL_3:.*]] = memref.memory_space_cast %[[VAL_2]] : memref<10x?xf32, strided<[?, 1], offset: ?>, 1> to memref<10x?xf32, strided<[?, 1], offset: ?>> +// CHECK: return %[[VAL_3]] : memref<10x?xf32, strided<[?, 1], offset: ?>> +// CHECK: } +func.func @reinterpret_cast(%arg0: memref, %arg1: index) -> memref<10x?xf32, strided<[?, 1], offset: ?>> { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %reinterpret_cast = memref.reinterpret_cast %memspacecast to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> + return %reinterpret_cast : memref<10x?xf32, strided<[?, 1], offset: ?>> +} + +// CHECK-LABEL: func.func @reshape( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: memref<1xindex>) -> memref { +// CHECK: %[[VAL_0:.*]] = memref.reshape %[[ARG0]](%[[ARG1]]) : (memref, memref<1xindex>) -> memref +// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref to memref +// CHECK: return %[[VAL_1]] : memref +// CHECK: } +func.func @reshape(%arg0: memref, %arg1: memref<1xindex>) -> memref { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %reshape = memref.reshape %memspacecast(%arg1) : (memref, memref<1xindex>) -> memref + return %reshape : memref +} + +// CHECK-LABEL: func.func @expand_shape( +// CHECK-SAME: %[[ARG0:.*]]: memref<12xf32, 1>) -> memref<3x4xf32> { +// CHECK: %[[VAL_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1]] output_shape [3, 4] : memref<12xf32, 1> into memref<3x4xf32, 1> +// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<3x4xf32, 1> to memref<3x4xf32> +// CHECK: return %[[VAL_1]] : memref<3x4xf32> +// CHECK: } +func.func @expand_shape(%arg0: memref<12xf32, 1>) -> memref<3x4xf32> { + %memspacecast = memref.memory_space_cast %arg0 : memref<12xf32, 1> to memref<12xf32> + %expand_shape = memref.expand_shape %memspacecast [[0, 1]] output_shape [3, 4] : memref<12xf32> into memref<3x4xf32> + return %expand_shape : memref<3x4xf32> +} + +// CHECK-LABEL: func.func @collapse_shape( +// CHECK-SAME: %[[ARG0:.*]]: memref<3x4xf32, 1>) -> memref<12xf32> { +// CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1]] : memref<3x4xf32, 1> into memref<12xf32, 1> +// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref<12xf32, 1> to memref<12xf32> +// CHECK: return %[[VAL_1]] : memref<12xf32> +// CHECK: } +func.func @collapse_shape(%arg0: memref<3x4xf32, 1>) -> memref<12xf32> { + %memspacecast = memref.memory_space_cast %arg0 : memref<3x4xf32, 1> to memref<3x4xf32> + %collapse_shape = memref.collapse_shape %memspacecast [[0, 1]] : memref<3x4xf32> into memref<12xf32> + return %collapse_shape : memref<12xf32> +} + +// CHECK-LABEL: func.func @transpose( +// CHECK-SAME: %[[ARG0:.*]]: memref) -> memref { +// CHECK: %[[VAL_0:.*]] = memref.transpose %[[ARG0]] (d0, d1) -> (d1, d0) : memref to memref +// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref to memref +// CHECK: return %[[VAL_1]] : memref +// CHECK: } +func.func @transpose(%arg0: memref) -> memref { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %transpose = memref.transpose %memspacecast (d0, d1) -> (d1, d0) : memref to memref + return %transpose : memref +} + +// CHECK-LABEL: func.func @atomic_rmw( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index, +// CHECK-SAME: %[[ARG2:.*]]: f32) -> f32 { +// CHECK: %[[VAL_0:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[ARG0]]{{\[}}%[[ARG1]]] : (f32, memref) -> f32 +// CHECK: return %[[VAL_0]] : f32 +// CHECK: } +func.func @atomic_rmw(%arg0: memref, %arg1: index, %arg2: f32) -> f32 { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %0 = memref.atomic_rmw addf %arg2, %memspacecast[%arg1] : (f32, memref) -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @assume_alignment( +// CHECK-SAME: %[[ARG0:.*]]: memref) -> memref { +// CHECK: %[[VAL_0:.*]] = memref.assume_alignment %[[ARG0]], 16 : memref +// CHECK: %[[VAL_1:.*]] = memref.memory_space_cast %[[VAL_0]] : memref to memref +// CHECK: return %[[VAL_1]] : memref +// CHECK: } +func.func @assume_alignment(%arg0: memref) -> memref { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %1 = memref.assume_alignment %memspacecast, 16 : memref + return %1 : memref +} + +// CHECK-LABEL: func.func @op_with_cast_sequence( +// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf32, 1>, +// CHECK-SAME: %[[ARG1:.*]]: index, +// CHECK-SAME: %[[ARG2:.*]]: f32) -> memref<16xf32> { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32, 1> into memref<4x2x2xf32, 1> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0, 1, 2]] : memref<4x2x2xf32, 1> into memref<16xf32, 1> +// CHECK: %[[VAL_4:.*]] = memref.memory_space_cast %[[VAL_3]] : memref<16xf32, 1> to memref<16xf32> +// CHECK: %[[VAL_5:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1> +// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_5]], %[[ARG2]] : f32 +// CHECK: memref.store %[[VAL_6]], %[[VAL_3]]{{\[}}%[[VAL_1]]] : memref<16xf32, 1> +// CHECK: %[[VAL_7:.*]] = memref.atomic_rmw addf %[[ARG2]], %[[VAL_3]]{{\[}}%[[VAL_0]]] : (f32, memref<16xf32, 1>) -> f32 +// CHECK: return %[[VAL_4]] : memref<16xf32> +// CHECK: } +func.func @op_with_cast_sequence(%arg0: memref<4x4xf32, 1>, %arg1: index, %arg2: f32) -> memref<16xf32> { + %memspacecast = memref.memory_space_cast %arg0 : memref<4x4xf32, 1> to memref<4x4xf32> + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %expanded = memref.expand_shape %memspacecast [[0], [1, 2]] output_shape [4, 2, 2] : memref<4x4xf32> into memref<4x2x2xf32> + %collapsed = memref.collapse_shape %expanded [[0, 1, 2]] : memref<4x2x2xf32> into memref<16xf32> + %loaded = memref.load %collapsed[%c0] : memref<16xf32> + %added = arith.addf %loaded, %arg2 : f32 + memref.store %added, %collapsed[%c0] : memref<16xf32> + %atomic_result = memref.atomic_rmw addf %arg2, %collapsed[%c4] : (f32, memref<16xf32>) -> f32 + return %collapsed : memref<16xf32> +} + +// CHECK-LABEL: func.func @transfer_read_write( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_1:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_0]] : memref, vector<4xf32> +// CHECK: vector.transfer_write %[[VAL_1]], %[[ARG0]]{{\[}}%[[ARG1]]] : vector<4xf32>, memref +// CHECK: return +// CHECK: } +func.func @transfer_read_write(%arg0: memref, %arg1: index) { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %memspacecast[%arg1], %c0 : memref, vector<4xf32> + vector.transfer_write %0, %memspacecast[%arg1] : vector<4xf32>, memref + return +} + +// NOTE: The operations disappear because they can get folded. +// CHECK-LABEL: func.func @transfer_read_write_tensor( +// CHECK-SAME: %[[ARG0:.*]]: tensor, +// CHECK-SAME: %[[ARG1:.*]]: index) -> tensor { +// CHECK: return %[[ARG0]] : tensor +// CHECK: } +func.func @transfer_read_write_tensor(%arg0: tensor, %arg1: index) -> tensor { + %c0 = arith.constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%arg1], %c0 : tensor, vector<4xf32> + %1 = vector.transfer_write %0, %arg0[%arg1] : vector<4xf32>, tensor + return %1 : tensor +} + +// CHECK-LABEL: func.func @vector_load_store( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) { +// CHECK: %[[VAL_0:.*]] = vector.load %[[ARG0]]{{\[}}%[[ARG1]]] : memref, vector<4xf32> +// CHECK: vector.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[ARG1]]] : memref, vector<4xf32> +// CHECK: return +// CHECK: } +func.func @vector_load_store(%arg0: memref, %arg1: index) { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %0 = vector.load %memspacecast[%arg1] : memref, vector<4xf32> + vector.store %0, %memspacecast[%arg1] : memref, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @masked_load_store( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> +// CHECK: %[[VAL_2:.*]] = vector.maskedload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: vector.maskedstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref, vector<4xi1>, vector<4xf32> +// CHECK: return +// CHECK: } +func.func @masked_load_store(%arg0: memref, %arg1: index) { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1> + %passthrough = arith.constant dense<0.0> : vector<4xf32> + %0 = vector.maskedload %memspacecast[%arg1], %mask, %passthrough : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> + vector.maskedstore %memspacecast[%arg1], %mask, %0 : memref, vector<4xi1>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @gather_scatter( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense : vector<4xi1> +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = vector.gather %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_0]] : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: vector.scatter %[[ARG0]]{{\[}}%[[VAL_3]]] {{\[}}%[[VAL_2]]], %[[VAL_1]], %[[VAL_4]] : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> +// CHECK: return +// CHECK: } +func.func @gather_scatter(%arg0: memref, %arg1: index) { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> + %mask = arith.constant dense : vector<4xi1> + %passthrough = arith.constant dense<0.0> : vector<4xf32> + %0 = vector.gather %memspacecast[%c0] [%indices], %mask, %passthrough : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> + vector.scatter %memspacecast[%c0] [%indices], %mask, %0 : memref, vector<4xindex>, vector<4xi1>, vector<4xf32> + return +} + +// CHECK-LABEL: func.func @expandload_compressstore( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: index) { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> +// CHECK: %[[VAL_2:.*]] = vector.expandload %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_0]] : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: vector.compressstore %[[ARG0]]{{\[}}%[[ARG1]]], %[[VAL_1]], %[[VAL_2]] : memref, vector<4xi1>, vector<4xf32> +// CHECK: return +// CHECK: } +func.func @expandload_compressstore(%arg0: memref, %arg1: index) { + %memspacecast = memref.memory_space_cast %arg0 : memref to memref + %mask = arith.constant dense<[true, true, false, false]> : vector<4xi1> + %passthrough = arith.constant dense<0.0> : vector<4xf32> + %0 = vector.expandload %memspacecast[%arg1], %mask, %passthrough : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> + vector.compressstore %memspacecast[%arg1], %mask, %0 : memref, vector<4xi1>, vector<4xf32> + return +}