diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 1f0e5cf7e7f56..c5afff7354ca3 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -1,4 +1,4 @@ -//===-- AMDGPU.td - AMDGPU dialect definitions *- tablegen -*------===// +//===-- AMDGPU.td - AMDGPU dialect *- tablegen -*--------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,1798 +6,9 @@ // //===----------------------------------------------------------------------===// -#ifndef AMDGPU -#define AMDGPU +#ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPU_TD +#define MLIR_DIALECT_AMDGPU_IR_AMDGPU_TD -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir/IR/EnumAttr.td" -include "mlir/IR/Properties.td" -include "mlir/IR/OpBase.td" +include "mlir/Dialect/AMDGPU/IR/AMDGPUOps.td" -def AMDGPU_Dialect : Dialect { - let name = "amdgpu"; - let cppNamespace = "::mlir::amdgpu"; - let description = [{ - The `AMDGPU` dialect provides wrappers around AMD-specific functionality - and LLVM intrinsics. These wrappers should be used in conjunction with - more generic dialects, such as `gpu` and `vector`, when generating LLVM IR - that will eventually be executed on AMD hardware. - - # What goes here? - In many cases, AMD GPU functionality can be accessed either though generic - operations (such as those in the `gpu`, `vector`, or `math`) or through - the `rocdl` dialect's intrinsic wrappers. However, there are instances where - AMD-specific functionally benefits from a wrapper around the underlying - LLVM intrinsics. - - In general terms, operations or types should be added to this dialect when they - wrap some AMD-specific functionality in a way that makes it work better with the - MLIR ecosystem and its types or when those buitins would be needlessly - complex to work with (such as if they features magic constants at the LLVM level). - - An additional set of operations that belong in this dialect are those that - have chipset-specific differences that can be abstracted over in a useful way. - - To give some concrete examples: - - - `amdgpu.mfma` and `amdgpu.wmma` exist in order to make a large set of - intrinsics more compatible with the MLIR type system (such as by allowing - 8-bit float vectors to be passed as `vector` or - `vector` instead of as packed 32-bit integers whose element type - is controlled by separate operator-level constants. These operations also - allow the same `amdgpu.mfma` operation to be used regardless of the target - chip. - - `amdgpu.swizzle_bitmode` provides a wrapper around the `ds.swizzle` intrinsic, - allowing a wider range of types (such as `vector<2xf16>`) to be used natively - and eliminating the need to pack the and, or, and xor components using opaque - shifts. - - Operations like `amdgpu.gather_to_lds` provide `memref`-ized wrappers around - intrinsics that take a pointer, and are nontrivial enough to justify inclusion - in this dialect. - - - Note that simple intrinsics like `rocdl.sin` or `rocdl.s.barrier` should not - receive wrapper operations, as nothing is gained from the duplicate operation. - As a rule of thumb, if an operation's rewrite in AMDGPUToROCDL would be only - a `replaceOpWithNewOp` call, no AMDGPU dialect operation is needed. - - # Design guidelines - - Operations should leverage MLIR's "standard" types where possible. MLIR has - a more extensible type system than LLVM (especially in the area of small floats) - and those types should be used to create more ergonomic wrappers. In particular, - intrinsics that take pointers should have wrappers in this dialect that take - `memref` arguments and indices. - - Operations should use properties or attributes in cases where the underlying - intrinsic uses `immarg`s (except in cases where that attribute can be represented - in the type system). - - If it is possible to generalize the types of an operation, it should be done. - For example, the underlying operations for permutations and swizzles always - take 32-bit operands. Their AMDGPU wrappers can take any type, and will apply - padding and expansion to multiple instructions as needed. This makes these - operations easier to target because it hides the bitcasts and extracts - until the final lowering. - - When the underlying operation uses magic constants, those should be presented - in a more programmer-friendly fashion, such as through enums or though - using separate arguments that are later combined. (For example, see the - design of the `amdgpu.dpp` and `amdgpu.fat_raw_buffer_cast` operations.) - - If sufficiently similar functionality on multiple hardware generations can be - encapsulated into a single operation, it should be done. The lowering to - intrinsics should either throw an error when an unsupported capability is - used or ignore it. Which of these is two failure modes is more appropriate - depends on the nature of the feature, but errors are a safe default choice. - - # Documentation guidelines - - AMDGPU dialect operations should document how any abstractions they introduce - translate to LLVM intrinsics or hardware operations. - - While documenting the semantics of the underlying operations is not required, - is preferred to provide an overview of the operation's functionality, - especially in cases where the documentation is widely distributed. Someone - looking at an AMDGPU dialect operation should be able to generally understand - what it does and have found the keywords they'll need for more detail. - - Operation documentation should include usage examples. - - Note that this dialect uses LLVM's gfx numbers to refer to individual - architectures/chipsets and not product names or codenames. - }]; - - - let dependentDialects = [ - "ROCDL::ROCDLDialect", - "arith::ArithDialect", - "gpu::GPUDialect" - ]; - let useDefaultAttributePrinterParser = 1; - let useDefaultTypePrinterParser = 1; -} - -def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">; - -def AnyIntegerOrFloatOr1DVector : - AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>; - -//===----------------------------------------------------------------------===// -// AMDGPU general attribute definitions -//===----------------------------------------------------------------------===// - -def AMDGPU_AddressSpace : I32EnumAttr<"AddressSpace", - "AMDGPU-specific address spaces", - [ - I32EnumAttrCase<"FatRawBuffer", 0, "fat_raw_buffer">, - I32EnumAttrCase<"BufferRsrc", 1, "buffer_rsrc">, - I32EnumAttrCase<"FatStructuredBuffer", 2, "fat_structured_buffer">, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::amdgpu"; -} - -def AMDGPU_AddressSpaceAttr : EnumAttr { - let description = [{ - AMDGPU-specific memory spaces that may not have exact analogues on other - GPU targets or backends. - - - `fat_raw_buffer` is the memory space used when a memref is stored as - as a "buffer fat pointer" - that is, a buffer resource (that is set up to - use raw byte-level indexing) along with its offset. The AMDGPU backend - implements `ptr addrspace(7)` to represent these fat pointers so that - buffer resources (which allow advanced features like bounds checking or - cache swizzling) can be used like ordinary LLVM pointers or memrefs. - See also the `fat_raw_buffer_cast` operation - - `buffer_rsrc` is the memory space for `ptr addrspace(8)`, representing a - buffer resource. It should not be used for memrefs, since it does not support - indexing - - `fat_structured_buffer` represents `ptr addrspace(9)`, a buffer resource - that carries both an index and offset field, which are used for complex - structured indexing that is primarily seen in graphics applications. This - is also incompatible with the simple indexing model supported by memref. - }]; - let assemblyFormat = "`<` $value `>`"; -} - -//===----------------------------------------------------------------------===// -// AMDGPU Type definitions -//===----------------------------------------------------------------------===// - -class AMDGPU_Type traits = []> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> { - let summary = "Pair of base addresses that move data between LDS and global storage."; - let description = [{ - This type is opaque and it is used to represent a struct of two addresses. - One address is in LDS while the other is in global memory. - - The value defined by this operation is only intended to be used by - amdgpu.tdm_make_descriptor. - }]; - let parameters = (ins "Type":$elementType); - let builders = [ - TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ - return $_get(elementType.getContext(), elementType); - }]> - ]; - let assemblyFormat = "`<` $elementType `>`"; -} - -def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> { - let summary = "Pair of base addresses that move data between LDS and global storage."; - let description = [{ - This type is opaque and it is used to represent a struct of two addresses. - One address is in LDS while the other is in global memory. - - This operation is similar to amdgpu.tdm_make_base but intended to be - used in gather mode. - - The value defined by this operation is only intended to be used by - amdgpu.tdm_make_gather_descriptor. - }]; - let parameters = (ins "Type":$elementType, "Type":$indexType); - let builders = [ - TypeBuilderWithInferredContext<(ins "Type":$elementType, "Type": $indexType), [{ - return $_get(elementType.getContext(), elementType, indexType); - }]> - ]; - let assemblyFormat = "`<` $elementType `,` $indexType`>`"; - let genVerifyDecl = 1; -} - -def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> { - let summary = "Descriptors used in tensor store/load operations."; - let description = [{ - This type is opaque and corresponds to the two or four descriptor groups - used in tensor_load_to_lds or tensor_store_from_lds. - }]; -} - -class AMDGPU_ConcreteVector : - FixedVectorOfLengthAndType<[length], [elem]>, - BuildableType< - "::mlir::VectorType::get({" # length # "} ," - # elem.builderCall # ")">; - -//===----------------------------------------------------------------------===// -// AMDGPU Op definitions -//===----------------------------------------------------------------------===// - -class AMDGPU_Op traits = []> : - Op {} - -def AMDGPU_ExtPackedFp8Op : - AMDGPU_Op<"ext_packed_fp8", [Pure]>, - Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN, - VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source, - ConfinedAttr]>:$index)>, - Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> { - let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats"; - - let description = [{ - Extend one or two 8-bit floats in `source[index]` to a 32-bit float or - two floats and return them. - - This rather unusual signature arises from the fact that AMD GPUs cannot - easily work with sub 32-bit quantities, so the compiler intrinsics for - extending 8-bit floats (which are, currently, the only way to work with - this operation) take packed vectors of 4 such floats. - - If the passed-in vector has fewer than four elements, or the input is scalar, - the remaining values in the <4 x i8> will be filled with - undefined values as needed. - }]; - let assemblyFormat = [{ - attr-dict $source `[` $index `]` `:` type($source) `to` type($res) - }]; -} - -def AMDGPU_ScaledExtPackedMatrixOp - : AMDGPU_Op<"scaled_ext_packed_matrix", [Pure, AllShapesMatch<["source", "res"]>]>, - Arguments<( - ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>, - FixedVectorOfShapeAndType<[8], F8E4M3FN>, - FixedVectorOfShapeAndType<[8], F8E5M2>, - FixedVectorOfShapeAndType<[16], F6E2M3FN>, - FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source, - FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale, - ConfinedAttr]>:$blockSize, - ConfinedAttr]>:$firstScaleLane, - ConfinedAttr, IntMaxValue<3>]>:$firstScaleByte)>, - Results<( - outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>, - FixedVectorOfShapeAndType<[8], F16>, - FixedVectorOfShapeAndType<[8], BF16>, - FixedVectorOfShapeAndType<[16], F32>, - FixedVectorOfShapeAndType<[16], F16>, - FixedVectorOfShapeAndType<[16], BF16>]>:$res)> { - - let summary = "Extend a wave-wide matrix of packed floating point values"; - - let description = [{ - Extend matrix of microfloats (8 or 16 elements per lane) using a set of scales - that may be stored on other lanes. - - The scales applied to the input microfloats are stored in bytes which - come from the `scales` input provided in a *half* of the wave identified - by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends - on the type of `source`. The 16 vectors in consecutive lanes starting from - `firstScaleLane` (which we'll call the scale vectors) will be used by both - halves of the wave (with lane L reading from L % 16'th scale vector). - - When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN each half of the - wave will use a different byte. The first one being `firstScaleByte` and - the second one being `firstScaleByte` + 1. When the block size is 32, - `firstScaleByte` can be either 0 or 2, selecting halves of the scale vectors. - Lanes 0-15 will read from `firstScaleByte` and lanes 16-31 will read - from `firstScaleByte` + 1. - - - For example: - ```mlir - // Input: 8-element vector of F8E4M3FN, converting to F32 - // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1 - %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) - blockSize(32) firstScaleLane(0) firstScaleByte(0) - : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> - - // Input: 16-element vector of F6E2M3FN, converting to F16 - // Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3 - %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) - blockSize(32) firstScaleLane(16) firstScaleByte(2) - : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - ``` - - When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN and - the block size is 16, `firstScaleByte` can be 0 or 1. - Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors, - while lanes 16-31 read from `firstScaleByte` + 2. - For example: - ```mlir - // Input: 8-element vector of F8E5M2, converting to BF16 - // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2) - %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) - blockSize(16) firstScaleLane(0) firstScaleByte(0) - : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> - - // Input: 16-element vector of F6E3M2FN, converting to F32 - // Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2) - %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) - blockSize(16) firstScaleLane(16) firstScaleByte(1) - : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> - ``` - - Note: the layout for the scales generally mirrors how the WMMA - instructions use for matrix scales. These selection operands allows - one to choose portions of the matrix to convert. - - When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32, - then the same byte will be used by both halves of the wave. - In this case, `firstScaleByte` can be any value from 0 to 3. - - When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16, - following combinations are allowed: - * `firstScaleLane(0), firstScaleByte(0)` - * `firstScaleLane(16), firstScaleByte(2)` - all other combinations are reserved. - - Available on gfx1250+. - }]; - - let assemblyFormat = [{ - attr-dict $source - `scale` `(` $scale `)` - `blockSize` `(` $blockSize `)` - `firstScaleLane` `(` $firstScaleLane`)` - `firstScaleByte` `(` $firstScaleByte `)` - `:` type($source) `,` type($scale) `->` type($res) - }]; - - let hasVerifier = 1; - -} - -def AMDGPU_ScaledExtPackedOp - : AMDGPU_Op<"scaled_ext_packed", [Pure]>, - Arguments<( - ins AnyTypeOf<[VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>, - VectorOfLengthAndType<[1, 2, 3, 4, 5, 6, 7, 8], - [F4E2M1FN]>]>:$source, - F32:$scale, - ConfinedAttr]>:$index)>, - Results<( - outs AnyTypeOf<[FixedVectorOfLengthAndType<[2], [F32]>, - FixedVectorOfLengthAndType<[2], [F16]>, - FixedVectorOfLengthAndType<[2], [BF16]>]>:$res)> { - let summary = "Extend a vector of packed floating point values"; - - let description = [{ - Extend and scale two packed floats in `source[index]` to two floats and - return them. - - This rather unusual signature arises from the fact that AMD GPUs cannot - easily work with sub 32-bit quantities, so the compiler intrinsics for - extending 8-bit floats (which are, currently, the only way to work with - this operation) take packed vectors of 2 such floats. - - If the passed-in vector has fewer than two elements, or the input is scalar, - the remaining values in the <2 x i8> will be filled with - undefined values as needed. - }]; - let assemblyFormat = [{ - attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res) - }]; -} - -def AMDGPU_PackedTrunc2xFp8Op : - AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>, - Arguments<(ins F32:$sourceA, - Optional:$sourceB, - ConfinedAttr]>:$wordIndex, - Optional>:$existing)>, - Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> { - let summary = "Round two floats into a packed vector of 8-bit floats"; - let description = [{ - Round the inputs `sourceA` and `sourceB` (which is undefined if not - specified) into the low or high word (bottom two or top two) elements - of the returned vector, keeping the other two elements of `existing` - unchanged if present (or undefined if it was not passed in). - - The reason for this odd signature is that AMD GPUs cannot easily work with - sub-registers, and so the conversion intrinsics (which are currently the - only way to work with 8-bit float types) take packed vectors of 4 8-bit - values. - }]; - let assemblyFormat = [{ - attr-dict $sourceA `,` ($sourceB^):(`undef`)? - `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]` - `:` type($sourceA) `to` type($res) (`into` type($existing)^)? - }]; - let hasVerifier = 1; -} - -def AMDGPU_PackedScaledTruncOp - : AMDGPU_Op<"packed_scaled_trunc", [Pure]>, - Arguments<(ins VectorOfLengthAndType<[1, 2], [F32, F16, BF16]>:$source, - F32:$scale, - ConfinedAttr]>:$index, - Optional, - FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>>:$existing)>, - Results<( - outs AnyTypeOf<[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>, - FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>:$res)> { - let summary = "Round two floats into a packed vector of floats"; - let description = [{ - Scale and round the inputs `source` (which is undefined if not - specified) into the low or high word (bottom two or top two) elements - of the returned vector, keeping the other two elements of `existing` - unchanged if present (or undefined if it was not passed in). - - The reason for this odd signature is that AMD GPUs cannot easily work with - sub-registers, and so the conversion intrinsics take 32-bit wide - packed vectors of float values. - }]; - let assemblyFormat = [{ - attr-dict $source `into` ($existing^):(`undef`)? `[` $index `]` - `,` $scale - `:` type($source) `to` type($res) (`into` type($existing)^)? - }]; - let hasVerifier = 1; -} - -def AMDGPU_PackedStochRoundFp8Op : - AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>, - Arguments<(ins F32:$source, - I32:$stochiasticParam, - ConfinedAttr]>:$storeIndex, - Optional>:$existing)>, - Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> { - let summary = "Round float stochiastically into a packed vector of 8-bit floats"; - let description = [{ - Round the input `source`, adding in `stochiasticParam`, and place it into - the `storeIndex`th element of `res`. - - If `existing` is passed in, elements of `res` other than the one at `storeIndex` - are copied from `existing`. - - The reason for this odd signature is that AMD GPUs cannot easily work with - sub-registers, and so the conversion intrinsics (which are currently the - only way to work with 8-bit float types) take packed vectors of 4 8-bit - values. - }]; - let assemblyFormat = [{ - attr-dict $source `+` $stochiasticParam - `into` ($existing^):(`undef`)? `[` $storeIndex `]` - `:` type($source) `to` type($res) (`into` type($existing)^)? - }]; - let hasVerifier = 1; -} - -def AMDGPU_FatRawBufferCastOp : - AMDGPU_Op<"fat_raw_buffer_cast", - [Pure, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - ViewLikeOpInterface, AttrSizedOperandSegments]>, - Arguments<(ins AnyMemRef:$source, - Optional:$validBytes, - Optional>:$cacheSwizzleStride, - DefaultValuedAttr:$boundsCheck, - UnitAttr:$resetOffset)>, - Results<(outs AnyMemRef:$result)> { - // TODO: Set `resetOffset` and `boundsCheck` to use `Property` once - // we implemented pythonic binding for `Property`. - let summary = "Create a raw buffer fat pointer that matches `memref`"; - let description = [{ - Wraps the memory pointed to by `source` as a raw buffer fat pointer, or, - in LLVM terms, a `ptr addrspace(7)`, returning a memref that has the same - sizes and layout but the `#amdgpu.address_space` - address space. - - This memref can be used with standard memref operations like `memref.load`, - `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant - buffer intrinsics. (`vector.masked_load/store` will work once there's backend - support for lowering them, and then this document will be updated) - - If `validBytes` is given, it is the number of bytes that will be valid as - an offset to `out`. If it is not provided, this will be inferred from - the size of the memref during lowering. This size is - max_{d = 0 upto rank(source)} (sizes[d] * strides[d]) * sizeof(element type). - - The flags of the buffer descriptor will be set up to enable raw usage - - for example, stride = 0, add_tid = 0, and so on. The `boundsCheck` - property determines if bounds checking is enabled or not (on architectures - where this can be controlled - that is, on RDNA chips). - - If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled - on architectures that support it. This swizzling, unlike the main swizzling - mode (whose usage makes a buffer non-raw) does not affect index calculation, - but does affect cache behavior. Mixing access between cache-swizzled raw - buffers and other forms of memory access, like ordinary pointer loads or - unswizzled buffer pointers can cause incorrect behavior and must be avoided. - - This operation preserves the sizes, strides, and offset of the input - memref - they'll be added in by `memref.load` later. However, if - `resetOffset` is set, that offset will be added to the base pointer. - If the value of the memref's offset is not uniform (independent of the lane/thread ID), - this will lead to substantially decreased performance due to the need for - a waterfall loop on the base address of the buffer resource. - }]; - - let extraClassDeclaration = [{ - Value getViewSource() { return getSource(); } - }]; - - let assemblyFormat = [{ - $source oilist (`validBytes` `(` $validBytes `)` - | `cacheSwizzleStride` `(` $cacheSwizzleStride `)` - | `boundsCheck` `(` $boundsCheck `)` - | `resetOffset` $resetOffset ) - attr-dict `:` type($source) `to` type($result) - }]; - - let hasVerifier = 1; -} - -/// Raw buffer load -def AMDGPU_RawBufferLoadOp : - AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>, - AttrSizedOperandSegments]>, - Arguments<(ins Arg:$memref, - Variadic:$indices, - DefaultValuedAttr:$boundsCheck, - OptionalAttr:$indexOffset, - Optional:$sgprOffset)>, - Results<(outs AnyType:$value)> { - - let summary = "Raw Buffer load, exposing GCN features"; - let description = [{ - The `amdgpu.raw_buffer_load` op is a wrapper around the buffer load intrinsics - available on AMD GPUs, including extensions in newer GPUs. - - The index into the buffer is computed as for `memref.load` with the additon - of `indexOffset` and `sgprOffset` (which **may or may not** be considered - in bounds checks and includes any offset present on the memref type if it's - non-zero). - - All indices and offsets are in units of the memref's data type and are - converted to bytes during lowering. - - When a load is out of bounds, the instruction returns zero. - Partially-out of bounds have chipset-dependent behavior: whether reading - 2 elements starting at index 7 of a `memref<8xf32>` returns the last element - in the first vector component depends on the architecture. - - The memref struct is converted into a buffer resource (a V#) and the arguments - are translated to intrinsic arguments as follows: - - The base address of the buffer is the base address of the memref - - The stride is 0 to enable raw mode - - The number of records is the size of the memref, in bytes - In the case of dynamically-shaped memrefs, this is computed at runtime - as max_d (size(d) * stride(d)) * sizeof(elementType(memref)) - - The offset enable bit is 1, the index enable bit is 0. - - The thread ID addition bit is off - - If `boundsCheck` is false and the target chipset is RDNA, OOB_SELECT is set - to 2 to disable bounds checks, otherwise it is 3 - - The cache coherency bits are off - }]; - let assemblyFormat = [{ - attr-dict $memref `[` $indices `]` - (`sgprOffset` $sgprOffset^)? `:` - type($memref) (`,` type($indices)^)? `->` type($value) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -/// Raw buffer store -def AMDGPU_RawBufferStoreOp : - AMDGPU_Op<"raw_buffer_store", [AllElementTypesMatch<["value", "memref"]>, - AttrSizedOperandSegments]>, - Arguments<(ins AnyType:$value, - Arg:$memref, - Variadic:$indices, - DefaultValuedAttr:$boundsCheck, - OptionalAttr:$indexOffset, - Optional:$sgprOffset)> { - - let summary = "Raw Buffer Store, exposing GCN features"; - let description = [{ - The `amdgpu.raw_buffer_store` op is a wrapper around the buffer store - intrinsics available on AMD GPUs, including extensions in newer GPUs. - - The store index is computed as in `memref.store` with the addition of - `indexOffset` (which is included for uniformity with atomics and may be useful - when writing vectorized code) and `sgprOffset` (which is added after bounds - checks and implicitly includes the offset of the memref type if non-zero). - All index components are in terms of the elements of the memref, not bytes, - and are scaled up appropriately. - - Out of bounds stores are ignored in hardware. - Wthether a vector write that includes some in-bounds and soeme out-of-bounds - components is partically completed is chipset-dependent. - - See `amdgpu.raw_buffer_load` for a description of how the underlying - instruction is constructed. - }]; - let assemblyFormat = [{ - attr-dict $value `->` $memref `[` $indices `]` - (`sgprOffset` $sgprOffset^)? `:` - type($value) `->` type($memref) (`,` type($indices)^)? - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -// Raw buffer atomic compare-and-swap -def AMDGPU_RawBufferAtomicCmpswapOp : - AMDGPU_Op<"raw_buffer_atomic_cmpswap", [ - AttrSizedOperandSegments, - AllTypesMatch<["src", "cmp", "value"]>, - AllElementTypesMatch<["value", "memref"]>]>, - Arguments<(ins AnyType:$src, - AnyType:$cmp, - Arg:$memref, - Variadic:$indices, - DefaultValuedAttr:$boundsCheck, - OptionalAttr:$indexOffset, - Optional:$sgprOffset)>, - Results<(outs AnyType:$value)> { - - let summary = "Raw Buffer Atomic compare-and-swap"; - let description = [{ - The `amdgpu.raw_buffer_atomic_cmpswap` op is a wrapper around the - buffer-based atomic compare-and-swap min available on AMD GPUs. - - The index into the buffer is computed as for `memref.store` with the addition - of `indexOffset` (which is used to aid in emitting vectorized code) and, - if present `sgprOffset` (which is added after bounds checks and includes - any non-zero offset on the memref type). - - All indexing components are given in terms of the memref's element size, not - the byte lengths required by the intrinsic. - - Out of bounds atomic operations are ignored in hardware. - - See `amdgpu.raw_buffer_load` for a description of how the underlying - instruction is constructed. - }]; - let assemblyFormat = [{ - attr-dict $src `,` $cmp `->` $memref `[` $indices `]` - (`sgprOffset` $sgprOffset^)? `:` - type($value) `->` type($memref) `,` type($indices) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -// Raw buffer atomic floating point add -def AMDGPU_RawBufferAtomicFaddOp : - AMDGPU_Op<"raw_buffer_atomic_fadd", [AllElementTypesMatch<["value", "memref"]>, - AttrSizedOperandSegments]>, - Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$value, - Arg:$memref, - Variadic:$indices, - DefaultValuedAttr:$boundsCheck, - OptionalAttr:$indexOffset, - Optional:$sgprOffset)> { - - let summary = "Raw Buffer Floating-point Atomic Add (MI-* only)"; - let description = [{ - The `amdgpu.raw_buffer_atomic_fadd` op is a wrapper around the - buffer-based atomic floating point addition available on the MI-* series - of AMD GPUs. - - The index into the buffer is computed as for `memref.store` with the addition - of `indexOffset` (which is used to aid in emitting vectorized code) and, - if present `sgprOffset` (which is added after bounds checks and includes - any non-zero offset on the memref type). - - All indexing components are given in terms of the memref's element size, not - the byte lengths required by the intrinsic. - - Out of bounds atomic operations are ignored in hardware. - - See `amdgpu.raw_buffer_load` for a description of how the underlying - instruction is constructed. - }]; - let assemblyFormat = [{ - attr-dict $value `->` $memref `[` $indices `]` - (`sgprOffset` $sgprOffset^)? `:` - type($value) `->` type($memref) `,` type($indices) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -// Raw buffer atomic floating point max -def AMDGPU_RawBufferAtomicFmaxOp : - AMDGPU_Op<"raw_buffer_atomic_fmax", [AllElementTypesMatch<["value", "memref"]>, - AttrSizedOperandSegments]>, - Arguments<(ins AnyTypeOf<[F32, F64]>:$value, - Arg:$memref, - Variadic:$indices, - DefaultValuedAttr:$boundsCheck, - OptionalAttr:$indexOffset, - Optional:$sgprOffset)> { - - let summary = "Raw Buffer Floating-point Atomic Max (non-GFX9)"; - let description = [{ - The `amdgpu.raw_buffer_atomic_fmax` op is a wrapper around the - buffer-based atomic floating point max available on AMD GPUs (except GFX9). - - The index into the buffer is computed as for `memref.store` with the addition - of `indexOffset` (which is used to aid in emitting vectorized code) and, - if present `sgprOffset` (which is added after bounds checks and includes - any non-zero offset on the memref type). - - All indexing components are given in terms of the memref's element size, not - the byte lengths required by the intrinsic. - - Out of bounds atomic operations are ignored in hardware. - - See `amdgpu.raw_buffer_load` for a description of how the underlying - instruction is constructed. - }]; - let assemblyFormat = [{ - attr-dict $value `->` $memref `[` $indices `]` - (`sgprOffset` $sgprOffset^)? `:` - type($value) `->` type($memref) `,` type($indices) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -// Raw buffer atomic signed integer max -def AMDGPU_RawBufferAtomicSmaxOp : - AMDGPU_Op<"raw_buffer_atomic_smax", [ - AttrSizedOperandSegments]>, - Arguments<(ins I32:$value, - Arg:$memref, - Variadic:$indices, - DefaultValuedAttr:$boundsCheck, - OptionalAttr:$indexOffset, - Optional:$sgprOffset)> { - - let summary = "Raw Buffer Signed Integer Atomic Max"; - let description = [{ - The `amdgpu.raw_buffer_atomic_smax` op is a wrapper around the - buffer-based atomic signed integer max available on AMD GPUs. - - The index into the buffer is computed as for `memref.store` with the addition - of `indexOffset` (which is used to aid in emitting vectorized code) and, - if present `sgprOffset` (which is added after bounds checks and includes - any non-zero offset on the memref type). - - All indexing components are given in terms of the memref's element size, not - the byte lengths required by the intrinsic. - - Out of bounds atomic operations are ignored in hardware. - - See `amdgpu.raw_buffer_load` for a description of how the underlying - instruction is constructed. - }]; - let assemblyFormat = [{ - attr-dict $value `->` $memref `[` $indices `]` - (`sgprOffset` $sgprOffset^)? `:` - type($value) `->` type($memref) `,` type($indices) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -// Raw buffer atomic unsigned integer min -def AMDGPU_RawBufferAtomicUminOp : - AMDGPU_Op<"raw_buffer_atomic_umin", [ - AttrSizedOperandSegments]>, - Arguments<(ins I32:$value, - Arg:$memref, - Variadic:$indices, - DefaultValuedAttr:$boundsCheck, - OptionalAttr:$indexOffset, - Optional:$sgprOffset)> { - - let summary = "Raw Buffer Unsigned Integer Atomic Min"; - let description = [{ - The `amdgpu.raw_buffer_atomic_umin` op is a wrapper around the - buffer-based atomic signed integer min available on AMD GPUs. - - The index into the buffer is computed as for `memref.store` with the addition - of `indexOffset` (which is used to aid in emitting vectorized code) and, - if present `sgprOffset` (which is added after bounds checks and includes - any non-zero offset on the memref type). - - All indexing components are given in terms of the memref's element size, not - the byte lengths required by the intrinsic. - - Out of bounds atomic operations are ignored in hardware. - - See `amdgpu.raw_buffer_load` for a description of how the underlying - instruction is constructed. - }]; - let assemblyFormat = [{ - attr-dict $value `->` $memref `[` $indices `]` - (`sgprOffset` $sgprOffset^)? `:` - type($value) `->` type($memref) `,` type($indices) - }]; - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - -def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm", - "The possible permutations for a DPP operation", - [ - I32EnumAttrCase<"quad_perm", 0>, - I32EnumAttrCase<"row_shl", 1>, - I32EnumAttrCase<"row_shr", 2>, - I32EnumAttrCase<"row_ror", 3>, - I32EnumAttrCase<"wave_shl", 4>, - I32EnumAttrCase<"wave_shr", 5>, - I32EnumAttrCase<"wave_ror", 6>, - I32EnumAttrCase<"wave_rol", 7>, - I32EnumAttrCase<"row_mirror", 8>, - I32EnumAttrCase<"row_half_mirror", 9>, - I32EnumAttrCase<"row_bcast_15", 10>, - I32EnumAttrCase<"row_bcast_31", 11> - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::amdgpu"; -} - -def AMDGPU_DPPPermAttr : EnumAttr; - -def AMDGPU_DPPOp : AMDGPU_Op<"dpp", - [Pure, SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>, - Arguments<(ins AnyType:$old, - AnyType:$src, - AMDGPU_DPPPermAttr:$kind, - OptionalAttr>:$permArgument, - DefaultValuedAttr:$row_mask, - DefaultValuedAttr:$bank_mask, - DefaultValuedAttr:$bound_ctrl)> { - let summary = "AMDGPU DPP operation"; - let description = [{ - This operation represents DPP functionality in a GPU program. - DPP provides the following operations: - - Full crossbar in a group of four (`quad_perm`) - - Wavefront shift left by one lane (`wave_shl`) - - Wavefront shift right by one lane (`wave_shr`) - - Wavefront rotate right by one lane (`wave_ror`) - - Wavefront rotate left by one lane (`wave_rol`) - - Row shift left by 1–15 lanes (`row_shl`) - - Row shift right by 1–15 lanes (`row_shr`) - - Row rotate right by 1–15 lanes (`row_ror`) - - Reverse within a row (`row_mirror`) - - Reverse within a half-row (`row_half_mirror`) - - Broadcast the 15th lane of each row to the next row (`row_bcast`) - - Broadcast lane 31 to rows 2 and 3 (`row_bcast`) - }]; - let results = (outs AnyType:$result); - let assemblyFormat = [{ - $old $src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result) - }]; - let hasVerifier = 1; -} - -def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode", - [Pure, AllTypesMatch<["result", "src"]>]>, - Arguments<(ins AnyIntegerOrFloatOr1DVector:$src, - I32Attr:$and_mask, - I32Attr:$or_mask, - I32Attr:$xor_mask - )> { - let summary = "AMDGPU ds_swizzle op, bitmode variant"; - let description = [{ - High-level wrapper on bitmode `rocdl.ds_swizzle` op, masks are represented - as separate fields so user won't need to do manual bitpacking. - - Supports arbitrary int/float/vector types, which will be repacked to i32 and - one or more `rocdl.ds_swizzle` ops during lowering. - }]; - let results = (outs AnyIntegerOrFloatOr1DVector:$result); - let assemblyFormat = [{ - $src $and_mask $or_mask $xor_mask attr-dict `:` type($result) - }]; -} - -def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["result", "src"]>]> { - let summary = "AMDGPU permlane swap op"; - let description = [{ - High-level wrapper on `rocdl.permlane{16,32}.swap` variants for permutations - on rows of lanes in a subgroup. - - Supports arbitrary int/float/vector types, which will be repacked to i32 and - one or more `rocdl.permlane_swap` ops during lowering. - Supported lane permutations: - - Swap the data between odd and even rows of 16 lanes - - Swap the data between the first 32 lanes and the last 32 lanes - - Example: - ```mlir - %0 = amdgpu.permlane_swap %src 16 : f16 - %1 = amdgpu.permlane_swap %src 32 { fetch_inactive = true, bound_ctrl = true } : f16 - ``` - - Operands: - * `$src`: Vector register to permute across lanes of the subgroup. - * `$row_length`: The length of a row to permute in number of lanes (valid values are 16 and 32). - * `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane. - `fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value. - `fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`). - * `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from - a disabled lane: use the value zero, or disable the write. - `bound_ctrl = false`: Do not write when source is from a disabled lane - `bound_ctrl = true`: Use zero as input if source is from a disabled lane - - Note: Lowering is only supported on gfx950 and up. - }]; - let arguments = (ins AnyIntegerOrFloatOr1DVector:$src, - I32Attr:$row_length, - DefaultValuedAttr:$fetch_inactive, - DefaultValuedAttr:$bound_ctrl); - let results = (outs AnyIntegerOrFloatOr1DVector:$result); - let assemblyFormat = [{ - $src $row_length attr-dict `:` type($result) - }]; - let hasVerifier = 1; -} - -def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> { - let summary = "Barrier that includes a wait for LDS memory operations."; - let description = [{ - **DEPRECATION NOTICE**: Unless you need the inline-assembly-based workaround - for gfx908/MI-100, you should represent this pattern with the equivalent - - ```mlir - gpu.barrier memfence [#gpu.address_space] - ``` - - instead. - - `amdgpu.lds_barrier` is both a barrier (all workitems in a workgroup must reach - the barrier before any of them may proceed past it) and a wait for all - operations that affect the Local Data Store (LDS) issued from that workgroup - to complete before the workgroup may continue. Since the LDS is per-workgroup - memory, this barrier may be used, for example, to ensure all workitems have - written data to LDS before any workitem attempts to read from it. - - Note that `lds_barrier` does **not** force reads to or from global memory - to complete before execution continues. Therefore, it should be used when - operations on global memory can be issued far in advance of when their results - are used (for example, by writing them to LDS). - - WARNING: On architectures that do not support the BackOffBarrier feature, - (those which will implement this barrier by emitting inline assembly), - use of this operation will impede the usabiliity of memory watches (including - breakpoints set on variables) when debugging. - }]; - let assemblyFormat = "attr-dict"; - let hasCanonicalizer = 1; -} - -def AMDGPU_SchedBarrierOpOpt : I32BitEnumAttr<"sched_barrier_opt_enum", - "The possible options for scheduling barriers", - [ - I32BitEnumAttrCaseNone<"none">, - I32BitEnumAttrCaseBit<"non_mem_non_sideffect", 0>, - I32BitEnumAttrCaseBit<"valu", 1>, - I32BitEnumAttrCaseBit<"salu", 2>, - I32BitEnumAttrCaseBit<"mfma_wmma", 3>, - I32BitEnumAttrCaseBit<"all_vmem", 4>, - I32BitEnumAttrCaseBit<"vmem_read", 5>, - I32BitEnumAttrCaseBit<"vmem_write", 6>, - I32BitEnumAttrCaseBit<"all_ds", 7>, - I32BitEnumAttrCaseBit<"ds_read", 8>, - I32BitEnumAttrCaseBit<"ds_write", 9>, - I32BitEnumAttrCaseBit<"transcendental", 10> - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::amdgpu"; -} - -def AMDGPU_SchedBarrierOpOptAttr : EnumAttr{ - let assemblyFormat = "`<` $value `>`"; -} - -def AMDGPU_SchedBarrierOp : - AMDGPU_Op<"sched_barrier">, - Arguments<(ins AMDGPU_SchedBarrierOpOptAttr:$opts)> - { - let summary = "Barrier that limits the backend scheduler of instruction movement"; - let description = [{ - `amdgpu.sched_barrier` serves as a barrier that could be - configured to restrict movements of instructions through it as - defined by sched_barrier_opts. - }]; - let assemblyFormat = [{ - `allow` `=` $opts attr-dict - }]; -} - -def AMDGPU_MemoryCounterWaitOp : - AMDGPU_Op<"memory_counter_wait">, - Arguments<(ins - OptionalAttr:$load, - OptionalAttr:$store, - OptionalAttr:$ds, - OptionalAttr:$exp, - OptionalAttr:$tensor - )> - { - let summary = "Wait for specified hardware counters"; - let description = [{ - Wait for the specified counters to be less-than or equal-to the provided - values before continuing. - - Counters can lower to different instructions on different architectires, - including clamping to the some HW supported max value or combining multiple - counters into one. - }]; - let assemblyFormat = [{ - oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` | `tensor` `(` $tensor `)` ) attr-dict - }]; - - let hasCanonicalizer = 1; -} - -def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB", - "The possible permutations of the lanes storing B available in an MFMA", - [ - I32EnumAttrCase<"none", 0>, - I32EnumAttrCase<"bcast_first_32", 1>, - I32EnumAttrCase<"bcast_second_32", 2>, - I32EnumAttrCase<"rotate_16_right", 3>, - I32EnumAttrCase<"bcast_first_16", 4>, - I32EnumAttrCase<"bcast_second_16", 5>, - I32EnumAttrCase<"bcast_third_16", 6>, - I32EnumAttrCase<"bcast_fourth_16", 7> - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::amdgpu"; -} - -def AMDGPU_MFMAPermBAttr : EnumAttr; - -// mfma -def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64, - VectorOfLengthAndType<[2], [F32]>, - VectorOfLengthAndType<[4, 8], [F16]>, - VectorOfLengthAndType<[2, 4, 8], [BF16]>, - VectorOfLengthAndType<[4, 8, 16], [I8]>, - VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>, - VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>, - VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; -def MFMAOutTypes : AnyTypeOf<[F64, - VectorOfLengthAndType<[4, 16, 32], [F32]>, - VectorOfLengthAndType<[4, 16, 32], [I32]>, - VectorOfLengthAndType<[4], [F64]>]>; - -// sparse_mfma (smfmac) -def SMFMACSparseInTypes : AnyTypeOf<[ - VectorOfLengthAndType<[4, 8], [F16]>, - VectorOfLengthAndType<[4, 8], [BF16]>, - VectorOfLengthAndType<[8, 16], [I8]>, - VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>, - VectorOfLengthAndType<[8, 16], [F8E4M3FNUZ, F8E5M2FNUZ]> -]>; - -def SMFMACDenseInTypes : AnyTypeOf<[ - VectorOfLengthAndType<[8, 16], [F16]>, - VectorOfLengthAndType<[8, 16], [BF16]>, - VectorOfLengthAndType<[16, 32], [I8]>, - VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>, - VectorOfLengthAndType<[16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]> -]>; - -def SMFMACOutTypes : AnyTypeOf<[ - VectorOfLengthAndType<[4, 16], [F32]>, - VectorOfLengthAndType<[4, 16], [I32]> -]>; - -def SMFMACIdxTypes : AnyTypeOf<[ - FixedVectorOfLengthAndType<[4], [I8]>, - FixedVectorOfLengthAndType<[2], [I16]> -]>; - -// scaled_mfma -def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>, - VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; -def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>; - -// scaled_wmma -def ScaledWMMAInTypes - : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>, - VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>, - VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>; - -def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F32]>]>; - -// wmma -def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>, - VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>, - VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>, - VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>, - VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>; -def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, - VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>; - -def AMDGPU_MFMAOp : - AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>, - Pure]>, - Arguments<(ins - ConfinedAttr]>:$m, - ConfinedAttr]>:$n, - ConfinedAttr]>:$k, - DefaultValuedAttr]>, "1">:$blocks, - MFMAInTypes:$sourceA, - MFMAInTypes:$sourceB, - MFMAOutTypes:$destC, - DefaultValuedAttr:$cbsz, - DefaultValuedAttr:$abid, - DefaultValuedAttr:$blgp, - UnitAttr:$reducePrecision, - UnitAttr:$negateA, - UnitAttr:$negateB, - UnitAttr:$negateC)>, - Results<(outs MFMAOutTypes: $destD)> { - let summary = "MLIR wrapper for CDNA mfma instructions"; - let description = [{ - The `amdgpu.mfma` op is an MLIR wrapper around intrinsics - for various `mfma` instructions in the CDNA architecture, which perform - multiple outer products in order to allow fast matrix multiplication. - - The wrapper will select an appropriate `mfma` instruction, if one is available, - based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the - types of the source and destination arguments. - - For information on the layouts of the input and output matrices (which are stored - in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation. - - The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave - are permuted when matrix data is being loaded: `blgp` can be any number of - fixed permutations, `cbsz` specifies the log_2 of the number of chunks the lanes - holding sourceA are split into, and `abid` selects one of those chunks. - - Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA - intrinsics that take an integer type of width `4K`. For example, - one can provide a vector<4xi8> as an argument to an MFMA instruction that - logically takes 4 i8s but whose intrinsics are specified to take an i32. - In these cases, the bytes in the vector will be concatenated in little-endian - order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on). - - The negateA, negateB, and negateC flags are only supported for double-precision - operations on gfx94x. - - Example: - ```mlir - %0 = amdgpu.mfma 16x16x16 %matA * %matB + %matC - : vector<4xf16>, vector<4xf16>, vector<4xf32> - - %1 = amdgpu.mfma 32x32x1 %matD * %matE + %matF - { abid = 1 : i32, cbsz = 1 : i32, blocks = 2 : i32 } - blgp = bcast_second_32 : f32, f32, vector<32xf32> - ``` - }]; - let assemblyFormat = [{ - custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC - attr-dict - `blgp` `=` $blgp - `:` type($sourceA) `,` type($sourceB) `,` type($destC) - }]; - let hasVerifier = 1; -} - -def AMDGPU_WMMAOp : - AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>, - Pure]>, - Arguments<(ins - ConfinedAttr]>:$m, - ConfinedAttr]>:$n, - ConfinedAttr]>:$k, - WMMAInTypes:$sourceA, - WMMAInTypes:$sourceB, - WMMAOutTypes:$destC, - DefaultValuedAttr]>, "0">:$subwordOffset, - UnitAttr:$unsignedA, - UnitAttr:$unsignedB, - UnitAttr:$clamp)>, - Results<(outs WMMAOutTypes: $destD)> { - let summary = "MLIR wrapper for wmma instructions"; - let description = [{ - The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma` - instructions in the AMDGPU architecture, which perform matrix multiplication. - - On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions. - - On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for - all element types, and K=32 for i4 sources. - - On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128, - depending on the element types. - - On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 - (or 16xbf16) vector containing only 8 valid values: - - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14. - - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15. - On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all - the values are valid and the `subwordOffset` must be `0`, as it cannot be used. - - `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned. - - The `clamp` flag is used to saturate the output of type T to `numeric_limits::max()` - in case of overflow. - - Example: - ```mlir - %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16> - - %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32> - - %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32> - - %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32> - ``` - }]; - let assemblyFormat = [{ - custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC - attr-dict - `:` type($sourceA) `,` type($sourceB) `,` type($destC) - }]; - let hasVerifier = 1; -} - -def AMDGPU_SparseMFMAOp : - AMDGPU_Op<"sparse_mfma", [AllTypesMatch<["destC", "destD"]>, - Pure]>, - Arguments<(ins - ConfinedAttr]>:$m, - ConfinedAttr]>:$n, - ConfinedAttr]>:$k, - SMFMACSparseInTypes:$sourceA, - SMFMACDenseInTypes:$sourceB, - SMFMACOutTypes:$destC, - SMFMACIdxTypes:$sparseIdx, - DefaultValuedAttr:$cbsz, - DefaultValuedAttr:$abid)>, - Results<(outs SMFMACOutTypes: $destD)> { - let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions"; - let description = [{ - The `amdgpu.sparse_mfma` op is an MLIR wrapper around intrinsics for various - `smfmac` instructions in the AMDGPU architecture, which perform matrix - multiply-accumulate operations using 2:4 structured sparsity on matrix A - with dense matrices B, C, and D. - - On gfx942, smfmac intrinsics support: - - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources - - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources - - On gfx950, smfmac intrinsics additionally support: - - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources - - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources - - The `sparseIdx` parameter contains packed indices identifying the positions - of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data, - use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use - `vector<2xi16>` (two 16-bit indices). - - The `cbsz` and `abid` parameters are repurposed to select the index set. - If `cbsz == 0`, then `abid[1:0]` selects which index set to use. - If `cbsz != 0`, then the very first is selected. - - Example: - ```mlir - %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>) - : vector<4xf16>, vector<8xf16>, vector<4xf32> - - %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>) - : vector<8xi8>, vector<16xi8>, vector<4xi32> - - %2 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>) - { cbsz = 0 : i32, abid = 1 : i32 } - : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32> - ``` - }]; - let assemblyFormat = [{ - custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC - `sparse` `(` $sparseIdx `:` type($sparseIdx) `)` - attr-dict - `:` type($sourceA) `,` type($sourceB) `,` type($destC) - }]; - let hasVerifier = 1; -} - -def AMDGPU_GatherToLDSOp : - AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>, - Arguments<(ins - Arg:$src, - Variadic:$srcIndices, - Arg:$dst, - Variadic:$dstIndices, - TypeAttr:$transferType - )>, - Results<(outs)> { - let summary = "MLIR wrapper for CDNA Gather to LDS instructions"; - let description = [{ - The `amdgpu.gather_to_lds` op is a wrapper around the `global_load_lds` instructions. - - Operands: - * `$src`: global memory (including fat buffer) memref to read from. - * `$srcIndices`: indices into `$src` to read from for this thread. - * `$dst`: LDS memory memref to write to. - * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread. - The elements gathered by the subgroup will be written contiguously in order of lane ID - starting at `$dst[$dstIndices]`. Byte-sized (ex. i8) or short-sized (ex. i16) - types will be zero-padded/extended to 32 bits before being written. 96-bit types - (ex. vector<3xf32>) will be zero-padded to 128 bits before being written. Only the - offsets held by lane 0 are used. - * `$transferType`: type of the data to be transferred by each thread. This is used to determine - the size of the data to be transferred and the number of threads in the subgroup. - The transfer type must be a scalar type or a vector type with a single element type. - - The `$dst`, along with its indices, points to the memory location the subgroup of this thread - will write to. - - Note: only supported on gfx9 and gfx10. - }]; - let assemblyFormat = [{ - $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst) - }]; - let hasVerifier = 1; - let hasCanonicalizer = 1; -} - -def AMDGPU_TransposeLoadOp : - AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>, - Arguments<(ins Arg:$src, Variadic:$srcIndices)>, - Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> { - let summary = "MLIR wrapper for CDNA Transpose Load instructions"; - let description = [{ - The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions. - The transpose load op represents a subgroup load from LDS memory, - where the subgroup of threads collectively reads a matrix from the source - memref, with each thread reading a vector of the matrix, and gets a transposed matrix - in as the result. That is, each thread reads a vector of the col-major matrix at different - indices, and the thread's read result is a vector of the corresponding row of the transposed - matrix. - - This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer - to the CDNA4 ISA documentation for more details about its exact semantics. - - Format example: - ``` - %0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16> - ``` - Operands: - * `$src`: LDS memref to read from. - * `$srcIndices`: indices into `$src` to read from for this thread. - * `$result`: target register this transpose load instruction will write to. - - Note: Lowering is only supported on gfx950 and up. - }]; - let assemblyFormat = [{ - $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result) - }]; - let hasVerifier = 1; -} - -def AMDGPU_ScaledMFMAOp : - AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>, - Pure]>, - Arguments<(ins - ConfinedAttr]>:$m, - ConfinedAttr]>:$n, - ConfinedAttr]>:$k, - ScaledMFMAInTypes:$sourceA, - ScaledMFMAInTypes:$sourceB, - ScaledMFMAOutTypes:$destC, - AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesA, - AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesB, - ConfinedAttr]>:$scalesIdxA, - ConfinedAttr]>:$scalesIdxB - )>, - Results<(outs ScaledMFMAOutTypes: $destD)> { - let summary = "MLIR wrapper for CDNA scaled mfma instructions"; - let description = [{ - The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics - for various scaled versions of `mfma` instructions in the CDNA architecture, which - perform multiple outer products in order to allow fast matrix multiplication. - - The wrapper will select an appropriate `mfma` instruction, if one is available, - based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the - types of the source and destination arguments. - - Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA - intrinsics that take an integer type of width `4K`. For example, - one can provide a `vector<4xi8>` as an argument to an MFMA instruction that - logically takes 4 i8s but whose intrinsics are specified to take an i32. - In these cases, the bytes in the vector will be concatenated in little-endian - order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on). - - This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences: - - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and - fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as - their tile size. - - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp` - are omitted from this wrapper. - - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported - for double-precision operations on gfx94x and so are not included here. - - Example: - ```mlir - %0 = amdgpu.scaled_mfma 32x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 - : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32> - ``` - }]; - let assemblyFormat = [{ - custom($m, $n, $k) ` ` - `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` - `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC - attr-dict - `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC) - }]; - let hasCanonicalizer = 1; -} - -def AMDGPU_ScaledWMMAOp - : AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>, - Arguments<(ins ConfinedAttr]>:$m, - ConfinedAttr]>:$n, - ConfinedAttr]>:$k, - ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB, - ScaledWMMAOutTypes:$destC, - VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA, - ConfinedAttr]>:$a_first_scale_lane, - VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB, - ConfinedAttr]>:$b_first_scale_lane)>, - Results<(outs ScaledWMMAOutTypes:$destD)> { - // TODO: E5M3FNU scales are supported, but there is not yet MLIR support for - // this datatype. Once we have support for that, update the scaleA and scaleB - // types here. - let summary = "MLIR wrapper for scaled wmma instructions"; - let description = [{ - The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled - `wmma` instructions. These instructions perform matrix multiplication with - per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats. - - The scale instructions support a block size of 16 or 32 and two tile sizes: - - 16x16x128 with mixed f8/f6/f4 formats (output: vector<8xf32>) - - 32x16x128 with f4 format only (output: vector<16xf32>) - - Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values - (either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during - lowering. Each lane can operate on 4 bytes (4 scale values), and the - number of scales required for each matrix is determined by: - num_scales_A = (M × K) / block_size - num_scales_B = (N × K) / block_size - - The index attributes (`a_first_scale_lane`, `b_first_scale_lane`) select - which lane to start reading scale values from (0 or 16): - - For block size 32, 32 lanes across a single wave are used for the scale - values. If the number of scales (num_scales_A or num_scales_B) can fit - into half of the available lanes - (i.e., num_scales / scales_per_lane == 16 (num_lanes)), - then then first_scale_lane can be either 0 or 16. If all lanes are required - for storing the scale values (num_scales / scales_per_lane == 32 (num_lanes)), - then the first_scale_lane must be 0. - - For block size 16, the same rules apply as above except that there are 64 - lanes across two waves that are used for the scale values. When - num_scales / scales_per_lane == 32 (num lanes), then 16 lanes from each wave are used. - first_scale_lane of 0 or 16 will decide which lanes are used for this. When - num_scales / scales_per_lane == 64 (num_lanes), then first_scale_lane must - be set to 0. - - Example: - ```mlir - // 16x16x128: fp8 inputs - %0 = amdgpu.scaled_wmma 16x16x128 (%scaleVecA * %matA) * (%scaleVecB * %matB) + %matC - {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} - : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, - vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32> - - // 32x16x128: fp4 inputs with different scale lanes - %1 = amdgpu.scaled_wmma 32x16x128 (%scaleVecD * %matD) * (%scaleVecE * %matE) + %matF - {a_first_scale_lane = 0 : i32, b_first_scale_lane = 16 : i32} - : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>, - vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32> - ``` - }]; - let assemblyFormat = [{ - custom($m, $n, $k) ` ` - `(` $scaleA `*` $sourceA `)` `*` - `(` $scaleB `*` $sourceB `)` `+` $destC - attr-dict - `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC) - }]; - let hasVerifier = 1; -} - -class AMDGPU_DmaBaseOp : - AMDGPU_Op]>, - Arguments<(ins Arg:$global, - Variadic:$global_indices, - Arg:$lds, - Variadic:$lds_indices)>, - Results<(outs outType: $base)> { - - // TODO: - // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions. - - let assemblyFormat = [{ - $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results) - }]; -} - -def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU_TDMGatherBaseType> { - let summary = "Pair of based addresses used when moving tiles between LDS and global memory."; - - let description = [{ - This operation creates a pair of addresses that will be used by `tensor_load_to_lds` - and `tensor_store_from_lds`. - - This operation creates a value corresponding to the tensor descriptor (D#) group 0 - found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect. - - Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_type>` - which is only compatible with `make_gather_dma_descriptor`. Using the descriptor returned - by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode. - - ```mlir - %base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_gather_base - // %indices : i16 - %descriptor = amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base, i16 -> !amdgpu.tdm_descriptor - amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor - ``` - }]; - - let hasVerifier = 1; - - let extraClassDeclaration = [{ - static constexpr bool isGather() { - return true; - } - }]; -} - - -def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType> { - - let summary = "Pair of based addresses used when moving tiles between LDS and global memory."; - let description = [{ - This operation creates a pair of addresses that will be used by tensor_load_to_lds - and tensor_store_from_lds. - - This operation creates a value corresponding to the tensor descriptor (D#) group 0 - found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect. - - For example: - - ```mlir - %base = amdgpu.make_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_base - %descriptor = amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor - amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor - ``` - - to - - ```mlir - // pseudo-code - %global_base = llvm.extractvalue %global_memref[1] - %global_address = llvm.get_element_ptr ... - - %lds_base = llvm.extractvalue %lds_memref[1] - %lds_address = llvm.get_element_ptr ... - - // Definition of %base - %undef = llvm.mlir.undef : vector<4xi32> - %v0 = llvm.insertelement %15, %undef[0] : vector<4xi32> - %v1 = llvm.insertelement %lds_address, %v0[1] : vector<4xi32> - %v2 = llvm.insertelement %global_address_low, %v1[2] : vector<4xi32> - %base = llvm.insertelement %global_address_high, %v2[3] : vector<4xi32> - - rocdl.tensor.load.to.lds %base, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> - ``` - - These tensor DMA operations were introduced in gfx1250. - }]; - - let hasVerifier = 1; - - let extraClassDeclaration = [{ - static constexpr bool isGather() { - return false; - } - }]; -} - -class AMDGPU_MakeDescriptorOp : - AMDGPU_Op, - Results<(outs AMDGPU_TDMDescriptorType: $desc)> { - - dag baseArgs = (ins - Variadic: $global_dynamic_sizes, - DenseI64ArrayAttr: $global_static_sizes, - Variadic: $global_dynamic_strides, - DenseI64ArrayAttr: $global_static_strides, - Variadic: $shared_dynamic_sizes, - DenseI64ArrayAttr: $shared_static_sizes, - Optional>: $workgroup_mask, - Optional: $early_timeout, - Optional: $pad_amount, - Optional: $pad_interval, - Optional: $atomic_barrier_address, - Variadic: $atomic_barrier_indices, - Optional: $global_increment, - Optional: $lds_increment, - Optional: $iteration_count); - - code extraClassDeclarationBase = [{ - int64_t getRank() { - return getGlobalStaticSizes().size(); - } - - unsigned getElementTypeWidth() { - return getBase().getType().getElementType().getIntOrFloatBitWidth(); - } - - SmallVector getMixedGlobalSizes() { - return getMixedValues(getGlobalStaticSizes(), getGlobalDynamicSizes(), getContext()); - } - - SmallVector getMixedGlobalStrides() { - return getMixedValues(getGlobalStaticStrides(), getGlobalDynamicStrides(), getContext()); - } - - SmallVector getMixedSharedSizes() { - return getMixedValues(getSharedStaticSizes(), getSharedDynamicSizes(), getContext()); - } - - }]; - -} - -def AMDGPU_MakeGatherDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_gather_dma_descriptor"> { - dag args = (ins AMDGPU_TDMGatherBaseType: $base, - AnyTypeOf<[VectorOfMinMaxLengthAndType<1, 8, [I32]>, - VectorOfMinMaxLengthAndType<1, 16, [I16]>]>: $indices); - let arguments = !con(args, baseArgs); - let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS."; - - let assemblyFormat = [{ - $base `[` $indices `]` - `globalSize` custom($global_dynamic_sizes, $global_static_sizes) - `globalStride` custom($global_dynamic_strides, $global_static_strides) - `sharedSize` custom($shared_dynamic_sizes, $shared_static_sizes) - ( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )? - ( `workgroupMask` $workgroup_mask^ ( `earlyTimeout` $early_timeout^)?)? - ( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]` - `:` type($atomic_barrier_address) `)`)? - ( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )? - attr-dict `:` qualified(type($base)) `,` type($indices) `->` type(results) - }]; - - let hasVerifier = 1; - let hasFolder = 1; - - let extraClassDeclaration = extraClassDeclarationBase # [{ - static constexpr bool isGather() { - return true; - } - }]; -} - -def AMDGPU_MakeDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_dma_descriptor"> { - dag args = (ins AMDGPU_TDMBaseType: $base); - let arguments = !con(args, baseArgs); - let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS."; - let description = [{ - Make all descriptor groups needed by tensor memory operations. - - The $base operand corresponds to the base pair addresses, one must be an address in LDS - while the other must be a global memory location. - - $global_{static/dynamic}_sizes determine the size of the tensor. - $global_{static/dynamic}_strides determine the strides of the tensor. - $shared_{static/dynamic}_sizes determines the size of the tile. - - $workgroup_mask broadcast load to workgroups inside of a workgroup cluster - (0 = do not broadcast result to workgroup, 1 = broadcast result to workgroup). Ignored for stores. - An all zeros mask is interpreted as a non-broadcasted load. - - $early_timeout return data to requesters as soon as cache supplies it. - - Padding can be applied to the LDS address when copying from memory to LDS, - but not when copying from LDS to memory. - The values in the padded target addresses remain the same as before the operation was applied. - $pad_interval must be a power of two contained in [2, 256]. - $pad_amount must be a value contained in [1, 128]. - - $atomic_barrier_address must be aligned to 8 bytes. - - 2D and 3D tensors may be iterated over by setting $global_increment, $lds_increment, and $iteration_count. - $global_increment determines how much to increment the starting global memory address per iteration in units of the $base's element type. - $lds_increment determines how much to increment the starting LDS address per iteration in units of the $base's element type. - $iterate_count determines how many times to iterate, it must be a value in the inclusive interval [1, 256]. - - ```mlir - // Example of moving a two-dimensional tensor to LDS. - %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_base - %descriptor = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor - amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor - - // Example of moving a two dimension tensor to LDS where padding is applied after every integer. - %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_base - %descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padShared(%pad_amount every %pad_interval) : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor - amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor - ``` - }]; - - let assemblyFormat = [{ - $base - `globalSize` custom($global_dynamic_sizes, $global_static_sizes) - `globalStride` custom($global_dynamic_strides, $global_static_strides) - `sharedSize` custom($shared_dynamic_sizes, $shared_static_sizes) - ( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )? - ( `workgroupMask` $workgroup_mask^ ( `earlyTimeout` $early_timeout^)?)? - ( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]` - `:` type($atomic_barrier_address) `)`)? - ( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )? - attr-dict `:` qualified(type($base)) `->` type(results) - }]; - - let hasVerifier = 1; - let hasFolder = 1; - - let extraClassDeclaration = extraClassDeclarationBase # [{ - static constexpr bool isGather() { - return false; - } - }]; - -} - -def AMDGPU_TensorLoadToLDSOp : - AMDGPU_Op<"tensor_load_to_lds", [MemoryEffects<[MemWrite, MemRead]>]>, - Arguments<(ins AMDGPU_TDMDescriptorType: $desc)> { - let summary = "Load tensors from global memory to LDS."; - let description = [{ - Load tensors of up to five dimensions from global memory to LDS. - - This operation was introduced in gfx1250. - }]; - - let assemblyFormat = [{ - $desc attr-dict `:` qualified(type($desc)) - }]; -} - -def AMDGPU_TensorStoreFromLDSOp : - AMDGPU_Op<"tensor_store_from_lds", [MemoryEffects<[MemWrite, MemRead]>]>, - Arguments<(ins AMDGPU_TDMDescriptorType: $desc)> { - - let summary = "Store tensors from LDS to global memory."; - let description = [{ - Store tensors of up to five dimensions from LDS to global memory. - - This operation was introduced in gfx1250. - }]; - - let assemblyFormat = [{ - $desc attr-dict `:` qualified(type($desc)) - }]; -} - -#endif // AMDGPU +#endif // MLIR_DIALECT_AMDGPU_IR_AMDGPU_TD diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.td new file mode 100644 index 0000000000000..c862fb2fc5a3a --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.td @@ -0,0 +1,50 @@ +//===-- AMDGPUAttrs.td - AMDGPU dialect attributes *- tablegen -*----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUATTRS_TD +#define MLIR_DIALECT_AMDGPU_IR_AMDGPUATTRS_TD + +include "mlir/Dialect/AMDGPU/IR/AMDGPUBase.td" +include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.td" + +def AMDGPU_AddressSpaceAttr : EnumAttr { + let description = [{ + AMDGPU-specific memory spaces that may not have exact analogues on other + GPU targets or backends. + + - `fat_raw_buffer` is the memory space used when a memref is stored as + as a "buffer fat pointer" - that is, a buffer resource (that is set up to + use raw byte-level indexing) along with its offset. The AMDGPU backend + implements `ptr addrspace(7)` to represent these fat pointers so that + buffer resources (which allow advanced features like bounds checking or + cache swizzling) can be used like ordinary LLVM pointers or memrefs. + See also the `fat_raw_buffer_cast` operation + - `buffer_rsrc` is the memory space for `ptr addrspace(8)`, representing a + buffer resource. It should not be used for memrefs, since it does not support + indexing + - `fat_structured_buffer` represents `ptr addrspace(9)`, a buffer resource + that carries both an index and offset field, which are used for complex + structured indexing that is primarily seen in graphics applications. This + is also incompatible with the simple indexing model supported by memref. + }]; + let assemblyFormat = "`<` $value `>`"; +} + +def AMDGPU_DPPPermAttr : EnumAttr; + +def AMDGPU_SchedBarrierOpOptAttr : EnumAttr{ + let assemblyFormat = "`<` $value `>`"; +} + +def AMDGPU_MFMAPermBAttr : EnumAttr; + +#endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUATTRS_TD diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUBase.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUBase.td new file mode 100644 index 0000000000000..0ceea9334fc03 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUBase.td @@ -0,0 +1,122 @@ +//===-- AMDGPUBase.td - AMDGPU dialect base *- tablegen -*-----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUBASE_TD +#define MLIR_DIALECT_AMDGPU_IR_AMDGPUBASE_TD + +include "mlir/IR/DialectBase.td" + +def AMDGPU_Dialect : Dialect { + let name = "amdgpu"; + let cppNamespace = "::mlir::amdgpu"; + let description = [{ + The `AMDGPU` dialect provides wrappers around AMD-specific functionality + and LLVM intrinsics. These wrappers should be used in conjunction with + more generic dialects, such as `gpu` and `vector`, when generating LLVM IR + that will eventually be executed on AMD hardware. + + # What goes here? + In many cases, AMD GPU functionality can be accessed either though generic + operations (such as those in the `gpu`, `vector`, or `math`) or through + the `rocdl` dialect's intrinsic wrappers. However, there are instances where + AMD-specific functionally benefits from a wrapper around the underlying + LLVM intrinsics. + + In general terms, operations or types should be added to this dialect when they + wrap some AMD-specific functionality in a way that makes it work better with the + MLIR ecosystem and its types or when those buitins would be needlessly + complex to work with (such as if they features magic constants at the LLVM level). + + An additional set of operations that belong in this dialect are those that + have chipset-specific differences that can be abstracted over in a useful way. + + To give some concrete examples: + + - `amdgpu.mfma` and `amdgpu.wmma` exist in order to make a large set of + intrinsics more compatible with the MLIR type system (such as by allowing + 8-bit float vectors to be passed as `vector` or + `vector` instead of as packed 32-bit integers whose element type + is controlled by separate operator-level constants. These operations also + allow the same `amdgpu.mfma` operation to be used regardless of the target + chip. + - `amdgpu.swizzle_bitmode` provides a wrapper around the `ds.swizzle` intrinsic, + allowing a wider range of types (such as `vector<2xf16>`) to be used natively + and eliminating the need to pack the and, or, and xor components using opaque + shifts. + - Operations like `amdgpu.gather_to_lds` provide `memref`-ized wrappers around + intrinsics that take a pointer, and are nontrivial enough to justify inclusion + in this dialect. + + + Note that simple intrinsics like `rocdl.sin` or `rocdl.s.barrier` should not + receive wrapper operations, as nothing is gained from the duplicate operation. + As a rule of thumb, if an operation's rewrite in AMDGPUToROCDL would be only + a `replaceOpWithNewOp` call, no AMDGPU dialect operation is needed. + + # Design guidelines + + Operations should leverage MLIR's "standard" types where possible. MLIR has + a more extensible type system than LLVM (especially in the area of small floats) + and those types should be used to create more ergonomic wrappers. In particular, + intrinsics that take pointers should have wrappers in this dialect that take + `memref` arguments and indices. + + Operations should use properties or attributes in cases where the underlying + intrinsic uses `immarg`s (except in cases where that attribute can be represented + in the type system). + + If it is possible to generalize the types of an operation, it should be done. + For example, the underlying operations for permutations and swizzles always + take 32-bit operands. Their AMDGPU wrappers can take any type, and will apply + padding and expansion to multiple instructions as needed. This makes these + operations easier to target because it hides the bitcasts and extracts + until the final lowering. + + When the underlying operation uses magic constants, those should be presented + in a more programmer-friendly fashion, such as through enums or though + using separate arguments that are later combined. (For example, see the + design of the `amdgpu.dpp` and `amdgpu.fat_raw_buffer_cast` operations.) + + If sufficiently similar functionality on multiple hardware generations can be + encapsulated into a single operation, it should be done. The lowering to + intrinsics should either throw an error when an unsupported capability is + used or ignore it. Which of these is two failure modes is more appropriate + depends on the nature of the feature, but errors are a safe default choice. + + # Documentation guidelines + + AMDGPU dialect operations should document how any abstractions they introduce + translate to LLVM intrinsics or hardware operations. + + While documenting the semantics of the underlying operations is not required, + is preferred to provide an overview of the operation's functionality, + especially in cases where the documentation is widely distributed. Someone + looking at an AMDGPU dialect operation should be able to generally understand + what it does and have found the keywords they'll need for more detail. + + Operation documentation should include usage examples. + + Note that this dialect uses LLVM's gfx numbers to refer to individual + architectures/chipsets and not product names or codenames. + }]; + + + let dependentDialects = [ + "ROCDL::ROCDLDialect", + "arith::ArithDialect", + "gpu::GPUDialect" + ]; + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; + let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + }]; +} + +#endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUBASE_TD diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h index a7680fb5c3191..69d77b8d89361 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -25,6 +25,8 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc" #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.h.inc" #include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc" namespace mlir::amdgpu { @@ -51,7 +53,7 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *, } // namespace mlir::amdgpu #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.h.inc" #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc" diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUEnums.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUEnums.td new file mode 100644 index 0000000000000..4ec7cb3cd7307 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUEnums.td @@ -0,0 +1,83 @@ +//===-- AMDGPUEnums.td - AMDGPU dialect enums *- tablegen -*---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUENUMS_TD +#define MLIR_DIALECT_AMDGPU_IR_AMDGPUENUMS_TD + +include "mlir/Dialect/AMDGPU/IR/AMDGPUBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/Properties.td" + +//===----------------------------------------------------------------------===// +// AMDGPU general enum definitions +//===----------------------------------------------------------------------===// + +def AMDGPU_AddressSpace : I32Enum<"AddressSpace", + "AMDGPU-specific address spaces", + [ + I32EnumCase<"FatRawBuffer", 0, "fat_raw_buffer">, + I32EnumCase<"BufferRsrc", 1, "buffer_rsrc">, + I32EnumCase<"FatStructuredBuffer", 2, "fat_structured_buffer">, + ]> { + let cppNamespace = "::mlir::amdgpu"; +} + +def AMDGPU_DPPPerm : I32Enum<"DPPPerm", + "The possible permutations for a DPP operation", + [ + I32EnumAttrCase<"quad_perm", 0>, + I32EnumAttrCase<"row_shl", 1>, + I32EnumAttrCase<"row_shr", 2>, + I32EnumAttrCase<"row_ror", 3>, + I32EnumAttrCase<"wave_shl", 4>, + I32EnumAttrCase<"wave_shr", 5>, + I32EnumAttrCase<"wave_ror", 6>, + I32EnumAttrCase<"wave_rol", 7>, + I32EnumAttrCase<"row_mirror", 8>, + I32EnumAttrCase<"row_half_mirror", 9>, + I32EnumAttrCase<"row_bcast_15", 10>, + I32EnumAttrCase<"row_bcast_31", 11> + ]> { + let cppNamespace = "::mlir::amdgpu"; +} + +def AMDGPU_SchedBarrierOpOpt : I32BitEnum<"sched_barrier_opt_enum", + "The possible options for scheduling barriers", + [ + I32BitEnumAttrCaseNone<"none">, + I32BitEnumAttrCaseBit<"non_mem_non_sideffect", 0>, + I32BitEnumAttrCaseBit<"valu", 1>, + I32BitEnumAttrCaseBit<"salu", 2>, + I32BitEnumAttrCaseBit<"mfma_wmma", 3>, + I32BitEnumAttrCaseBit<"all_vmem", 4>, + I32BitEnumAttrCaseBit<"vmem_read", 5>, + I32BitEnumAttrCaseBit<"vmem_write", 6>, + I32BitEnumAttrCaseBit<"all_ds", 7>, + I32BitEnumAttrCaseBit<"ds_read", 8>, + I32BitEnumAttrCaseBit<"ds_write", 9>, + I32BitEnumAttrCaseBit<"transcendental", 10> + ]> { + let cppNamespace = "::mlir::amdgpu"; +} + +def AMDGPU_MFMAPermB : I32Enum<"MFMAPermB", + "The possible permutations of the lanes storing B available in an MFMA", + [ + I32EnumAttrCase<"none", 0>, + I32EnumAttrCase<"bcast_first_32", 1>, + I32EnumAttrCase<"bcast_second_32", 2>, + I32EnumAttrCase<"rotate_16_right", 3>, + I32EnumAttrCase<"bcast_first_16", 4>, + I32EnumAttrCase<"bcast_second_16", 5>, + I32EnumAttrCase<"bcast_third_16", 6>, + I32EnumAttrCase<"bcast_fourth_16", 7> + ]> { + let cppNamespace = "::mlir::amdgpu"; +} + +#endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUENUMS_TD diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td new file mode 100644 index 0000000000000..24e40f40c2031 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td @@ -0,0 +1,1544 @@ +//===-- AMDGPUOps.td - AMDGPU dialect operations *- tablegen -*------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUOPS_TD +#define MLIR_DIALECT_AMDGPU_IR_AMDGPUOPS_TD + +include "mlir/Dialect/AMDGPU/IR/AMDGPUBase.td" +include "mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.td" +include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td" + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// AMDGPU common type constraints +//===----------------------------------------------------------------------===// + +class AMDGPU_ConcreteVector : + FixedVectorOfLengthAndType<[length], [elem]>, + BuildableType< + "::mlir::VectorType::get({" # length # "} ," + # elem.builderCall # ")">; + +def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">; + +def AnyIntegerOrFloatOr1DVector : + AnyTypeOf<[AnyIntegerOrFloat, FixedVectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>; + +//===----------------------------------------------------------------------===// +// AMDGPU Op definitions +//===----------------------------------------------------------------------===// + +class AMDGPU_Op traits = []> : + Op {} + +def AMDGPU_ExtPackedFp8Op : + AMDGPU_Op<"ext_packed_fp8", [Pure]>, + Arguments<(ins AnyTypeOf<[F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN, + VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2FNUZ, F8E4M3FNUZ, F8E5M2, F8E4M3FN]>]>:$source, + ConfinedAttr]>:$index)>, + Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> { + let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats"; + + let description = [{ + Extend one or two 8-bit floats in `source[index]` to a 32-bit float or + two floats and return them. + + This rather unusual signature arises from the fact that AMD GPUs cannot + easily work with sub 32-bit quantities, so the compiler intrinsics for + extending 8-bit floats (which are, currently, the only way to work with + this operation) take packed vectors of 4 such floats. + + If the passed-in vector has fewer than four elements, or the input is scalar, + the remaining values in the <4 x i8> will be filled with + undefined values as needed. + }]; + let assemblyFormat = [{ + attr-dict $source `[` $index `]` `:` type($source) `to` type($res) + }]; +} + +def AMDGPU_ScaledExtPackedMatrixOp + : AMDGPU_Op<"scaled_ext_packed_matrix", [Pure, AllShapesMatch<["source", "res"]>]>, + Arguments<( + ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>, + FixedVectorOfShapeAndType<[8], F8E4M3FN>, + FixedVectorOfShapeAndType<[8], F8E5M2>, + FixedVectorOfShapeAndType<[16], F6E2M3FN>, + FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source, + FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale, + ConfinedAttr]>:$blockSize, + ConfinedAttr]>:$firstScaleLane, + ConfinedAttr, IntMaxValue<3>]>:$firstScaleByte)>, + Results<( + outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>, + FixedVectorOfShapeAndType<[8], F16>, + FixedVectorOfShapeAndType<[8], BF16>, + FixedVectorOfShapeAndType<[16], F32>, + FixedVectorOfShapeAndType<[16], F16>, + FixedVectorOfShapeAndType<[16], BF16>]>:$res)> { + + let summary = "Extend a wave-wide matrix of packed floating point values"; + + let description = [{ + Extend matrix of microfloats (8 or 16 elements per lane) using a set of scales + that may be stored on other lanes. + + The scales applied to the input microfloats are stored in bytes which + come from the `scales` input provided in a *half* of the wave identified + by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends + on the type of `source`. The 16 vectors in consecutive lanes starting from + `firstScaleLane` (which we'll call the scale vectors) will be used by both + halves of the wave (with lane L reading from L % 16'th scale vector). + + When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN each half of the + wave will use a different byte. The first one being `firstScaleByte` and + the second one being `firstScaleByte` + 1. When the block size is 32, + `firstScaleByte` can be either 0 or 2, selecting halves of the scale vectors. + Lanes 0-15 will read from `firstScaleByte` and lanes 16-31 will read + from `firstScaleByte` + 1. + + + For example: + ```mlir + // Input: 8-element vector of F8E4M3FN, converting to F32 + // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1 + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) + blockSize(32) firstScaleLane(0) firstScaleByte(0) + : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + + // Input: 16-element vector of F6E2M3FN, converting to F16 + // Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3 + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) + blockSize(32) firstScaleLane(16) firstScaleByte(2) + : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + ``` + + When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN and + the block size is 16, `firstScaleByte` can be 0 or 1. + Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors, + while lanes 16-31 read from `firstScaleByte` + 2. + For example: + ```mlir + // Input: 8-element vector of F8E5M2, converting to BF16 + // Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2) + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) + blockSize(16) firstScaleLane(0) firstScaleByte(0) + : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // Input: 16-element vector of F6E3M2FN, converting to F32 + // Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2) + %result = amdgpu.scaled_ext_packed_matrix %source scale(%scales) + blockSize(16) firstScaleLane(16) firstScaleByte(1) + : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + ``` + + Note: the layout for the scales generally mirrors how the WMMA + instructions use for matrix scales. These selection operands allows + one to choose portions of the matrix to convert. + + When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32, + then the same byte will be used by both halves of the wave. + In this case, `firstScaleByte` can be any value from 0 to 3. + + When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16, + following combinations are allowed: + * `firstScaleLane(0), firstScaleByte(0)` + * `firstScaleLane(16), firstScaleByte(2)` + all other combinations are reserved. + + Available on gfx1250+. + }]; + + let assemblyFormat = [{ + attr-dict $source + `scale` `(` $scale `)` + `blockSize` `(` $blockSize `)` + `firstScaleLane` `(` $firstScaleLane`)` + `firstScaleByte` `(` $firstScaleByte `)` + `:` type($source) `,` type($scale) `->` type($res) + }]; + + let hasVerifier = 1; + +} + +def AMDGPU_ScaledExtPackedOp + : AMDGPU_Op<"scaled_ext_packed", [Pure]>, + Arguments<( + ins AnyTypeOf<[VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>, + VectorOfLengthAndType<[1, 2, 3, 4, 5, 6, 7, 8], + [F4E2M1FN]>]>:$source, + F32:$scale, + ConfinedAttr]>:$index)>, + Results<( + outs AnyTypeOf<[FixedVectorOfLengthAndType<[2], [F32]>, + FixedVectorOfLengthAndType<[2], [F16]>, + FixedVectorOfLengthAndType<[2], [BF16]>]>:$res)> { + let summary = "Extend a vector of packed floating point values"; + + let description = [{ + Extend and scale two packed floats in `source[index]` to two floats and + return them. + + This rather unusual signature arises from the fact that AMD GPUs cannot + easily work with sub 32-bit quantities, so the compiler intrinsics for + extending 8-bit floats (which are, currently, the only way to work with + this operation) take packed vectors of 2 such floats. + + If the passed-in vector has fewer than two elements, or the input is scalar, + the remaining values in the <2 x i8> will be filled with + undefined values as needed. + }]; + let assemblyFormat = [{ + attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res) + }]; +} + +def AMDGPU_PackedTrunc2xFp8Op : + AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>, + Arguments<(ins F32:$sourceA, + Optional:$sourceB, + ConfinedAttr]>:$wordIndex, + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> { + let summary = "Round two floats into a packed vector of 8-bit floats"; + let description = [{ + Round the inputs `sourceA` and `sourceB` (which is undefined if not + specified) into the low or high word (bottom two or top two) elements + of the returned vector, keeping the other two elements of `existing` + unchanged if present (or undefined if it was not passed in). + + The reason for this odd signature is that AMD GPUs cannot easily work with + sub-registers, and so the conversion intrinsics (which are currently the + only way to work with 8-bit float types) take packed vectors of 4 8-bit + values. + }]; + let assemblyFormat = [{ + attr-dict $sourceA `,` ($sourceB^):(`undef`)? + `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]` + `:` type($sourceA) `to` type($res) (`into` type($existing)^)? + }]; + let hasVerifier = 1; +} + +def AMDGPU_PackedScaledTruncOp + : AMDGPU_Op<"packed_scaled_trunc", [Pure]>, + Arguments<(ins VectorOfLengthAndType<[1, 2], [F32, F16, BF16]>:$source, + F32:$scale, + ConfinedAttr]>:$index, + Optional, + FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>>:$existing)>, + Results<( + outs AnyTypeOf<[FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>, + FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>:$res)> { + let summary = "Round two floats into a packed vector of floats"; + let description = [{ + Scale and round the inputs `source` (which is undefined if not + specified) into the low or high word (bottom two or top two) elements + of the returned vector, keeping the other two elements of `existing` + unchanged if present (or undefined if it was not passed in). + + The reason for this odd signature is that AMD GPUs cannot easily work with + sub-registers, and so the conversion intrinsics take 32-bit wide + packed vectors of float values. + }]; + let assemblyFormat = [{ + attr-dict $source `into` ($existing^):(`undef`)? `[` $index `]` + `,` $scale + `:` type($source) `to` type($res) (`into` type($existing)^)? + }]; + let hasVerifier = 1; +} + +def AMDGPU_PackedStochRoundFp8Op : + AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>, + Arguments<(ins F32:$source, + I32:$stochiasticParam, + ConfinedAttr]>:$storeIndex, + Optional>:$existing)>, + Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> { + let summary = "Round float stochiastically into a packed vector of 8-bit floats"; + let description = [{ + Round the input `source`, adding in `stochiasticParam`, and place it into + the `storeIndex`th element of `res`. + + If `existing` is passed in, elements of `res` other than the one at `storeIndex` + are copied from `existing`. + + The reason for this odd signature is that AMD GPUs cannot easily work with + sub-registers, and so the conversion intrinsics (which are currently the + only way to work with 8-bit float types) take packed vectors of 4 8-bit + values. + }]; + let assemblyFormat = [{ + attr-dict $source `+` $stochiasticParam + `into` ($existing^):(`undef`)? `[` $storeIndex `]` + `:` type($source) `to` type($res) (`into` type($existing)^)? + }]; + let hasVerifier = 1; +} + +def AMDGPU_FatRawBufferCastOp : + AMDGPU_Op<"fat_raw_buffer_cast", + [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + ViewLikeOpInterface, AttrSizedOperandSegments]>, + Arguments<(ins AnyMemRef:$source, + Optional:$validBytes, + Optional>:$cacheSwizzleStride, + DefaultValuedAttr:$boundsCheck, + UnitAttr:$resetOffset)>, + Results<(outs AnyMemRef:$result)> { + // TODO: Set `resetOffset` and `boundsCheck` to use `Property` once + // we implemented pythonic binding for `Property`. + let summary = "Create a raw buffer fat pointer that matches `memref`"; + let description = [{ + Wraps the memory pointed to by `source` as a raw buffer fat pointer, or, + in LLVM terms, a `ptr addrspace(7)`, returning a memref that has the same + sizes and layout but the `#amdgpu.address_space` + address space. + + This memref can be used with standard memref operations like `memref.load`, + `memref.store`, and `memref.atomicrmw`, which will be lowered to the relevant + buffer intrinsics. (`vector.masked_load/store` will work once there's backend + support for lowering them, and then this document will be updated) + + If `validBytes` is given, it is the number of bytes that will be valid as + an offset to `out`. If it is not provided, this will be inferred from + the size of the memref during lowering. This size is + max_{d = 0 upto rank(source)} (sizes[d] * strides[d]) * sizeof(element type). + + The flags of the buffer descriptor will be set up to enable raw usage - + for example, stride = 0, add_tid = 0, and so on. The `boundsCheck` + property determines if bounds checking is enabled or not (on architectures + where this can be controlled - that is, on RDNA chips). + + If `cacheSwizzleStride` is provided, L1 cache swizzling will be enabled + on architectures that support it. This swizzling, unlike the main swizzling + mode (whose usage makes a buffer non-raw) does not affect index calculation, + but does affect cache behavior. Mixing access between cache-swizzled raw + buffers and other forms of memory access, like ordinary pointer loads or + unswizzled buffer pointers can cause incorrect behavior and must be avoided. + + This operation preserves the sizes, strides, and offset of the input + memref - they'll be added in by `memref.load` later. However, if + `resetOffset` is set, that offset will be added to the base pointer. + If the value of the memref's offset is not uniform (independent of the lane/thread ID), + this will lead to substantially decreased performance due to the need for + a waterfall loop on the base address of the buffer resource. + }]; + + let extraClassDeclaration = [{ + Value getViewSource() { return getSource(); } + }]; + + let assemblyFormat = [{ + $source oilist (`validBytes` `(` $validBytes `)` + | `cacheSwizzleStride` `(` $cacheSwizzleStride `)` + | `boundsCheck` `(` $boundsCheck `)` + | `resetOffset` $resetOffset ) + attr-dict `:` type($source) `to` type($result) + }]; + + let hasVerifier = 1; +} + +/// Raw buffer load +def AMDGPU_RawBufferLoadOp : + AMDGPU_Op<"raw_buffer_load", [AllElementTypesMatch<["value", "memref"]>, + AttrSizedOperandSegments]>, + Arguments<(ins Arg:$memref, + Variadic:$indices, + DefaultValuedAttr:$boundsCheck, + OptionalAttr:$indexOffset, + Optional:$sgprOffset)>, + Results<(outs AnyType:$value)> { + + let summary = "Raw Buffer load, exposing GCN features"; + let description = [{ + The `amdgpu.raw_buffer_load` op is a wrapper around the buffer load intrinsics + available on AMD GPUs, including extensions in newer GPUs. + + The index into the buffer is computed as for `memref.load` with the additon + of `indexOffset` and `sgprOffset` (which **may or may not** be considered + in bounds checks and includes any offset present on the memref type if it's + non-zero). + + All indices and offsets are in units of the memref's data type and are + converted to bytes during lowering. + + When a load is out of bounds, the instruction returns zero. + Partially-out of bounds have chipset-dependent behavior: whether reading + 2 elements starting at index 7 of a `memref<8xf32>` returns the last element + in the first vector component depends on the architecture. + + The memref struct is converted into a buffer resource (a V#) and the arguments + are translated to intrinsic arguments as follows: + - The base address of the buffer is the base address of the memref + - The stride is 0 to enable raw mode + - The number of records is the size of the memref, in bytes + In the case of dynamically-shaped memrefs, this is computed at runtime + as max_d (size(d) * stride(d)) * sizeof(elementType(memref)) + - The offset enable bit is 1, the index enable bit is 0. + - The thread ID addition bit is off + - If `boundsCheck` is false and the target chipset is RDNA, OOB_SELECT is set + to 2 to disable bounds checks, otherwise it is 3 + - The cache coherency bits are off + }]; + let assemblyFormat = [{ + attr-dict $memref `[` $indices `]` + (`sgprOffset` $sgprOffset^)? `:` + type($memref) (`,` type($indices)^)? `->` type($value) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +/// Raw buffer store +def AMDGPU_RawBufferStoreOp : + AMDGPU_Op<"raw_buffer_store", [AllElementTypesMatch<["value", "memref"]>, + AttrSizedOperandSegments]>, + Arguments<(ins AnyType:$value, + Arg:$memref, + Variadic:$indices, + DefaultValuedAttr:$boundsCheck, + OptionalAttr:$indexOffset, + Optional:$sgprOffset)> { + + let summary = "Raw Buffer Store, exposing GCN features"; + let description = [{ + The `amdgpu.raw_buffer_store` op is a wrapper around the buffer store + intrinsics available on AMD GPUs, including extensions in newer GPUs. + + The store index is computed as in `memref.store` with the addition of + `indexOffset` (which is included for uniformity with atomics and may be useful + when writing vectorized code) and `sgprOffset` (which is added after bounds + checks and implicitly includes the offset of the memref type if non-zero). + All index components are in terms of the elements of the memref, not bytes, + and are scaled up appropriately. + + Out of bounds stores are ignored in hardware. + Wthether a vector write that includes some in-bounds and soeme out-of-bounds + components is partically completed is chipset-dependent. + + See `amdgpu.raw_buffer_load` for a description of how the underlying + instruction is constructed. + }]; + let assemblyFormat = [{ + attr-dict $value `->` $memref `[` $indices `]` + (`sgprOffset` $sgprOffset^)? `:` + type($value) `->` type($memref) (`,` type($indices)^)? + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +// Raw buffer atomic compare-and-swap +def AMDGPU_RawBufferAtomicCmpswapOp : + AMDGPU_Op<"raw_buffer_atomic_cmpswap", [ + AttrSizedOperandSegments, + AllTypesMatch<["src", "cmp", "value"]>, + AllElementTypesMatch<["value", "memref"]>]>, + Arguments<(ins AnyType:$src, + AnyType:$cmp, + Arg:$memref, + Variadic:$indices, + DefaultValuedAttr:$boundsCheck, + OptionalAttr:$indexOffset, + Optional:$sgprOffset)>, + Results<(outs AnyType:$value)> { + + let summary = "Raw Buffer Atomic compare-and-swap"; + let description = [{ + The `amdgpu.raw_buffer_atomic_cmpswap` op is a wrapper around the + buffer-based atomic compare-and-swap min available on AMD GPUs. + + The index into the buffer is computed as for `memref.store` with the addition + of `indexOffset` (which is used to aid in emitting vectorized code) and, + if present `sgprOffset` (which is added after bounds checks and includes + any non-zero offset on the memref type). + + All indexing components are given in terms of the memref's element size, not + the byte lengths required by the intrinsic. + + Out of bounds atomic operations are ignored in hardware. + + See `amdgpu.raw_buffer_load` for a description of how the underlying + instruction is constructed. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $cmp `->` $memref `[` $indices `]` + (`sgprOffset` $sgprOffset^)? `:` + type($value) `->` type($memref) `,` type($indices) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +// Raw buffer atomic floating point add +def AMDGPU_RawBufferAtomicFaddOp : + AMDGPU_Op<"raw_buffer_atomic_fadd", [AllElementTypesMatch<["value", "memref"]>, + AttrSizedOperandSegments]>, + Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$value, + Arg:$memref, + Variadic:$indices, + DefaultValuedAttr:$boundsCheck, + OptionalAttr:$indexOffset, + Optional:$sgprOffset)> { + + let summary = "Raw Buffer Floating-point Atomic Add (MI-* only)"; + let description = [{ + The `amdgpu.raw_buffer_atomic_fadd` op is a wrapper around the + buffer-based atomic floating point addition available on the MI-* series + of AMD GPUs. + + The index into the buffer is computed as for `memref.store` with the addition + of `indexOffset` (which is used to aid in emitting vectorized code) and, + if present `sgprOffset` (which is added after bounds checks and includes + any non-zero offset on the memref type). + + All indexing components are given in terms of the memref's element size, not + the byte lengths required by the intrinsic. + + Out of bounds atomic operations are ignored in hardware. + + See `amdgpu.raw_buffer_load` for a description of how the underlying + instruction is constructed. + }]; + let assemblyFormat = [{ + attr-dict $value `->` $memref `[` $indices `]` + (`sgprOffset` $sgprOffset^)? `:` + type($value) `->` type($memref) `,` type($indices) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +// Raw buffer atomic floating point max +def AMDGPU_RawBufferAtomicFmaxOp : + AMDGPU_Op<"raw_buffer_atomic_fmax", [AllElementTypesMatch<["value", "memref"]>, + AttrSizedOperandSegments]>, + Arguments<(ins AnyTypeOf<[F32, F64]>:$value, + Arg:$memref, + Variadic:$indices, + DefaultValuedAttr:$boundsCheck, + OptionalAttr:$indexOffset, + Optional:$sgprOffset)> { + + let summary = "Raw Buffer Floating-point Atomic Max (non-GFX9)"; + let description = [{ + The `amdgpu.raw_buffer_atomic_fmax` op is a wrapper around the + buffer-based atomic floating point max available on AMD GPUs (except GFX9). + + The index into the buffer is computed as for `memref.store` with the addition + of `indexOffset` (which is used to aid in emitting vectorized code) and, + if present `sgprOffset` (which is added after bounds checks and includes + any non-zero offset on the memref type). + + All indexing components are given in terms of the memref's element size, not + the byte lengths required by the intrinsic. + + Out of bounds atomic operations are ignored in hardware. + + See `amdgpu.raw_buffer_load` for a description of how the underlying + instruction is constructed. + }]; + let assemblyFormat = [{ + attr-dict $value `->` $memref `[` $indices `]` + (`sgprOffset` $sgprOffset^)? `:` + type($value) `->` type($memref) `,` type($indices) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +// Raw buffer atomic signed integer max +def AMDGPU_RawBufferAtomicSmaxOp : + AMDGPU_Op<"raw_buffer_atomic_smax", [ + AttrSizedOperandSegments]>, + Arguments<(ins I32:$value, + Arg:$memref, + Variadic:$indices, + DefaultValuedAttr:$boundsCheck, + OptionalAttr:$indexOffset, + Optional:$sgprOffset)> { + + let summary = "Raw Buffer Signed Integer Atomic Max"; + let description = [{ + The `amdgpu.raw_buffer_atomic_smax` op is a wrapper around the + buffer-based atomic signed integer max available on AMD GPUs. + + The index into the buffer is computed as for `memref.store` with the addition + of `indexOffset` (which is used to aid in emitting vectorized code) and, + if present `sgprOffset` (which is added after bounds checks and includes + any non-zero offset on the memref type). + + All indexing components are given in terms of the memref's element size, not + the byte lengths required by the intrinsic. + + Out of bounds atomic operations are ignored in hardware. + + See `amdgpu.raw_buffer_load` for a description of how the underlying + instruction is constructed. + }]; + let assemblyFormat = [{ + attr-dict $value `->` $memref `[` $indices `]` + (`sgprOffset` $sgprOffset^)? `:` + type($value) `->` type($memref) `,` type($indices) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + +// Raw buffer atomic unsigned integer min +def AMDGPU_RawBufferAtomicUminOp : + AMDGPU_Op<"raw_buffer_atomic_umin", [ + AttrSizedOperandSegments]>, + Arguments<(ins I32:$value, + Arg:$memref, + Variadic:$indices, + DefaultValuedAttr:$boundsCheck, + OptionalAttr:$indexOffset, + Optional:$sgprOffset)> { + + let summary = "Raw Buffer Unsigned Integer Atomic Min"; + let description = [{ + The `amdgpu.raw_buffer_atomic_umin` op is a wrapper around the + buffer-based atomic signed integer min available on AMD GPUs. + + The index into the buffer is computed as for `memref.store` with the addition + of `indexOffset` (which is used to aid in emitting vectorized code) and, + if present `sgprOffset` (which is added after bounds checks and includes + any non-zero offset on the memref type). + + All indexing components are given in terms of the memref's element size, not + the byte lengths required by the intrinsic. + + Out of bounds atomic operations are ignored in hardware. + + See `amdgpu.raw_buffer_load` for a description of how the underlying + instruction is constructed. + }]; + let assemblyFormat = [{ + attr-dict $value `->` $memref `[` $indices `]` + (`sgprOffset` $sgprOffset^)? `:` + type($value) `->` type($memref) `,` type($indices) + }]; + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + + +def AMDGPU_DPPOp : AMDGPU_Op<"dpp", + [Pure, SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>, + Arguments<(ins AnyType:$old, + AnyType:$src, + AMDGPU_DPPPermAttr:$kind, + OptionalAttr>:$permArgument, + DefaultValuedAttr:$row_mask, + DefaultValuedAttr:$bank_mask, + DefaultValuedAttr:$bound_ctrl)> { + let summary = "AMDGPU DPP operation"; + let description = [{ + This operation represents DPP functionality in a GPU program. + DPP provides the following operations: + - Full crossbar in a group of four (`quad_perm`) + - Wavefront shift left by one lane (`wave_shl`) + - Wavefront shift right by one lane (`wave_shr`) + - Wavefront rotate right by one lane (`wave_ror`) + - Wavefront rotate left by one lane (`wave_rol`) + - Row shift left by 1–15 lanes (`row_shl`) + - Row shift right by 1–15 lanes (`row_shr`) + - Row rotate right by 1–15 lanes (`row_ror`) + - Reverse within a row (`row_mirror`) + - Reverse within a half-row (`row_half_mirror`) + - Broadcast the 15th lane of each row to the next row (`row_bcast`) + - Broadcast lane 31 to rows 2 and 3 (`row_bcast`) + }]; + let results = (outs AnyType:$result); + let assemblyFormat = [{ + $old $src $kind (`(` $permArgument^ `)`)? attr-dict `:` type($result) + }]; + let hasVerifier = 1; +} + +def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode", + [Pure, AllTypesMatch<["result", "src"]>]>, + Arguments<(ins AnyIntegerOrFloatOr1DVector:$src, + I32Attr:$and_mask, + I32Attr:$or_mask, + I32Attr:$xor_mask + )> { + let summary = "AMDGPU ds_swizzle op, bitmode variant"; + let description = [{ + High-level wrapper on bitmode `rocdl.ds_swizzle` op, masks are represented + as separate fields so user won't need to do manual bitpacking. + + Supports arbitrary int/float/vector types, which will be repacked to i32 and + one or more `rocdl.ds_swizzle` ops during lowering. + }]; + let results = (outs AnyIntegerOrFloatOr1DVector:$result); + let assemblyFormat = [{ + $src $and_mask $or_mask $xor_mask attr-dict `:` type($result) + }]; +} + +def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["result", "src"]>]> { + let summary = "AMDGPU permlane swap op"; + let description = [{ + High-level wrapper on `rocdl.permlane{16,32}.swap` variants for permutations + on rows of lanes in a subgroup. + + Supports arbitrary int/float/vector types, which will be repacked to i32 and + one or more `rocdl.permlane_swap` ops during lowering. + Supported lane permutations: + - Swap the data between odd and even rows of 16 lanes + - Swap the data between the first 32 lanes and the last 32 lanes + + Example: + ```mlir + %0 = amdgpu.permlane_swap %src 16 : f16 + %1 = amdgpu.permlane_swap %src 32 { fetch_inactive = true, bound_ctrl = true } : f16 + ``` + + Operands: + * `$src`: Vector register to permute across lanes of the subgroup. + * `$row_length`: The length of a row to permute in number of lanes (valid values are 16 and 32). + * `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane. + `fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value. + `fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`). + * `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from + a disabled lane: use the value zero, or disable the write. + `bound_ctrl = false`: Do not write when source is from a disabled lane + `bound_ctrl = true`: Use zero as input if source is from a disabled lane + + Note: Lowering is only supported on gfx950 and up. + }]; + let arguments = (ins AnyIntegerOrFloatOr1DVector:$src, + I32Attr:$row_length, + DefaultValuedAttr:$fetch_inactive, + DefaultValuedAttr:$bound_ctrl); + let results = (outs AnyIntegerOrFloatOr1DVector:$result); + let assemblyFormat = [{ + $src $row_length attr-dict `:` type($result) + }]; + let hasVerifier = 1; +} + +def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> { + let summary = "Barrier that includes a wait for LDS memory operations."; + let description = [{ + **DEPRECATION NOTICE**: Unless you need the inline-assembly-based workaround + for gfx908/MI-100, you should represent this pattern with the equivalent + + ```mlir + gpu.barrier memfence [#gpu.address_space] + ``` + + instead. + + `amdgpu.lds_barrier` is both a barrier (all workitems in a workgroup must reach + the barrier before any of them may proceed past it) and a wait for all + operations that affect the Local Data Store (LDS) issued from that workgroup + to complete before the workgroup may continue. Since the LDS is per-workgroup + memory, this barrier may be used, for example, to ensure all workitems have + written data to LDS before any workitem attempts to read from it. + + Note that `lds_barrier` does **not** force reads to or from global memory + to complete before execution continues. Therefore, it should be used when + operations on global memory can be issued far in advance of when their results + are used (for example, by writing them to LDS). + + WARNING: On architectures that do not support the BackOffBarrier feature, + (those which will implement this barrier by emitting inline assembly), + use of this operation will impede the usabiliity of memory watches (including + breakpoints set on variables) when debugging. + }]; + let assemblyFormat = "attr-dict"; + let hasCanonicalizer = 1; +} + +def AMDGPU_SchedBarrierOp : + AMDGPU_Op<"sched_barrier">, + Arguments<(ins AMDGPU_SchedBarrierOpOptAttr:$opts)> + { + let summary = "Barrier that limits the backend scheduler of instruction movement"; + let description = [{ + `amdgpu.sched_barrier` serves as a barrier that could be + configured to restrict movements of instructions through it as + defined by sched_barrier_opts. + }]; + let assemblyFormat = [{ + `allow` `=` $opts attr-dict + }]; +} + +def AMDGPU_MemoryCounterWaitOp : + AMDGPU_Op<"memory_counter_wait">, + Arguments<(ins + OptionalAttr:$load, + OptionalAttr:$store, + OptionalAttr:$ds, + OptionalAttr:$exp, + OptionalAttr:$tensor + )> + { + let summary = "Wait for specified hardware counters"; + let description = [{ + Wait for the specified counters to be less-than or equal-to the provided + values before continuing. + + Counters can lower to different instructions on different architectires, + including clamping to the some HW supported max value or combining multiple + counters into one. + }]; + let assemblyFormat = [{ + oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` | `tensor` `(` $tensor `)` ) attr-dict + }]; + + let hasCanonicalizer = 1; +} + + +// mfma +def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64, + VectorOfLengthAndType<[2], [F32]>, + VectorOfLengthAndType<[4, 8], [F16]>, + VectorOfLengthAndType<[2, 4, 8], [BF16]>, + VectorOfLengthAndType<[4, 8, 16], [I8]>, + VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>, + VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>, + VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; +def MFMAOutTypes : AnyTypeOf<[F64, + VectorOfLengthAndType<[4, 16, 32], [F32]>, + VectorOfLengthAndType<[4, 16, 32], [I32]>, + VectorOfLengthAndType<[4], [F64]>]>; + +// sparse_mfma (smfmac) +def SMFMACSparseInTypes : AnyTypeOf<[ + VectorOfLengthAndType<[4, 8], [F16]>, + VectorOfLengthAndType<[4, 8], [BF16]>, + VectorOfLengthAndType<[8, 16], [I8]>, + VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>, + VectorOfLengthAndType<[8, 16], [F8E4M3FNUZ, F8E5M2FNUZ]> +]>; + +def SMFMACDenseInTypes : AnyTypeOf<[ + VectorOfLengthAndType<[8, 16], [F16]>, + VectorOfLengthAndType<[8, 16], [BF16]>, + VectorOfLengthAndType<[16, 32], [I8]>, + VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>, + VectorOfLengthAndType<[16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]> +]>; + +def SMFMACOutTypes : AnyTypeOf<[ + VectorOfLengthAndType<[4, 16], [F32]>, + VectorOfLengthAndType<[4, 16], [I32]> +]>; + +def SMFMACIdxTypes : AnyTypeOf<[ + FixedVectorOfLengthAndType<[4], [I8]>, + FixedVectorOfLengthAndType<[2], [I16]> +]>; + +// scaled_mfma +def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>, + VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; +def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>; + +// scaled_wmma +def ScaledWMMAInTypes + : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>, + VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>, + VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>; + +def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F32]>]>; + +// wmma +def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>, + VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>, + VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>, + VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>, + VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>; +def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, + VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>; + +def AMDGPU_MFMAOp : + AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>, + Pure]>, + Arguments<(ins + ConfinedAttr]>:$m, + ConfinedAttr]>:$n, + ConfinedAttr]>:$k, + DefaultValuedAttr]>, "1">:$blocks, + MFMAInTypes:$sourceA, + MFMAInTypes:$sourceB, + MFMAOutTypes:$destC, + DefaultValuedAttr:$cbsz, + DefaultValuedAttr:$abid, + DefaultValuedAttr:$blgp, + UnitAttr:$reducePrecision, + UnitAttr:$negateA, + UnitAttr:$negateB, + UnitAttr:$negateC)>, + Results<(outs MFMAOutTypes: $destD)> { + let summary = "MLIR wrapper for CDNA mfma instructions"; + let description = [{ + The `amdgpu.mfma` op is an MLIR wrapper around intrinsics + for various `mfma` instructions in the CDNA architecture, which perform + multiple outer products in order to allow fast matrix multiplication. + + The wrapper will select an appropriate `mfma` instruction, if one is available, + based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the + types of the source and destination arguments. + + For information on the layouts of the input and output matrices (which are stored + in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation. + + The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave + are permuted when matrix data is being loaded: `blgp` can be any number of + fixed permutations, `cbsz` specifies the log_2 of the number of chunks the lanes + holding sourceA are split into, and `abid` selects one of those chunks. + + Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA + intrinsics that take an integer type of width `4K`. For example, + one can provide a vector<4xi8> as an argument to an MFMA instruction that + logically takes 4 i8s but whose intrinsics are specified to take an i32. + In these cases, the bytes in the vector will be concatenated in little-endian + order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on). + + The negateA, negateB, and negateC flags are only supported for double-precision + operations on gfx94x. + + Example: + ```mlir + %0 = amdgpu.mfma 16x16x16 %matA * %matB + %matC + : vector<4xf16>, vector<4xf16>, vector<4xf32> + + %1 = amdgpu.mfma 32x32x1 %matD * %matE + %matF + { abid = 1 : i32, cbsz = 1 : i32, blocks = 2 : i32 } + blgp = bcast_second_32 : f32, f32, vector<32xf32> + ``` + }]; + let assemblyFormat = [{ + custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC + attr-dict + `blgp` `=` $blgp + `:` type($sourceA) `,` type($sourceB) `,` type($destC) + }]; + let hasVerifier = 1; +} + +def AMDGPU_WMMAOp : + AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>, + Pure]>, + Arguments<(ins + ConfinedAttr]>:$m, + ConfinedAttr]>:$n, + ConfinedAttr]>:$k, + WMMAInTypes:$sourceA, + WMMAInTypes:$sourceB, + WMMAOutTypes:$destC, + DefaultValuedAttr]>, "0">:$subwordOffset, + UnitAttr:$unsignedA, + UnitAttr:$unsignedB, + UnitAttr:$clamp)>, + Results<(outs WMMAOutTypes: $destD)> { + let summary = "MLIR wrapper for wmma instructions"; + let description = [{ + The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma` + instructions in the AMDGPU architecture, which perform matrix multiplication. + + On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions. + + On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for + all element types, and K=32 for i4 sources. + + On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128, + depending on the element types. + + On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 + (or 16xbf16) vector containing only 8 valid values: + - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14. + - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15. + On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all + the values are valid and the `subwordOffset` must be `0`, as it cannot be used. + + `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned. + + The `clamp` flag is used to saturate the output of type T to `numeric_limits::max()` + in case of overflow. + + Example: + ```mlir + %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16> + + %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32> + + %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32> + + %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32> + ``` + }]; + let assemblyFormat = [{ + custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC + attr-dict + `:` type($sourceA) `,` type($sourceB) `,` type($destC) + }]; + let hasVerifier = 1; +} + +def AMDGPU_SparseMFMAOp : + AMDGPU_Op<"sparse_mfma", [AllTypesMatch<["destC", "destD"]>, + Pure]>, + Arguments<(ins + ConfinedAttr]>:$m, + ConfinedAttr]>:$n, + ConfinedAttr]>:$k, + SMFMACSparseInTypes:$sourceA, + SMFMACDenseInTypes:$sourceB, + SMFMACOutTypes:$destC, + SMFMACIdxTypes:$sparseIdx, + DefaultValuedAttr:$cbsz, + DefaultValuedAttr:$abid)>, + Results<(outs SMFMACOutTypes: $destD)> { + let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions"; + let description = [{ + The `amdgpu.sparse_mfma` op is an MLIR wrapper around intrinsics for various + `smfmac` instructions in the AMDGPU architecture, which perform matrix + multiply-accumulate operations using 2:4 structured sparsity on matrix A + with dense matrices B, C, and D. + + On gfx942, smfmac intrinsics support: + - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources + - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources + + On gfx950, smfmac intrinsics additionally support: + - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources + - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources + + The `sparseIdx` parameter contains packed indices identifying the positions + of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data, + use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use + `vector<2xi16>` (two 16-bit indices). + + The `cbsz` and `abid` parameters are repurposed to select the index set. + If `cbsz == 0`, then `abid[1:0]` selects which index set to use. + If `cbsz != 0`, then the very first is selected. + + Example: + ```mlir + %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>) + : vector<4xf16>, vector<8xf16>, vector<4xf32> + + %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>) + : vector<8xi8>, vector<16xi8>, vector<4xi32> + + %2 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>) + { cbsz = 0 : i32, abid = 1 : i32 } + : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32> + ``` + }]; + let assemblyFormat = [{ + custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC + `sparse` `(` $sparseIdx `:` type($sparseIdx) `)` + attr-dict + `:` type($sourceA) `,` type($sourceB) `,` type($destC) + }]; + let hasVerifier = 1; +} + +def AMDGPU_GatherToLDSOp : + AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>, + Arguments<(ins + Arg:$src, + Variadic:$srcIndices, + Arg:$dst, + Variadic:$dstIndices, + TypeAttr:$transferType + )>, + Results<(outs)> { + let summary = "MLIR wrapper for CDNA Gather to LDS instructions"; + let description = [{ + The `amdgpu.gather_to_lds` op is a wrapper around the `global_load_lds` instructions. + + Operands: + * `$src`: global memory (including fat buffer) memref to read from. + * `$srcIndices`: indices into `$src` to read from for this thread. + * `$dst`: LDS memory memref to write to. + * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread. + The elements gathered by the subgroup will be written contiguously in order of lane ID + starting at `$dst[$dstIndices]`. Byte-sized (ex. i8) or short-sized (ex. i16) + types will be zero-padded/extended to 32 bits before being written. 96-bit types + (ex. vector<3xf32>) will be zero-padded to 128 bits before being written. Only the + offsets held by lane 0 are used. + * `$transferType`: type of the data to be transferred by each thread. This is used to determine + the size of the data to be transferred and the number of threads in the subgroup. + The transfer type must be a scalar type or a vector type with a single element type. + + The `$dst`, along with its indices, points to the memory location the subgroup of this thread + will write to. + + Note: only supported on gfx9 and gfx10. + }]; + let assemblyFormat = [{ + $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst) + }]; + let hasVerifier = 1; + let hasCanonicalizer = 1; +} + +def AMDGPU_TransposeLoadOp : + AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>, + Arguments<(ins Arg:$src, Variadic:$srcIndices)>, + Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> { + let summary = "MLIR wrapper for CDNA Transpose Load instructions"; + let description = [{ + The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions. + The transpose load op represents a subgroup load from LDS memory, + where the subgroup of threads collectively reads a matrix from the source + memref, with each thread reading a vector of the matrix, and gets a transposed matrix + in as the result. That is, each thread reads a vector of the col-major matrix at different + indices, and the thread's read result is a vector of the corresponding row of the transposed + matrix. + + This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer + to the CDNA4 ISA documentation for more details about its exact semantics. + + Format example: + ``` + %0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16> + ``` + Operands: + * `$src`: LDS memref to read from. + * `$srcIndices`: indices into `$src` to read from for this thread. + * `$result`: target register this transpose load instruction will write to. + + Note: Lowering is only supported on gfx950 and up. + }]; + let assemblyFormat = [{ + $src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result) + }]; + let hasVerifier = 1; +} + +def AMDGPU_ScaledMFMAOp : + AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>, + Pure]>, + Arguments<(ins + ConfinedAttr]>:$m, + ConfinedAttr]>:$n, + ConfinedAttr]>:$k, + ScaledMFMAInTypes:$sourceA, + ScaledMFMAInTypes:$sourceB, + ScaledMFMAOutTypes:$destC, + AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesA, + AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesB, + ConfinedAttr]>:$scalesIdxA, + ConfinedAttr]>:$scalesIdxB + )>, + Results<(outs ScaledMFMAOutTypes: $destD)> { + let summary = "MLIR wrapper for CDNA scaled mfma instructions"; + let description = [{ + The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics + for various scaled versions of `mfma` instructions in the CDNA architecture, which + perform multiple outer products in order to allow fast matrix multiplication. + + The wrapper will select an appropriate `mfma` instruction, if one is available, + based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the + types of the source and destination arguments. + + Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA + intrinsics that take an integer type of width `4K`. For example, + one can provide a `vector<4xi8>` as an argument to an MFMA instruction that + logically takes 4 i8s but whose intrinsics are specified to take an i32. + In these cases, the bytes in the vector will be concatenated in little-endian + order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on). + + This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences: + - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and + fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as + their tile size. + - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp` + are omitted from this wrapper. + - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported + for double-precision operations on gfx94x and so are not included here. + + Example: + ```mlir + %0 = amdgpu.scaled_mfma 32x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 + : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32> + ``` + }]; + let assemblyFormat = [{ + custom($m, $n, $k) ` ` + `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` + `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC + attr-dict + `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC) + }]; + let hasCanonicalizer = 1; +} + +def AMDGPU_ScaledWMMAOp + : AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>, + Arguments<(ins ConfinedAttr]>:$m, + ConfinedAttr]>:$n, + ConfinedAttr]>:$k, + ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB, + ScaledWMMAOutTypes:$destC, + VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA, + ConfinedAttr]>:$a_first_scale_lane, + VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB, + ConfinedAttr]>:$b_first_scale_lane)>, + Results<(outs ScaledWMMAOutTypes:$destD)> { + // TODO: E5M3FNU scales are supported, but there is not yet MLIR support for + // this datatype. Once we have support for that, update the scaleA and scaleB + // types here. + let summary = "MLIR wrapper for scaled wmma instructions"; + let description = [{ + The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled + `wmma` instructions. These instructions perform matrix multiplication with + per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats. + + The scale instructions support a block size of 16 or 32 and two tile sizes: + - 16x16x128 with mixed f8/f6/f4 formats (output: vector<8xf32>) + - 32x16x128 with f4 format only (output: vector<16xf32>) + + Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values + (either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during + lowering. Each lane can operate on 4 bytes (4 scale values), and the + number of scales required for each matrix is determined by: + num_scales_A = (M × K) / block_size + num_scales_B = (N × K) / block_size + + The index attributes (`a_first_scale_lane`, `b_first_scale_lane`) select + which lane to start reading scale values from (0 or 16): + - For block size 32, 32 lanes across a single wave are used for the scale + values. If the number of scales (num_scales_A or num_scales_B) can fit + into half of the available lanes + (i.e., num_scales / scales_per_lane == 16 (num_lanes)), + then then first_scale_lane can be either 0 or 16. If all lanes are required + for storing the scale values (num_scales / scales_per_lane == 32 (num_lanes)), + then the first_scale_lane must be 0. + - For block size 16, the same rules apply as above except that there are 64 + lanes across two waves that are used for the scale values. When + num_scales / scales_per_lane == 32 (num lanes), then 16 lanes from each wave are used. + first_scale_lane of 0 or 16 will decide which lanes are used for this. When + num_scales / scales_per_lane == 64 (num_lanes), then first_scale_lane must + be set to 0. + + Example: + ```mlir + // 16x16x128: fp8 inputs + %0 = amdgpu.scaled_wmma 16x16x128 (%scaleVecA * %matA) * (%scaleVecB * %matB) + %matC + {a_first_scale_lane = 0 : i32, b_first_scale_lane = 0 : i32} + : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, + vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32> + + // 32x16x128: fp4 inputs with different scale lanes + %1 = amdgpu.scaled_wmma 32x16x128 (%scaleVecD * %matD) * (%scaleVecE * %matE) + %matF + {a_first_scale_lane = 0 : i32, b_first_scale_lane = 16 : i32} + : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>, + vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32> + ``` + }]; + let assemblyFormat = [{ + custom($m, $n, $k) ` ` + `(` $scaleA `*` $sourceA `)` `*` + `(` $scaleB `*` $sourceB `)` `+` $destC + attr-dict + `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC) + }]; + let hasVerifier = 1; +} + +class AMDGPU_DmaBaseOp : + AMDGPU_Op]>, + Arguments<(ins Arg:$global, + Variadic:$global_indices, + Arg:$lds, + Variadic:$lds_indices)>, + Results<(outs outType: $base)> { + + // TODO: + // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions. + + let assemblyFormat = [{ + $global `[` $global_indices `]` `,` $lds `[` $lds_indices `]` attr-dict `:` type($global) `,` type($lds) `->` type(results) + }]; +} + +def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU_TDMGatherBaseType> { + let summary = "Pair of based addresses used when moving tiles between LDS and global memory."; + + let description = [{ + This operation creates a pair of addresses that will be used by `tensor_load_to_lds` + and `tensor_store_from_lds`. + + This operation creates a value corresponding to the tensor descriptor (D#) group 0 + found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect. + + Unlike `make_dma_base`, this operation returns `!amdgpu.tdm_gather_base<$element_type, $index_type>` + which is only compatible with `make_gather_dma_descriptor`. Using the descriptor returned + by `make_gather_dma_descriptor` will set the `tensor_load_to_lds` and `tensor_store_from_lds` to gather mode. + + ```mlir + %base = amdgpu.make_gather_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_gather_base + // %indices : i16 + %descriptor = amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_gather_base, i16 -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor + ``` + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static constexpr bool isGather() { + return true; + } + }]; +} + + +def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType> { + + let summary = "Pair of based addresses used when moving tiles between LDS and global memory."; + let description = [{ + This operation creates a pair of addresses that will be used by tensor_load_to_lds + and tensor_store_from_lds. + + This operation creates a value corresponding to the tensor descriptor (D#) group 0 + found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect. + + For example: + + ```mlir + %base = amdgpu.make_dma_base %global[%idx0, %idx1], %lds[%idx2, %idx3] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_base + %descriptor = amdgpu.make_dma_descriptor %base globalSize [2, 2] globalStride [2, 1] sharedSize [2, 2] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor + ``` + + to + + ```mlir + // pseudo-code + %global_base = llvm.extractvalue %global_memref[1] + %global_address = llvm.get_element_ptr ... + + %lds_base = llvm.extractvalue %lds_memref[1] + %lds_address = llvm.get_element_ptr ... + + // Definition of %base + %undef = llvm.mlir.undef : vector<4xi32> + %v0 = llvm.insertelement %15, %undef[0] : vector<4xi32> + %v1 = llvm.insertelement %lds_address, %v0[1] : vector<4xi32> + %v2 = llvm.insertelement %global_address_low, %v1[2] : vector<4xi32> + %base = llvm.insertelement %global_address_high, %v2[3] : vector<4xi32> + + rocdl.tensor.load.to.lds %base, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32> + ``` + + These tensor DMA operations were introduced in gfx1250. + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static constexpr bool isGather() { + return false; + } + }]; +} + +class AMDGPU_MakeDescriptorOp : + AMDGPU_Op, + Results<(outs AMDGPU_TDMDescriptorType: $desc)> { + + dag baseArgs = (ins + Variadic: $global_dynamic_sizes, + DenseI64ArrayAttr: $global_static_sizes, + Variadic: $global_dynamic_strides, + DenseI64ArrayAttr: $global_static_strides, + Variadic: $shared_dynamic_sizes, + DenseI64ArrayAttr: $shared_static_sizes, + Optional>: $workgroup_mask, + Optional: $early_timeout, + Optional: $pad_amount, + Optional: $pad_interval, + Optional: $atomic_barrier_address, + Variadic: $atomic_barrier_indices, + Optional: $global_increment, + Optional: $lds_increment, + Optional: $iteration_count); + + code extraClassDeclarationBase = [{ + int64_t getRank() { + return getGlobalStaticSizes().size(); + } + + unsigned getElementTypeWidth() { + return getBase().getType().getElementType().getIntOrFloatBitWidth(); + } + + SmallVector getMixedGlobalSizes() { + return getMixedValues(getGlobalStaticSizes(), getGlobalDynamicSizes(), getContext()); + } + + SmallVector getMixedGlobalStrides() { + return getMixedValues(getGlobalStaticStrides(), getGlobalDynamicStrides(), getContext()); + } + + SmallVector getMixedSharedSizes() { + return getMixedValues(getSharedStaticSizes(), getSharedDynamicSizes(), getContext()); + } + + }]; + +} + +def AMDGPU_MakeGatherDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_gather_dma_descriptor"> { + dag args = (ins AMDGPU_TDMGatherBaseType: $base, + AnyTypeOf<[VectorOfMinMaxLengthAndType<1, 8, [I32]>, + VectorOfMinMaxLengthAndType<1, 16, [I16]>]>: $indices); + let arguments = !con(args, baseArgs); + let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS."; + + let assemblyFormat = [{ + $base `[` $indices `]` + `globalSize` custom($global_dynamic_sizes, $global_static_sizes) + `globalStride` custom($global_dynamic_strides, $global_static_strides) + `sharedSize` custom($shared_dynamic_sizes, $shared_static_sizes) + ( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )? + ( `workgroupMask` $workgroup_mask^ ( `earlyTimeout` $early_timeout^)?)? + ( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]` + `:` type($atomic_barrier_address) `)`)? + ( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )? + attr-dict `:` qualified(type($base)) `,` type($indices) `->` type(results) + }]; + + let hasVerifier = 1; + let hasFolder = 1; + + let extraClassDeclaration = extraClassDeclarationBase # [{ + static constexpr bool isGather() { + return true; + } + }]; +} + +def AMDGPU_MakeDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_dma_descriptor"> { + dag args = (ins AMDGPU_TDMBaseType: $base); + let arguments = !con(args, baseArgs); + let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS."; + let description = [{ + Make all descriptor groups needed by tensor memory operations. + + The $base operand corresponds to the base pair addresses, one must be an address in LDS + while the other must be a global memory location. + + $global_{static/dynamic}_sizes determine the size of the tensor. + $global_{static/dynamic}_strides determine the strides of the tensor. + $shared_{static/dynamic}_sizes determines the size of the tile. + + $workgroup_mask broadcast load to workgroups inside of a workgroup cluster + (0 = do not broadcast result to workgroup, 1 = broadcast result to workgroup). Ignored for stores. + An all zeros mask is interpreted as a non-broadcasted load. + + $early_timeout return data to requesters as soon as cache supplies it. + + Padding can be applied to the LDS address when copying from memory to LDS, + but not when copying from LDS to memory. + The values in the padded target addresses remain the same as before the operation was applied. + $pad_interval must be a power of two contained in [2, 256]. + $pad_amount must be a value contained in [1, 128]. + + $atomic_barrier_address must be aligned to 8 bytes. + + 2D and 3D tensors may be iterated over by setting $global_increment, $lds_increment, and $iteration_count. + $global_increment determines how much to increment the starting global memory address per iteration in units of the $base's element type. + $lds_increment determines how much to increment the starting LDS address per iteration in units of the $base's element type. + $iterate_count determines how many times to iterate, it must be a value in the inclusive interval [1, 256]. + + ```mlir + // Example of moving a two-dimensional tensor to LDS. + %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<64x64xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_base + %descriptor = amdgpu.make_dma_descriptor %base globalSize [64, 64] globalStride [64, 1] sharedSize [64, 64] : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor + + // Example of moving a two dimension tensor to LDS where padding is applied after every integer. + %base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space> -> !amdgpu.tdm_base + %descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padShared(%pad_amount every %pad_interval) : !amdgpu.tdm_base -> !amdgpu.tdm_descriptor + amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor + ``` + }]; + + let assemblyFormat = [{ + $base + `globalSize` custom($global_dynamic_sizes, $global_static_sizes) + `globalStride` custom($global_dynamic_strides, $global_static_strides) + `sharedSize` custom($shared_dynamic_sizes, $shared_static_sizes) + ( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )? + ( `workgroupMask` $workgroup_mask^ ( `earlyTimeout` $early_timeout^)?)? + ( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]` + `:` type($atomic_barrier_address) `)`)? + ( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )? + attr-dict `:` qualified(type($base)) `->` type(results) + }]; + + let hasVerifier = 1; + let hasFolder = 1; + + let extraClassDeclaration = extraClassDeclarationBase # [{ + static constexpr bool isGather() { + return false; + } + }]; + +} + +def AMDGPU_TensorLoadToLDSOp : + AMDGPU_Op<"tensor_load_to_lds", [MemoryEffects<[MemWrite, MemRead]>]>, + Arguments<(ins AMDGPU_TDMDescriptorType: $desc)> { + let summary = "Load tensors from global memory to LDS."; + let description = [{ + Load tensors of up to five dimensions from global memory to LDS. + + This operation was introduced in gfx1250. + }]; + + let assemblyFormat = [{ + $desc attr-dict `:` qualified(type($desc)) + }]; +} + +def AMDGPU_TensorStoreFromLDSOp : + AMDGPU_Op<"tensor_store_from_lds", [MemoryEffects<[MemWrite, MemRead]>]>, + Arguments<(ins AMDGPU_TDMDescriptorType: $desc)> { + + let summary = "Store tensors from LDS to global memory."; + let description = [{ + Store tensors of up to five dimensions from LDS to global memory. + + This operation was introduced in gfx1250. + }]; + + let assemblyFormat = [{ + $desc attr-dict `:` qualified(type($desc)) + }]; +} + +#endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td new file mode 100644 index 0000000000000..3ea1ba35815c3 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUTypes.td @@ -0,0 +1,72 @@ +//===-- AMDGPUTypes.td - AMDGPU dialect types *- tablegen -*---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUTYPES_TD +#define MLIR_DIALECT_AMDGPU_IR_AMDGPUTYPES_TD + +include "mlir/Dialect/AMDGPU/IR/AMDGPUBase.td" +include "mlir/IR/AttrTypeBase.td" + +//===----------------------------------------------------------------------===// +// AMDGPU Type definitions +//===----------------------------------------------------------------------===// + +class AMDGPU_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> { + let summary = "Pair of base addresses that move data between LDS and global storage."; + let description = [{ + This type is opaque and it is used to represent a struct of two addresses. + One address is in LDS while the other is in global memory. + + The value defined by this operation is only intended to be used by + amdgpu.tdm_make_descriptor. + }]; + let parameters = (ins "Type":$elementType); + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ + return $_get(elementType.getContext(), elementType); + }]> + ]; + let assemblyFormat = "`<` $elementType `>`"; +} + +def AMDGPU_TDMGatherBaseType : AMDGPU_Type<"TDMGatherBase", "tdm_gather_base"> { + let summary = "Pair of base addresses that move data between LDS and global storage."; + let description = [{ + This type is opaque and it is used to represent a struct of two addresses. + One address is in LDS while the other is in global memory. + + This operation is similar to amdgpu.tdm_make_base but intended to be + used in gather mode. + + The value defined by this operation is only intended to be used by + amdgpu.tdm_make_gather_descriptor. + }]; + let parameters = (ins "Type":$elementType, "Type":$indexType); + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType, "Type": $indexType), [{ + return $_get(elementType.getContext(), elementType, indexType); + }]> + ]; + let assemblyFormat = "`<` $elementType `,` $indexType`>`"; + let genVerifyDecl = 1; +} + +def AMDGPU_TDMDescriptorType : AMDGPU_Type<"TDMDescriptor", "tdm_descriptor"> { + let summary = "Descriptors used in tensor store/load operations."; + let description = [{ + This type is opaque and corresponds to the two or four descriptor groups + used in tensor_load_to_lds or tensor_store_from_lds. + }]; +} + +#endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUTYPES_TD diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt index cab34696946e6..3462c3b8b74ae 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/CMakeLists.txt @@ -1,12 +1,12 @@ add_mlir_dialect(AMDGPU amdgpu) add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc) -set(LLVM_TARGET_DEFINITIONS AMDGPU.td) +set(LLVM_TARGET_DEFINITIONS AMDGPUEnums.td) mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls) mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs) add_mlir_dialect_tablegen_target(MLIRAMDGPUEnumsGen) -set(LLVM_TARGET_DEFINITIONS AMDGPU.td) -mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu) -mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu) +set(LLVM_TARGET_DEFINITIONS AMDGPUAttrs.td) +mlir_tablegen(AMDGPUAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu) +mlir_tablegen(AMDGPUAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu) add_mlir_dialect_tablegen_target(MLIRAMDGPUAttributesIncGen) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUAttrs.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUAttrs.cpp new file mode 100644 index 0000000000000..9428b166a8cf0 --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUAttrs.cpp @@ -0,0 +1,26 @@ +//===- AMDGPUAttrs.cpp - MLIR AMDGPU dialect attributes -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AMDGPU dialect attributes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" + +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.cpp.inc" + +void mlir::amdgpu::AMDGPUDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttrs.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index dd741d56d39d0..1d4e5eddce019 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -15,27 +15,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" - -#include -#include -#include -#include using namespace mlir; using namespace mlir::amdgpu; @@ -56,1194 +38,7 @@ void AMDGPUDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" - >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" - >(); + registerTypes(); + registerAttributes(); addInterfaces(); } - -//===----------------------------------------------------------------------===// -// 8-bit float ops -//===----------------------------------------------------------------------===// -LogicalResult PackedTrunc2xFp8Op::verify() { - if (getExisting() && getExisting().getType() != getResult().getType()) - return emitOpError("existing values must have same type as result"); - return success(); -} - -LogicalResult PackedStochRoundFp8Op::verify() { - if (getExisting() && getExisting().getType() != getResult().getType()) - return emitOpError("existing values must have same type as result"); - return success(); -} - -//===----------------------------------------------------------------------===// -// mxfp float ops -//===----------------------------------------------------------------------===// -LogicalResult PackedScaledTruncOp::verify() { - if (getExisting() && getExisting().getType() != getResult().getType()) - return emitOpError("existing values must have same type as result"); - return success(); -} - -//===----------------------------------------------------------------------===// -// FatRawBufferCastOp -//===----------------------------------------------------------------------===// - -/// Convert the type `source` to one with the same sizes and strides - and -/// offset, unless `stripOffset` is true, in which case the offset is reset to -/// 0, if the offset should be reset but the layout of `source` isn't either the -/// identity layout or a strided layout, this function fails. -static FailureOr getFatRawBufferTypeLike(MemRefType source, - bool resetOffset) { - MLIRContext *ctx = source.getContext(); - MemRefType::Builder mb(source); - mb.setMemorySpace( - amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer)); - MemRefLayoutAttrInterface layout = source.getLayout(); - if (resetOffset && !layout.isIdentity()) { - auto stridedLayout = dyn_cast(layout); - if (!stridedLayout) - return failure(); - MemRefLayoutAttrInterface newLayout = - StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()); - // Special case: if resetting the offset causes the strided layout to become - // the identity layout, then reset to the identity layout. - // TODO: this'll get a lot simpler when we have the contiguous layout. - SmallVector stridesIfIdentity; - if (source.hasStaticShape()) { - stridesIfIdentity = computeSuffixProduct(source.getShape()); - } else if (source.getRank() <= 1) { - stridesIfIdentity = SmallVector(source.getRank(), 1); - } - if (stridesIfIdentity == stridedLayout.getStrides()) { - newLayout = AffineMapAttr::get( - AffineMap::getMultiDimIdentityMap(source.getRank(), ctx)); - } - mb.setLayout(newLayout); - } - return (MemRefType)(mb); -} - -LogicalResult FatRawBufferCastOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - Adaptor adaptor(operands, attributes, properties, regions); - auto sourceType = - dyn_cast_if_present(adaptor.getSource().getType()); - if (!sourceType) - return failure(); - FailureOr resultType = - getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset()); - if (failed(resultType)) - return failure(); - inferredReturnTypes = SmallVector{*resultType}; - return success(); -} - -FailureOr FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder, - int resultIndex, - int dim) { - assert(resultIndex == 0 && "FatRawBufferCastOp has a single result"); - return memref::getMixedSize(builder, getLoc(), getSource(), dim); -} - -LogicalResult FatRawBufferCastOp::verify() { - FailureOr expectedResultType = - getFatRawBufferTypeLike(getSource().getType(), getResetOffset()); - if (failed(expectedResultType)) - return emitOpError("source type ") - << getSource().getType() << " can't have its offset reset"; - if (getResult().getType() != *expectedResultType) - return emitOpError("expected result type to be ") - << *expectedResultType << " but got " << getResult().getType(); - return success(); -} - -static bool hasGlobalMemorySpace(Attribute memorySpace) { - if (!memorySpace) - return true; - if (auto intMemorySpace = dyn_cast(memorySpace)) - return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; - if (auto gpuMemorySpace = dyn_cast(memorySpace)) - return gpuMemorySpace.getValue() == gpu::AddressSpace::Global; - return false; -} - -static bool hasWorkgroupMemorySpace(Attribute memorySpace) { - if (!memorySpace) - return false; - if (auto intMemorySpace = dyn_cast(memorySpace)) - return intMemorySpace.getInt() == 3; - if (auto gpuMemorySpace = dyn_cast(memorySpace)) - return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup; - return false; -} - -static bool hasFatRawBufferMemorySpace(Attribute memorySpace) { - if (!memorySpace) - return false; - if (auto intMemorySpace = dyn_cast(memorySpace)) - return intMemorySpace.getInt() == 7; - if (auto gpuMemorySpace = dyn_cast(memorySpace)) - return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer; - return false; -} - -//===----------------------------------------------------------------------===// -// RawBuffer*Op -//===----------------------------------------------------------------------===// -template -static LogicalResult verifyRawBufferOp(T &op) { - MemRefType bufferType = llvm::cast(op.getMemref().getType()); - bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace()); - - if (!isGlobal) - return op.emitOpError( - "Buffer ops must operate on a memref in global memory"); - if (!bufferType.hasRank()) - return op.emitOpError( - "Cannot meaningfully buffer_store to an unranked memref"); - if (static_cast(op.getIndices().size()) != bufferType.getRank()) - return op.emitOpError("Expected " + Twine(bufferType.getRank()) + - " indices to memref"); - return success(); -} - -LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); } - -LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); } - -LogicalResult RawBufferAtomicFaddOp::verify() { - return verifyRawBufferOp(*this); -} - -LogicalResult RawBufferAtomicFmaxOp::verify() { - return verifyRawBufferOp(*this); -} - -LogicalResult RawBufferAtomicSmaxOp::verify() { - return verifyRawBufferOp(*this); -} - -LogicalResult RawBufferAtomicUminOp::verify() { - return verifyRawBufferOp(*this); -} - -LogicalResult RawBufferAtomicCmpswapOp::verify() { - return verifyRawBufferOp(*this); -} - -static std::optional getConstantUint32(Value v) { - APInt cst; - if (!v.getType().isInteger(32)) - return std::nullopt; - if (matchPattern(v, m_ConstantInt(&cst))) - return cst.getZExtValue(); - return std::nullopt; -} - -template -static bool staticallyOutOfBounds(OpType op) { - if (!op.getBoundsCheck()) - return false; - MemRefType bufferType = op.getMemref().getType(); - if (!bufferType.hasStaticShape()) - return false; - int64_t offset; - SmallVector strides; - if (failed(bufferType.getStridesAndOffset(strides, offset))) - return false; - int64_t result = offset + op.getIndexOffset().value_or(0); - if (op.getSgprOffset()) { - std::optional sgprOffset = getConstantUint32(op.getSgprOffset()); - if (!sgprOffset) - return false; - result += *sgprOffset; - } - if (strides.size() != op.getIndices().size()) - return false; - int64_t indexVal = 0; - for (auto pair : llvm::zip(strides, op.getIndices())) { - int64_t stride = std::get<0>(pair); - Value idx = std::get<1>(pair); - std::optional idxVal = getConstantUint32(idx); - if (!idxVal) - return false; - indexVal += stride * *idxVal; - } - result += indexVal; - if (result > std::numeric_limits::max()) - // Overflow means don't drop - return false; - return result >= bufferType.getNumElements(); -} - -namespace { -template -struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { - if (!staticallyOutOfBounds(op)) - return failure(); - Type loadType = op.getResult().getType(); - rw.replaceOpWithNewOp(op, loadType, - rw.getZeroAttr(loadType)); - return success(); - } -}; - -template -struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { - if (!staticallyOutOfBounds(op)) - return failure(); - - rw.eraseOp(op); - return success(); - } -}; -} // end namespace - -void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>(context); -} - -void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>(context); -} - -void RawBufferAtomicFaddOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add>(context); -} - -void RawBufferAtomicFmaxOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add>(context); -} - -void RawBufferAtomicSmaxOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add>(context); -} - -void RawBufferAtomicUminOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add>(context); -} - -void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add>( - context); -} - -//===----------------------------------------------------------------------===// -// ScaledExtPackedMatrixOp -//===----------------------------------------------------------------------===// -LogicalResult ScaledExtPackedMatrixOp::verify() { - int blockSize = getBlockSize(); - assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size"); - - int firstScaleByte = getFirstScaleByte(); - int firstScaleLane = getFirstScaleLane(); - auto sourceType = cast(getSource().getType()); - Type elementType = sourceType.getElementType(); - auto floatType = cast(elementType); - unsigned bitWidth = floatType.getWidth(); - - assert(llvm::is_contained(llvm::ArrayRef{4, 6, 8}, bitWidth)); - - const bool is_fp8 = bitWidth == 8; - const bool is_block_16 = blockSize == 16; - - if (!is_fp8) { - if (is_block_16) { - if (!llvm::is_contained({0, 1}, firstScaleByte)) { - return emitOpError("blockSize of 16 can only have firstScaleByte be 0 " - "or 1 for f4 and f6."); - } - } else { - if (!llvm::is_contained({0, 2}, firstScaleByte)) { - return emitOpError("blockSize of 32 can only have firstScaleByte be 0 " - "or 2 for f4 and f6."); - } - } - } else { - if (is_block_16) { - bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) || - ((firstScaleLane == 16) && (firstScaleByte == 2)); - if (!is_valid) { - return emitOpError("blockSize of 16 can only have (firstScaleLane, " - "firstScaleByte) be (0, 0) or (16, 2) for f8."); - } - } - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// WMMAOp -//===----------------------------------------------------------------------===// - -ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, - IntegerAttr &m, IntegerAttr &n, - IntegerAttr &k) { - SmallVector dimensions; - if (parser.parseDimensionList(dimensions, false, false)) - return failure(); - if (dimensions.size() != 3) - return parser.emitError(parser.getCurrentLocation()) - << "expected 3 dimensions in MNK dimension list"; - - m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); - n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); - k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); - return success(); -} - -LogicalResult WMMAOp::verify() { - auto sourceAType = cast(getSourceA().getType()); - auto sourceBType = cast(getSourceB().getType()); - auto destType = cast(getDestC().getType()); - - Type sourceAElemType = sourceAType.getElementType(); - Type sourceBElemType = sourceBType.getElementType(); - if (sourceAType.getNumElements() != sourceBType.getNumElements()) { - return emitOpError("source vectors have different lengths: ") - << sourceAType << " vs. " << sourceBType; - } - - bool isDestFloat = destType.getElementType().isFloat(); - bool isSrcFloat = sourceAElemType.isFloat(); - - if (isDestFloat && !isSrcFloat) - return emitOpError("expected float sources with float destination"); - if (!isDestFloat && isSrcFloat) - return emitOpError("expected int sources with int destination"); - - if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { - return emitOpError( - "source element types must match (except for fp8/bf8) but have ") - << sourceAType << " and " << sourceBType; - } - - if (isSrcFloat) { - if (getClamp()) - return emitOpError("clamp flag is not supported for float types"); - if (getUnsignedA() || getUnsignedB()) - return emitOpError("unsigned flags are not supported for float types"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// ScaledWMMAOp -//===----------------------------------------------------------------------===// - -LogicalResult ScaledWMMAOp::verify() { - // Helper functions for type classification. - auto isF8 = llvm::IsaPred; - auto isF6 = llvm::IsaPred; - auto isF4 = llvm::IsaPred; - auto isScaleF8 = llvm::IsaPred; - auto isE8M0 = llvm::IsaPred; - auto isE4M3 = llvm::IsaPred; - - auto sourceAType = cast(getSourceA().getType()); - auto sourceBType = cast(getSourceB().getType()); - auto destType = cast(getDestC().getType()); - - // Validate source element types are small floats (fp4/fp6/fp8). - Type aElemType = sourceAType.getElementType(); - Type bElemType = sourceBType.getElementType(); - - // Validate vector lengths based on dimensions. - int64_t m = getM(); - int64_t aLen = sourceAType.getNumElements(); - int64_t bLen = sourceBType.getNumElements(); - int64_t expectedOutLen = (m == 16) ? 8 : 16; - - if (destType.getNumElements() != expectedOutLen) - return emitOpError("expected output vector of length ") - << expectedOutLen << " but got " << destType.getNumElements(); - - if (m == 16) { - // For 16×16×128: both A and B must be 64 elements. - if (aLen != 64) - return emitOpError( - "for 16x16x128, sourceA must have 64 elements but got ") - << aLen; - if (bLen != 64) - return emitOpError( - "for 16x16x128, sourceB must have 64 elements but got ") - << bLen; - } else { // m == 32 - // For 32×16×128: only fp4 is supported, A is 128, B is 64. - if (!isF4(aElemType) && !isF4(bElemType)) - return emitOpError("32x16x128 only supports fp4 element types"); - - if (aLen != 128) - return emitOpError( - "for 32x16x128, sourceA must have 128 elements but got ") - << aLen; - if (bLen != 64) - return emitOpError( - "for 32x16x128, sourceB must have 64 elements but got ") - << bLen; - - // For 32x16x128, matrix A uses all 32 lanes so a_first_scale_lane must be - // 0. - if (getAFirstScaleLane() != 0) - return emitOpError("for 32x16x128, a_first_scale_lane must be 0"); - } - - // Validate scale types and their compatibility with matrix element types. - auto scaleAType = cast(getScaleA().getType()); - auto scaleBType = cast(getScaleB().getType()); - Type scaleAElemType = scaleAType.getElementType(); - Type scaleBElemType = scaleBType.getElementType(); - - // Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN). - if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType)) - return emitOpError( - "scale operands must have f8 element types (E8M0FNU or E4M3FN)"); - - // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid. - if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType)) - return success(); - - // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3). - if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) && - isF4(bElemType) && isE4M3(scaleBElemType)) - return success(); - - // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0). - if (isF4(aElemType) && isE4M3(scaleAElemType) && - (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType)) - return success(); - - // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3). - if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) && - isE4M3(scaleBElemType)) - return success(); - - // No valid combination matched. - return emitOpError("invalid combination of matrix and scale types: ") - << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType - << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType; -} - -//===----------------------------------------------------------------------===// -// MFMAOp -//===----------------------------------------------------------------------===// -LogicalResult MFMAOp::verify() { - constexpr uint32_t waveSize = 64; - Builder b(getContext()); - - Type sourceType = getSourceA().getType(); - Type destType = getDestC().getType(); - - Type sourceElem = sourceType, destElem = destType; - uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = dyn_cast(sourceType)) { - sourceLen = sourceVector.getNumElements(); - sourceElem = sourceVector.getElementType(); - } - if (auto destVector = dyn_cast(destType)) { - destLen = destVector.getNumElements(); - destElem = destVector.getElementType(); - } - - Type sourceBType = getSourceB().getType(); - if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) { - int64_t sourceBLen = 1; - Type sourceBElem = sourceBType; - if (auto sourceBVector = llvm::dyn_cast(sourceBType)) { - sourceBLen = sourceBVector.getNumElements(); - sourceBElem = sourceBVector.getElementType(); - } - if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) && - !sourceBElem.isFloat(4)) - return emitOpError("expected both source operands to have small-float " - "elements if one does"); - if (sourceLen != sourceBLen) - return emitOpError( - "expected both small-float source vectors to have the same length"); - } else { - if (sourceType != sourceBType) - return emitOpError("expected both non-small-float source operand types " - "to match exactly"); - } - // Normalize the wider integer types the compiler expects to i8. - if (sourceElem.isInteger(32)) { - sourceLen *= 4; - sourceElem = b.getI8Type(); - } - if (sourceElem.isInteger(64)) { - sourceLen *= 8; - sourceElem = b.getI8Type(); - } - - int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize; - if (sourceLen != numSourceElems) - return emitOpError("expected " + Twine(numSourceElems) + - " source values for this operation but got " + - Twine(sourceLen)); - - int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize; - if (destLen != numDestElems) - return emitOpError("expected " + Twine(numDestElems) + - " result values for this operation but got " + - Twine(destLen)); - - if (destElem.isF64() && getBlgp() != MFMAPermB::none) - return emitOpError( - "double-precision ops do not support permuting lanes of B"); - if (destElem.isF64() && getCbsz() != 0) - return emitOpError( - "double-precision ops do not support permuting lanes of A"); - if (getAbid() >= (1u << getCbsz())) - return emitOpError( - "block ID for permuting A (abid) must be below 2 ** cbsz"); - - if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64()) - return emitOpError( - "negation flags only available for double-precision operations"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// SparseMFMAOp -//===----------------------------------------------------------------------===// - -LogicalResult SparseMFMAOp::verify() { - constexpr uint32_t waveSize = 64; - - auto sparseType = cast(getSourceA().getType()); - auto denseType = cast(getSourceB().getType()); - auto destType = cast(getDestC().getType()); - - Type sparseElem = sparseType.getElementType(); - Type denseElem = denseType.getElementType(); - int64_t sparseLen = sparseType.getNumElements(); - int64_t denseLen = denseType.getNumElements(); - int64_t destLen = destType.getNumElements(); - - if (denseLen != 2 * sparseLen) - return emitOpError("expected dense source operand to have exactly double " - "the number of elements of the sparse source operand"); - - // Check that source element types are compatible. - // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8). - // For other types, element types must match exactly. - bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8); - if (!bothFloat8 && sparseElem != denseElem) - return emitOpError( - "expected source operands to have the same element type"); - - // When CBSZ == 0, ABID selects the index set within the sparse index VGPR. - // When CBSZ != 0, the first index set is always used (ABID ignored). - bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8); - // 8-bit source: ABID selects one of two 16-bit index sets. - if (getCbsz() == 0 && is8BitSource && getAbid() > 1) - return emitOpError("ABID must be 0 or 1 for 8-bit source data"); - // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid). - if (getCbsz() == 0 && !is8BitSource && getAbid() > 3) - return emitOpError("ABID must be between 0 and 3 for 16-bit source data"); - - // Validate sparseIdx type matches source element type. - auto sparseIdxType = cast(getSparseIdx().getType()); - if (is8BitSource) { - // 8-bit source data requires vector<2xi16> sparse indices. - if (sparseIdxType.getNumElements() != 2 || - !sparseIdxType.getElementType().isInteger(16)) - return emitOpError("expected vector<2xi16> sparse indices for 8-bit " - "source data, but got ") - << getSparseIdx().getType(); - } else { - // 16-bit source data requires vector<4xi8> sparse indices. - if (sparseIdxType.getNumElements() != 4 || - !sparseIdxType.getElementType().isInteger(8)) - return emitOpError("expected vector<4xi8> sparse indices for 16-bit " - "source data, but got ") - << getSparseIdx().getType(); - } - - int64_t expectedSourceElems = (getM() * getK()) / waveSize; - if (denseLen != expectedSourceElems) - return emitOpError("expected " + Twine(expectedSourceElems) + - " source values for this operation but got " + - Twine(denseLen)); - - int64_t expectedDestElems = (getM() * getN()) / waveSize; - if (destLen != expectedDestElems) - return emitOpError("expected " + Twine(expectedDestElems) + - " result values for this operation but got " + - Twine(destLen)); - - return success(); -} - -//===----------------------------------------------------------------------===// -// DPPOp -//===----------------------------------------------------------------------===// -LogicalResult DPPOp::verify() { - Type srcType = getSrc().getType(); - if (srcType.getIntOrFloatBitWidth() > 64) { - return emitOpError("integer and floating point types larger than 64 bits " - "are not supported"); - } - - DPPPerm kind = getKind(); - Attribute permArgument = getPermArgument().value_or(Attribute{}); - - switch (kind) { - - case DPPPerm::quad_perm: { - auto quadPermAttr = dyn_cast_or_null(permArgument); - if (!quadPermAttr || quadPermAttr.size() != 4) { - return emitOpError("quad_perm attribute must have exactly 4 elements"); - } - for (auto elem : quadPermAttr.getAsRange()) { - int32_t num = elem.getInt(); - if (num < 0 || num > 3) { - return emitOpError( - "Each element of quad_perm must be in the range [0, 3]"); - } - } - } break; - - case DPPPerm::row_shl: - case DPPPerm::row_shr: - case DPPPerm::row_ror: { - if (!permArgument) { - return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) + - "' value not specified"); - } - if (auto intAttr = dyn_cast(permArgument)) { - uint32_t attrValue = intAttr.getInt(); - if (attrValue < 1 || attrValue > 15) { - return emitOpError("Attribute value must be between 1 and 15"); - } - } - } break; - - case DPPPerm::wave_shl: - case DPPPerm::wave_shr: - case DPPPerm::wave_rol: - case DPPPerm::wave_ror: - case DPPPerm::row_mirror: - case DPPPerm::row_half_mirror: - case DPPPerm::row_bcast_15: - case DPPPerm::row_bcast_31: { - if (permArgument && !isa(permArgument)) { - return emitOpError("Expected unit attribute for permArgument, but found " - "non-trivial argument"); - } - break; - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// PermlaneSwapOp -//===----------------------------------------------------------------------===// -LogicalResult PermlaneSwapOp::verify() { - unsigned rowLength = getRowLength(); - - if (rowLength != 16 && rowLength != 32) - return emitOpError("row_length attribute must either be 16 or 32."); - - return success(); -} - -/// Remove amdgpu.lds_barrier after amdgpu.lds_barrier. -static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, - PatternRewriter &rewriter) { - if (isa_and_nonnull(op->getNextNode())) { - rewriter.eraseOp(op); - return success(); - } - return failure(); -} - -void LDSBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(eraseRedundantLDSBarrierOps); -} - -//===----------------------------------------------------------------------===// -// MemoryCounterWaitOp -//===----------------------------------------------------------------------===// - -namespace { -/// Fuse adjacent memory counter wait ops, taking the minimum value of the -/// counters. -struct FuseMemoryCounterWaitOp final : OpRewritePattern { - using Base::Base; - - LogicalResult matchAndRewrite(MemoryCounterWaitOp op, - PatternRewriter &rewriter) const override { - auto next = dyn_cast(op->getNextNode()); - if (!next) - return failure(); - - auto setters = {&MemoryCounterWaitOp::setLoad, - &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs, - &MemoryCounterWaitOp::setExp, - &MemoryCounterWaitOp::setTensor}; - auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(), - op.getTensor()}; - auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(), - next.getExp(), next.getTensor()}; - rewriter.modifyOpInPlace(op, [&] { - for (auto [setter, lhs, rhs] : - llvm::zip_equal(setters, lhsVals, rhsVals)) { - if (lhs && rhs) { - (op.*setter)(std::min(*lhs, *rhs)); - } else if (lhs) { - (op.*setter)(*lhs); - } else if (rhs) { - (op.*setter)(*rhs); - } - } - }); - rewriter.eraseOp(next); - return success(); - } -}; -} // namespace - -void MemoryCounterWaitOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - -//===----------------------------------------------------------------------===// -// GatherToLDSOp -//===----------------------------------------------------------------------===// - -LogicalResult GatherToLDSOp::verify() { - MemRefType srcType = cast(getSrc().getType()); - MemRefType dstType = cast(getDst().getType()); - - if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1)) - return emitOpError("destination type inner most dim must be contiguous"); - - auto elemType = srcType.getElementType(); - // Check $src and $dst element types are the same. - if (elemType != dstType.getElementType()) - return emitOpError("source and destination element types must match"); - - // copy type sizes should be 1, 2, 4, 12 or 16 bytes. - auto transferType = getTransferType(); - int transferSize; - if (auto vectorTransfer = dyn_cast(transferType)) { - transferSize = vectorTransfer.getNumElements() * - vectorTransfer.getElementTypeBitWidth(); - } else { - transferSize = transferType.getIntOrFloatBitWidth(); - } - if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize)) - return emitOpError( - "Transfering type size must be 8, 16, 32, 96 or 128 bits"); - - if (!hasGlobalMemorySpace(srcType.getMemorySpace()) && - !hasFatRawBufferMemorySpace(srcType.getMemorySpace())) - return emitOpError( - "source memory address space must be global or fat raw buffer"); - - if (!hasWorkgroupMemorySpace(dstType.getMemorySpace())) - return emitOpError("destination memory address space must be Workgroup"); - - return success(); -} - -namespace { -/// If the source/target of a GatherToLDSOp is a CastOp that only removes static -/// information or changes layout, the cast can be skipped. -struct FoldGatherToLDSOfCast final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, - PatternRewriter &rewriter) const override { - bool modified = false; - auto foldCast = [&](OpOperand &operand) { - if (auto castOp = operand.get().getDefiningOp()) { - if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { - rewriter.modifyOpInPlace(gatherOp, - [&] { operand.assign(castOp.getSource()); }); - modified = true; - } - } - }; - - foldCast(gatherOp.getSrcMutable()); - foldCast(gatherOp.getDstMutable()); - - return success(modified); - } -}; -} // namespace - -void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -//===----------------------------------------------------------------------===// -// TransposeLoadOp -//===----------------------------------------------------------------------===// - -LogicalResult TransposeLoadOp::verify() { - MemRefType srcType = cast(getSrc().getType()); - - if (!hasWorkgroupMemorySpace(srcType.getMemorySpace())) - return emitOpError("source memory address space must be Workgroup"); - - auto transferType = cast(getType()); - size_t numElements = transferType.getNumElements(); - size_t elementTypeSize = - transferType.getElementType().getIntOrFloatBitWidth(); - - // ElementSize -> NumElements - const llvm::SmallDenseMap kValidLoadSizeMap = { - {4, 16}, - {6, 16}, - {8, 8}, - {16, 4}, - }; - - auto validNumElems = kValidLoadSizeMap.find(elementTypeSize); - if (validNumElems == kValidLoadSizeMap.end()) - return emitOpError("Unsupported element type size for transpose load: ") - << elementTypeSize << " bits"; - - if (numElements != validNumElems->second) - return emitOpError( - "Transferring type size mismatch: expected num of elements: ") - << validNumElems->second; - - return success(); -} - -//===----------------------------------------------------------------------===// -// MakeDmaBaseOp -//===----------------------------------------------------------------------===// - -template -static LogicalResult verifyBase(BaseOp op) { - auto ldsType = cast(op.getLds().getType()); - auto globalType = cast(op.getGlobal().getType()); - if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace())) - return op.emitOpError( - "lds memref must have workgroup address space attribute."); - if (!hasGlobalMemorySpace(globalType.getMemorySpace())) - return op.emitOpError( - "global memref must have global address space attribute."); - - Type elementType = ldsType.getElementType(); - unsigned width = elementType.getIntOrFloatBitWidth(); - - if (!llvm::is_contained({8u, 16u, 32u, 64u}, width)) - return op.emitOpError( - "element type must be 1, 2, 4, or 8 bytes long but type was ") - << width << " bits long."; - return success(); -} - -LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); } - -//===----------------------------------------------------------------------===// -// MakeGatherDmaBaseOp -//===----------------------------------------------------------------------===// - -LogicalResult -TDMGatherBaseType::verify(function_ref emitError, - Type elementType, Type indexType) { - unsigned width = elementType.getIntOrFloatBitWidth(); - if (!llvm::is_contained({8u, 16u, 32u, 64u}, width)) - return emitError() - << "element type must be 1, 2, 4, or 8 bytes wide but type " - << elementType << " is " << width / 8 << " bytes wide."; - MLIRContext *ctx = elementType.getContext(); - Type i16 = IntegerType::get(ctx, 32); - Type i32 = IntegerType::get(ctx, 16); - if (!llvm::is_contained({i16, i32}, indexType)) - return emitError() << "index type must be i16 or i32 but index type is " - << indexType << "."; - return success(); -} - -LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); } - -//===----------------------------------------------------------------------===// -// MakeDmaDescriptorOp -//===----------------------------------------------------------------------===// - -template -static LogicalResult verifyDescriptorOp(DescriptorOp op) { - ArrayRef globalStaticStrides = op.getGlobalStaticStrides(); - - if (globalStaticStrides.empty()) - return op.emitOpError("strides must not be empty."); - if (globalStaticStrides.back() != 1) - return op.emitOpError("strides for the innermost dimension must be 1."); - - ArrayRef globalStaticSizes = op.getGlobalStaticSizes(); - size_t rank = globalStaticSizes.size(); - if (rank > 5) - return op.emitOpError("tensor and tile must be at most of rank 5."); - if (rank != globalStaticStrides.size()) - return op.emitOpError("strides and sizes must have same rank."); - - ArrayRef sharedStaticSizes = op.getSharedStaticSizes(); - if (rank != sharedStaticSizes.size()) - return op.emitOpError("tensor must have same rank as tile."); - - unsigned elementTypeWidth = op.getElementTypeWidth(); - if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth)) - return op.emitOpError( - "element type width must be 1, 2, 4 or 8 bytes, but was ") - << elementTypeWidth << " bits long"; - - if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) { - auto atomicBarrierAddressType = - cast(atomicBarrierAddress.getType()); - bool barrierInLDS = - hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace()); - if (!barrierInLDS) - return op.emitOpError("atomic barrier address must be in LDS."); - } - - if (op.getEarlyTimeout() && !op.getWorkgroupMask()) - return op.emitOpError( - "early timeout does not apply when workgroup_mask is not set."); - return success(); -} - -template -static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) { - SmallVector mixedGlobalSizes(op.getMixedGlobalSizes()); - SmallVector mixedGlobalStrides(op.getMixedGlobalStrides()); - SmallVector mixedSharedSizes(op.getMixedSharedSizes()); - - if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true, - /*onlyNonZero=*/true)) && - failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true, - /*onlyNonZero=*/true)) && - failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true, - /*onlyNonZero=*/true))) - return nullptr; - - SmallVector dynamicGlobalSizes, dynamicGlobalStrides, - dynamicSharedSizes; - SmallVector staticGlobalSizes, staticGlobalStrides, - staticSharedSizes; - - dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes, - staticGlobalSizes); - op.setGlobalStaticSizes(staticGlobalSizes); - op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes); - - dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides, - staticGlobalStrides); - op.setGlobalStaticStrides(staticGlobalStrides); - op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides); - - dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes, - staticSharedSizes); - op.setSharedStaticSizes(staticSharedSizes); - op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes); - return op.getResult(); -} - -LogicalResult MakeDmaDescriptorOp::verify() { - return verifyDescriptorOp(*this); -} - -OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) { - return foldDescriptorOp(*this, adaptor); -} - -//===----------------------------------------------------------------------===// -// MakeGatherDmaDescriptorOp -//===----------------------------------------------------------------------===// - -LogicalResult MakeGatherDmaDescriptorOp::verify() { - ArrayRef globalStaticSizes = getGlobalStaticSizes(); - size_t rank = globalStaticSizes.size(); - if (rank > 2) - return emitOpError( - "tensor and tile must be at most of rank two in gather mode."); - Value indices = getIndices(); - Type elementType = cast(indices.getType()).getElementType(); - if (elementType != getBase().getType().getIndexType()) - return emitOpError("indices' element type must match base's element type."); - - return verifyDescriptorOp(*this); -} - -OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) { - return foldDescriptorOp(*this, adaptor); -} - -//===----------------------------------------------------------------------===// -// ScaledMFMAOp -//===----------------------------------------------------------------------===// - -namespace { -/// Check if the scales input is used in other scaled mfma's while they exist. -/// If theyre unused then pack the scales. -struct PackScales final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ScaledMFMAOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto setOpsel = [&op](unsigned idx, int64_t val) { - switch (idx) { - case 3: - op.setScalesIdxA(val); - break; - case 4: - op.setScalesIdxB(val); - break; - default: - break; - } - }; - - // For every scale operand of this ScaledMFMAOp, if the scale is produced by - // the extraction of a single scale from some vector, then attempt to - // extract 4 values from that vector instead. - // - // Example: (f8 here means f8E8M0FNU) - // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...> - // %scale = vector.insert %unit, ... : f8 into vector<4xf8> - // amdgpu.scaled_mfma(%scale[0] * ... - // - // rewrite to: - // - // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector - // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector - // amdgpu.scaled_mfma(%scale[0-3] * ... - // - // This creates duplicate shape_casts for every use but these will be - // removed in CSE. - for (auto opIdx : std::array({3, 4})) { - auto insertOp = op.getOperand(opIdx).getDefiningOp(); - if (!insertOp) { - return rewriter.notifyMatchFailure(op, - "defining op not a vector.insert"); - } - // If the extracted value is not a single scalar, then it has been packed. - if (isa(insertOp.getValueToStore().getType())) { - return rewriter.notifyMatchFailure( - op, "scaled mfma operand already packed"); - } - - auto extractOp = - insertOp.getValueToStore().getDefiningOp(); - if (!extractOp) { - return rewriter.notifyMatchFailure(op, - "defining op not a vector.extract"); - } - - Value scaleSrc = extractOp.getOperand(0); - auto scaleSrcType = dyn_cast(scaleSrc.getType()); - if (!scaleSrcType) { - return rewriter.notifyMatchFailure(op, "not a vector type"); - } - - // We do not handle dynamic dims yet, assume that the input is padded to - // a static shape now. - if (!scaleSrcType.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, - "dynamic dims not yet supported"); - } - - int64_t numElements = scaleSrcType.getNumElements(); - if (numElements <= 4) { - return rewriter.notifyMatchFailure( - op, "no packing if # of scales less than four"); - } - - // Find a linearized idx using the size and offsets of the extract op. - auto extractedPos = llvm::to_vector_of( - llvm::reverse(extractOp.getStaticPosition())); - ArrayRef scaleSrcShape = scaleSrcType.getShape(); - int64_t scaleSrcRank = scaleSrcType.getRank(); - SmallVector extractSizes(scaleSrcRank, 1); - for (int64_t i = 1; i < scaleSrcRank; ++i) { - extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i]; - } - int64_t idx = linearize(extractedPos, extractSizes); - - // All n scales (where n is the total number of scales) must now be - // extracted in chunks of 4 elements. This is done by dividing the - // original vector of scales into groups of 4 elements - // at offsets 0, 4, ..., m (where m = n/4). All extractions of a - // scale at a particular index are now replaced with an extraction - // of the entire group of 4 elements to which that index belongs. - // - // If the number of scales happens to be indivisible by 4, extract - // the remaining n - m scales in a chunk of 4 elements starting at - // offset n - 4. - int64_t offset = idx - (idx % 4); - int64_t opsel = idx - offset; - int64_t size = 4l; - // Accomdate remaining elements in the case of non-4-divisible vectors. - if (numElements - offset < size) { - opsel = size - (numElements - idx); - offset = numElements - 4l; - } - Type scaleSrcElemType = scaleSrcType.getElementType(); - auto newSrcType = - VectorType::get(ArrayRef{numElements}, scaleSrcElemType); - Value newScaleSrc = - vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); - auto extract = vector::ExtractStridedSliceOp::create( - rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, - ArrayRef{int64_t(1)}); - rewriter.modifyOpInPlace(op, [&] { - op->setOperand(opIdx, extract); - setOpsel(opIdx, opsel); - }); - } - return success(); - } -}; -} // namespace - -void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" - -#define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" - -#define GET_OP_CLASSES -#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUEnums.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUEnums.cpp new file mode 100644 index 0000000000000..44107bd2b95fb --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUEnums.cpp @@ -0,0 +1,18 @@ +//===- AMDGPUAttrs.cpp - MLIR AMDGPU dialect attributes -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AMDGPU dialect attributes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" + +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/StringExtras.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp new file mode 100644 index 0000000000000..87a813a31608d --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp @@ -0,0 +1,1212 @@ +//===- AMDGPUOps.cpp - MLIR AMDGPU dialect operations ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AMDGPU dialect operations, their verifiers, and +// their canonicalizations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include +#include +#include + +using namespace mlir; +using namespace mlir::amdgpu; + +//===----------------------------------------------------------------------===// +// 8-bit float ops +//===----------------------------------------------------------------------===// +LogicalResult PackedTrunc2xFp8Op::verify() { + if (getExisting() && getExisting().getType() != getResult().getType()) + return emitOpError("existing values must have same type as result"); + return success(); +} + +LogicalResult PackedStochRoundFp8Op::verify() { + if (getExisting() && getExisting().getType() != getResult().getType()) + return emitOpError("existing values must have same type as result"); + return success(); +} + +//===----------------------------------------------------------------------===// +// mxfp float ops +//===----------------------------------------------------------------------===// +LogicalResult PackedScaledTruncOp::verify() { + if (getExisting() && getExisting().getType() != getResult().getType()) + return emitOpError("existing values must have same type as result"); + return success(); +} + +//===----------------------------------------------------------------------===// +// FatRawBufferCastOp +//===----------------------------------------------------------------------===// + +/// Convert the type `source` to one with the same sizes and strides - and +/// offset, unless `stripOffset` is true, in which case the offset is reset to +/// 0, if the offset should be reset but the layout of `source` isn't either the +/// identity layout or a strided layout, this function fails. +static FailureOr getFatRawBufferTypeLike(MemRefType source, + bool resetOffset) { + MLIRContext *ctx = source.getContext(); + MemRefType::Builder mb(source); + mb.setMemorySpace( + amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer)); + MemRefLayoutAttrInterface layout = source.getLayout(); + if (resetOffset && !layout.isIdentity()) { + auto stridedLayout = dyn_cast(layout); + if (!stridedLayout) + return failure(); + MemRefLayoutAttrInterface newLayout = + StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides()); + // Special case: if resetting the offset causes the strided layout to become + // the identity layout, then reset to the identity layout. + // TODO: this'll get a lot simpler when we have the contiguous layout. + SmallVector stridesIfIdentity; + if (source.hasStaticShape()) { + stridesIfIdentity = computeSuffixProduct(source.getShape()); + } else if (source.getRank() <= 1) { + stridesIfIdentity = SmallVector(source.getRank(), 1); + } + if (stridesIfIdentity == stridedLayout.getStrides()) { + newLayout = AffineMapAttr::get( + AffineMap::getMultiDimIdentityMap(source.getRank(), ctx)); + } + mb.setLayout(newLayout); + } + return (MemRefType)(mb); +} + +LogicalResult FatRawBufferCastOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + Adaptor adaptor(operands, attributes, properties, regions); + auto sourceType = + dyn_cast_if_present(adaptor.getSource().getType()); + if (!sourceType) + return failure(); + FailureOr resultType = + getFatRawBufferTypeLike(sourceType, adaptor.getResetOffset()); + if (failed(resultType)) + return failure(); + inferredReturnTypes = SmallVector{*resultType}; + return success(); +} + +FailureOr FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder, + int resultIndex, + int dim) { + assert(resultIndex == 0 && "FatRawBufferCastOp has a single result"); + return memref::getMixedSize(builder, getLoc(), getSource(), dim); +} + +LogicalResult FatRawBufferCastOp::verify() { + FailureOr expectedResultType = + getFatRawBufferTypeLike(getSource().getType(), getResetOffset()); + if (failed(expectedResultType)) + return emitOpError("source type ") + << getSource().getType() << " can't have its offset reset"; + if (getResult().getType() != *expectedResultType) + return emitOpError("expected result type to be ") + << *expectedResultType << " but got " << getResult().getType(); + return success(); +} + +static bool hasGlobalMemorySpace(Attribute memorySpace) { + if (!memorySpace) + return true; + if (auto intMemorySpace = dyn_cast(memorySpace)) + return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; + if (auto gpuMemorySpace = dyn_cast(memorySpace)) + return gpuMemorySpace.getValue() == gpu::AddressSpace::Global; + return false; +} + +static bool hasWorkgroupMemorySpace(Attribute memorySpace) { + if (!memorySpace) + return false; + if (auto intMemorySpace = dyn_cast(memorySpace)) + return intMemorySpace.getInt() == 3; + if (auto gpuMemorySpace = dyn_cast(memorySpace)) + return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup; + return false; +} + +static bool hasFatRawBufferMemorySpace(Attribute memorySpace) { + if (!memorySpace) + return false; + if (auto intMemorySpace = dyn_cast(memorySpace)) + return intMemorySpace.getInt() == 7; + if (auto gpuMemorySpace = dyn_cast(memorySpace)) + return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer; + return false; +} + +//===----------------------------------------------------------------------===// +// RawBuffer*Op +//===----------------------------------------------------------------------===// +template +static LogicalResult verifyRawBufferOp(T &op) { + MemRefType bufferType = llvm::cast(op.getMemref().getType()); + bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace()); + + if (!isGlobal) + return op.emitOpError( + "Buffer ops must operate on a memref in global memory"); + if (!bufferType.hasRank()) + return op.emitOpError( + "Cannot meaningfully buffer_store to an unranked memref"); + if (static_cast(op.getIndices().size()) != bufferType.getRank()) + return op.emitOpError("Expected " + Twine(bufferType.getRank()) + + " indices to memref"); + return success(); +} + +LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); } + +LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); } + +LogicalResult RawBufferAtomicFaddOp::verify() { + return verifyRawBufferOp(*this); +} + +LogicalResult RawBufferAtomicFmaxOp::verify() { + return verifyRawBufferOp(*this); +} + +LogicalResult RawBufferAtomicSmaxOp::verify() { + return verifyRawBufferOp(*this); +} + +LogicalResult RawBufferAtomicUminOp::verify() { + return verifyRawBufferOp(*this); +} + +LogicalResult RawBufferAtomicCmpswapOp::verify() { + return verifyRawBufferOp(*this); +} + +static std::optional getConstantUint32(Value v) { + APInt cst; + if (!v.getType().isInteger(32)) + return std::nullopt; + if (matchPattern(v, m_ConstantInt(&cst))) + return cst.getZExtValue(); + return std::nullopt; +} + +template +static bool staticallyOutOfBounds(OpType op) { + if (!op.getBoundsCheck()) + return false; + MemRefType bufferType = op.getMemref().getType(); + if (!bufferType.hasStaticShape()) + return false; + int64_t offset; + SmallVector strides; + if (failed(bufferType.getStridesAndOffset(strides, offset))) + return false; + int64_t result = offset + op.getIndexOffset().value_or(0); + if (op.getSgprOffset()) { + std::optional sgprOffset = getConstantUint32(op.getSgprOffset()); + if (!sgprOffset) + return false; + result += *sgprOffset; + } + if (strides.size() != op.getIndices().size()) + return false; + int64_t indexVal = 0; + for (auto pair : llvm::zip(strides, op.getIndices())) { + int64_t stride = std::get<0>(pair); + Value idx = std::get<1>(pair); + std::optional idxVal = getConstantUint32(idx); + if (!idxVal) + return false; + indexVal += stride * *idxVal; + } + result += indexVal; + if (result > std::numeric_limits::max()) + // Overflow means don't drop + return false; + return result >= bufferType.getNumElements(); +} + +namespace { +template +struct RemoveStaticallyOobBufferLoads final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { + if (!staticallyOutOfBounds(op)) + return failure(); + Type loadType = op.getResult().getType(); + rw.replaceOpWithNewOp(op, loadType, + rw.getZeroAttr(loadType)); + return success(); + } +}; + +template +struct RemoveStaticallyOobBufferWrites final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw) const override { + if (!staticallyOutOfBounds(op)) + return failure(); + + rw.eraseOp(op); + return success(); + } +}; +} // end namespace + +void RawBufferLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + +void RawBufferStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + +void RawBufferAtomicFaddOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>(context); +} + +void RawBufferAtomicFmaxOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>(context); +} + +void RawBufferAtomicSmaxOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>(context); +} + +void RawBufferAtomicUminOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>(context); +} + +void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>( + context); +} + +//===----------------------------------------------------------------------===// +// ScaledExtPackedMatrixOp +//===----------------------------------------------------------------------===// +LogicalResult ScaledExtPackedMatrixOp::verify() { + int blockSize = getBlockSize(); + assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size"); + + int firstScaleByte = getFirstScaleByte(); + int firstScaleLane = getFirstScaleLane(); + auto sourceType = cast(getSource().getType()); + Type elementType = sourceType.getElementType(); + auto floatType = cast(elementType); + unsigned bitWidth = floatType.getWidth(); + + assert(llvm::is_contained(llvm::ArrayRef{4, 6, 8}, bitWidth)); + + const bool is_fp8 = bitWidth == 8; + const bool is_block_16 = blockSize == 16; + + if (!is_fp8) { + if (is_block_16) { + if (!llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError("blockSize of 16 can only have firstScaleByte be 0 " + "or 1 for f4 and f6."); + } + } else { + if (!llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError("blockSize of 32 can only have firstScaleByte be 0 " + "or 2 for f4 and f6."); + } + } + } else { + if (is_block_16) { + bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) || + ((firstScaleLane == 16) && (firstScaleByte == 2)); + if (!is_valid) { + return emitOpError("blockSize of 16 can only have (firstScaleLane, " + "firstScaleByte) be (0, 0) or (16, 2) for f8."); + } + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// WMMAOp +//===----------------------------------------------------------------------===// + +ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, + IntegerAttr &m, IntegerAttr &n, + IntegerAttr &k) { + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; + + m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); + n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); + k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); + return success(); +} + +LogicalResult WMMAOp::verify() { + auto sourceAType = cast(getSourceA().getType()); + auto sourceBType = cast(getSourceB().getType()); + auto destType = cast(getDestC().getType()); + + Type sourceAElemType = sourceAType.getElementType(); + Type sourceBElemType = sourceBType.getElementType(); + if (sourceAType.getNumElements() != sourceBType.getNumElements()) { + return emitOpError("source vectors have different lengths: ") + << sourceAType << " vs. " << sourceBType; + } + + bool isDestFloat = destType.getElementType().isFloat(); + bool isSrcFloat = sourceAElemType.isFloat(); + + if (isDestFloat && !isSrcFloat) + return emitOpError("expected float sources with float destination"); + if (!isDestFloat && isSrcFloat) + return emitOpError("expected int sources with int destination"); + + if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { + return emitOpError( + "source element types must match (except for fp8/bf8) but have ") + << sourceAType << " and " << sourceBType; + } + + if (isSrcFloat) { + if (getClamp()) + return emitOpError("clamp flag is not supported for float types"); + if (getUnsignedA() || getUnsignedB()) + return emitOpError("unsigned flags are not supported for float types"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ScaledWMMAOp +//===----------------------------------------------------------------------===// + +LogicalResult ScaledWMMAOp::verify() { + // Helper functions for type classification. + auto isF8 = llvm::IsaPred; + auto isF6 = llvm::IsaPred; + auto isF4 = llvm::IsaPred; + auto isScaleF8 = llvm::IsaPred; + auto isE8M0 = llvm::IsaPred; + auto isE4M3 = llvm::IsaPred; + + auto sourceAType = cast(getSourceA().getType()); + auto sourceBType = cast(getSourceB().getType()); + auto destType = cast(getDestC().getType()); + + // Validate source element types are small floats (fp4/fp6/fp8). + Type aElemType = sourceAType.getElementType(); + Type bElemType = sourceBType.getElementType(); + + // Validate vector lengths based on dimensions. + int64_t m = getM(); + int64_t aLen = sourceAType.getNumElements(); + int64_t bLen = sourceBType.getNumElements(); + int64_t expectedOutLen = (m == 16) ? 8 : 16; + + if (destType.getNumElements() != expectedOutLen) + return emitOpError("expected output vector of length ") + << expectedOutLen << " but got " << destType.getNumElements(); + + if (m == 16) { + // For 16×16×128: both A and B must be 64 elements. + if (aLen != 64) + return emitOpError( + "for 16x16x128, sourceA must have 64 elements but got ") + << aLen; + if (bLen != 64) + return emitOpError( + "for 16x16x128, sourceB must have 64 elements but got ") + << bLen; + } else { // m == 32 + // For 32×16×128: only fp4 is supported, A is 128, B is 64. + if (!isF4(aElemType) && !isF4(bElemType)) + return emitOpError("32x16x128 only supports fp4 element types"); + + if (aLen != 128) + return emitOpError( + "for 32x16x128, sourceA must have 128 elements but got ") + << aLen; + if (bLen != 64) + return emitOpError( + "for 32x16x128, sourceB must have 64 elements but got ") + << bLen; + + // For 32x16x128, matrix A uses all 32 lanes so a_first_scale_lane must be + // 0. + if (getAFirstScaleLane() != 0) + return emitOpError("for 32x16x128, a_first_scale_lane must be 0"); + } + + // Validate scale types and their compatibility with matrix element types. + auto scaleAType = cast(getScaleA().getType()); + auto scaleBType = cast(getScaleB().getType()); + Type scaleAElemType = scaleAType.getElementType(); + Type scaleBElemType = scaleBType.getElementType(); + + // Validate scale element types are valid scale f8 types (E8M0FNU or E4M3FN). + if (!isScaleF8(scaleAElemType) || !isScaleF8(scaleBElemType)) + return emitOpError( + "scale operands must have f8 element types (E8M0FNU or E4M3FN)"); + + // Any matrices A/B (fp8|fp6|fp4) with E8M0 scales for matrix A/B are valid. + if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType)) + return success(); + + // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3). + if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) && + isF4(bElemType) && isE4M3(scaleBElemType)) + return success(); + + // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0). + if (isF4(aElemType) && isE4M3(scaleAElemType) && + (isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType)) + return success(); + + // Matrix A (F4) x Matrix B (F4) with Scale A (E4M3), Scale B (E4M3). + if (isF4(aElemType) && isF4(bElemType) && isE4M3(scaleAElemType) && + isE4M3(scaleBElemType)) + return success(); + + // No valid combination matched. + return emitOpError("invalid combination of matrix and scale types: ") + << "sourceA=" << aElemType << ", scaleA=" << scaleAElemType + << ", sourceB=" << bElemType << ", scaleB=" << scaleBElemType; +} + +//===----------------------------------------------------------------------===// +// MFMAOp +//===----------------------------------------------------------------------===// +LogicalResult MFMAOp::verify() { + constexpr uint32_t waveSize = 64; + Builder b(getContext()); + + Type sourceType = getSourceA().getType(); + Type destType = getDestC().getType(); + + Type sourceElem = sourceType, destElem = destType; + uint32_t sourceLen = 1, destLen = 1; + if (auto sourceVector = dyn_cast(sourceType)) { + sourceLen = sourceVector.getNumElements(); + sourceElem = sourceVector.getElementType(); + } + if (auto destVector = dyn_cast(destType)) { + destLen = destVector.getNumElements(); + destElem = destVector.getElementType(); + } + + Type sourceBType = getSourceB().getType(); + if (sourceElem.isFloat(8) || sourceElem.isFloat(6) || sourceElem.isFloat(4)) { + int64_t sourceBLen = 1; + Type sourceBElem = sourceBType; + if (auto sourceBVector = llvm::dyn_cast(sourceBType)) { + sourceBLen = sourceBVector.getNumElements(); + sourceBElem = sourceBVector.getElementType(); + } + if (!sourceBElem.isFloat(8) && !sourceBElem.isFloat(6) && + !sourceBElem.isFloat(4)) + return emitOpError("expected both source operands to have small-float " + "elements if one does"); + if (sourceLen != sourceBLen) + return emitOpError( + "expected both small-float source vectors to have the same length"); + } else { + if (sourceType != sourceBType) + return emitOpError("expected both non-small-float source operand types " + "to match exactly"); + } + // Normalize the wider integer types the compiler expects to i8. + if (sourceElem.isInteger(32)) { + sourceLen *= 4; + sourceElem = b.getI8Type(); + } + if (sourceElem.isInteger(64)) { + sourceLen *= 8; + sourceElem = b.getI8Type(); + } + + int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize; + if (sourceLen != numSourceElems) + return emitOpError("expected " + Twine(numSourceElems) + + " source values for this operation but got " + + Twine(sourceLen)); + + int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize; + if (destLen != numDestElems) + return emitOpError("expected " + Twine(numDestElems) + + " result values for this operation but got " + + Twine(destLen)); + + if (destElem.isF64() && getBlgp() != MFMAPermB::none) + return emitOpError( + "double-precision ops do not support permuting lanes of B"); + if (destElem.isF64() && getCbsz() != 0) + return emitOpError( + "double-precision ops do not support permuting lanes of A"); + if (getAbid() >= (1u << getCbsz())) + return emitOpError( + "block ID for permuting A (abid) must be below 2 ** cbsz"); + + if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64()) + return emitOpError( + "negation flags only available for double-precision operations"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// SparseMFMAOp +//===----------------------------------------------------------------------===// + +LogicalResult SparseMFMAOp::verify() { + constexpr uint32_t waveSize = 64; + + auto sparseType = cast(getSourceA().getType()); + auto denseType = cast(getSourceB().getType()); + auto destType = cast(getDestC().getType()); + + Type sparseElem = sparseType.getElementType(); + Type denseElem = denseType.getElementType(); + int64_t sparseLen = sparseType.getNumElements(); + int64_t denseLen = denseType.getNumElements(); + int64_t destLen = destType.getNumElements(); + + if (denseLen != 2 * sparseLen) + return emitOpError("expected dense source operand to have exactly double " + "the number of elements of the sparse source operand"); + + // Check that source element types are compatible. + // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8). + // For other types, element types must match exactly. + bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8); + if (!bothFloat8 && sparseElem != denseElem) + return emitOpError( + "expected source operands to have the same element type"); + + // When CBSZ == 0, ABID selects the index set within the sparse index VGPR. + // When CBSZ != 0, the first index set is always used (ABID ignored). + bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8); + // 8-bit source: ABID selects one of two 16-bit index sets. + if (getCbsz() == 0 && is8BitSource && getAbid() > 1) + return emitOpError("ABID must be 0 or 1 for 8-bit source data"); + // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid). + if (getCbsz() == 0 && !is8BitSource && getAbid() > 3) + return emitOpError("ABID must be between 0 and 3 for 16-bit source data"); + + // Validate sparseIdx type matches source element type. + auto sparseIdxType = cast(getSparseIdx().getType()); + if (is8BitSource) { + // 8-bit source data requires vector<2xi16> sparse indices. + if (sparseIdxType.getNumElements() != 2 || + !sparseIdxType.getElementType().isInteger(16)) + return emitOpError("expected vector<2xi16> sparse indices for 8-bit " + "source data, but got ") + << getSparseIdx().getType(); + } else { + // 16-bit source data requires vector<4xi8> sparse indices. + if (sparseIdxType.getNumElements() != 4 || + !sparseIdxType.getElementType().isInteger(8)) + return emitOpError("expected vector<4xi8> sparse indices for 16-bit " + "source data, but got ") + << getSparseIdx().getType(); + } + + int64_t expectedSourceElems = (getM() * getK()) / waveSize; + if (denseLen != expectedSourceElems) + return emitOpError("expected " + Twine(expectedSourceElems) + + " source values for this operation but got " + + Twine(denseLen)); + + int64_t expectedDestElems = (getM() * getN()) / waveSize; + if (destLen != expectedDestElems) + return emitOpError("expected " + Twine(expectedDestElems) + + " result values for this operation but got " + + Twine(destLen)); + + return success(); +} + +//===----------------------------------------------------------------------===// +// DPPOp +//===----------------------------------------------------------------------===// +LogicalResult DPPOp::verify() { + Type srcType = getSrc().getType(); + if (srcType.getIntOrFloatBitWidth() > 64) { + return emitOpError("integer and floating point types larger than 64 bits " + "are not supported"); + } + + DPPPerm kind = getKind(); + Attribute permArgument = getPermArgument().value_or(Attribute{}); + + switch (kind) { + + case DPPPerm::quad_perm: { + auto quadPermAttr = dyn_cast_or_null(permArgument); + if (!quadPermAttr || quadPermAttr.size() != 4) { + return emitOpError("quad_perm attribute must have exactly 4 elements"); + } + for (auto elem : quadPermAttr.getAsRange()) { + int32_t num = elem.getInt(); + if (num < 0 || num > 3) { + return emitOpError( + "Each element of quad_perm must be in the range [0, 3]"); + } + } + } break; + + case DPPPerm::row_shl: + case DPPPerm::row_shr: + case DPPPerm::row_ror: { + if (!permArgument) { + return emitOpError("Attribute '" + Twine(stringifyDPPPerm(kind)) + + "' value not specified"); + } + if (auto intAttr = dyn_cast(permArgument)) { + uint32_t attrValue = intAttr.getInt(); + if (attrValue < 1 || attrValue > 15) { + return emitOpError("Attribute value must be between 1 and 15"); + } + } + } break; + + case DPPPerm::wave_shl: + case DPPPerm::wave_shr: + case DPPPerm::wave_rol: + case DPPPerm::wave_ror: + case DPPPerm::row_mirror: + case DPPPerm::row_half_mirror: + case DPPPerm::row_bcast_15: + case DPPPerm::row_bcast_31: { + if (permArgument && !isa(permArgument)) { + return emitOpError("Expected unit attribute for permArgument, but found " + "non-trivial argument"); + } + break; + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// PermlaneSwapOp +//===----------------------------------------------------------------------===// +LogicalResult PermlaneSwapOp::verify() { + unsigned rowLength = getRowLength(); + + if (rowLength != 16 && rowLength != 32) + return emitOpError("row_length attribute must either be 16 or 32."); + + return success(); +} + +/// Remove amdgpu.lds_barrier after amdgpu.lds_barrier. +static LogicalResult eraseRedundantLDSBarrierOps(LDSBarrierOp op, + PatternRewriter &rewriter) { + if (isa_and_nonnull(op->getNextNode())) { + rewriter.eraseOp(op); + return success(); + } + return failure(); +} + +void LDSBarrierOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(eraseRedundantLDSBarrierOps); +} + +//===----------------------------------------------------------------------===// +// MemoryCounterWaitOp +//===----------------------------------------------------------------------===// + +namespace { +/// Fuse adjacent memory counter wait ops, taking the minimum value of the +/// counters. +struct FuseMemoryCounterWaitOp final : OpRewritePattern { + using Base::Base; + + LogicalResult matchAndRewrite(MemoryCounterWaitOp op, + PatternRewriter &rewriter) const override { + auto next = dyn_cast(op->getNextNode()); + if (!next) + return failure(); + + auto setters = {&MemoryCounterWaitOp::setLoad, + &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs, + &MemoryCounterWaitOp::setExp, + &MemoryCounterWaitOp::setTensor}; + auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(), + op.getTensor()}; + auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(), + next.getExp(), next.getTensor()}; + rewriter.modifyOpInPlace(op, [&] { + for (auto [setter, lhs, rhs] : + llvm::zip_equal(setters, lhsVals, rhsVals)) { + if (lhs && rhs) { + (op.*setter)(std::min(*lhs, *rhs)); + } else if (lhs) { + (op.*setter)(*lhs); + } else if (rhs) { + (op.*setter)(*rhs); + } + } + }); + rewriter.eraseOp(next); + return success(); + } +}; +} // namespace + +void MemoryCounterWaitOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// GatherToLDSOp +//===----------------------------------------------------------------------===// + +LogicalResult GatherToLDSOp::verify() { + MemRefType srcType = cast(getSrc().getType()); + MemRefType dstType = cast(getDst().getType()); + + if (dstType.getRank() > 0 && !dstType.areTrailingDimsContiguous(1)) + return emitOpError("destination type inner most dim must be contiguous"); + + auto elemType = srcType.getElementType(); + // Check $src and $dst element types are the same. + if (elemType != dstType.getElementType()) + return emitOpError("source and destination element types must match"); + + // copy type sizes should be 1, 2, 4, 12 or 16 bytes. + auto transferType = getTransferType(); + int transferSize; + if (auto vectorTransfer = dyn_cast(transferType)) { + transferSize = vectorTransfer.getNumElements() * + vectorTransfer.getElementTypeBitWidth(); + } else { + transferSize = transferType.getIntOrFloatBitWidth(); + } + if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize)) + return emitOpError( + "Transfering type size must be 8, 16, 32, 96 or 128 bits"); + + if (!hasGlobalMemorySpace(srcType.getMemorySpace()) && + !hasFatRawBufferMemorySpace(srcType.getMemorySpace())) + return emitOpError( + "source memory address space must be global or fat raw buffer"); + + if (!hasWorkgroupMemorySpace(dstType.getMemorySpace())) + return emitOpError("destination memory address space must be Workgroup"); + + return success(); +} + +namespace { +/// If the source/target of a GatherToLDSOp is a CastOp that only removes static +/// information or changes layout, the cast can be skipped. +struct FoldGatherToLDSOfCast final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, + PatternRewriter &rewriter) const override { + bool modified = false; + auto foldCast = [&](OpOperand &operand) { + if (auto castOp = operand.get().getDefiningOp()) { + if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { + rewriter.modifyOpInPlace(gatherOp, + [&] { operand.assign(castOp.getSource()); }); + modified = true; + } + } + }; + + foldCast(gatherOp.getSrcMutable()); + foldCast(gatherOp.getDstMutable()); + + return success(modified); + } +}; +} // namespace + +void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// TransposeLoadOp +//===----------------------------------------------------------------------===// + +LogicalResult TransposeLoadOp::verify() { + MemRefType srcType = cast(getSrc().getType()); + + if (!hasWorkgroupMemorySpace(srcType.getMemorySpace())) + return emitOpError("source memory address space must be Workgroup"); + + auto transferType = cast(getType()); + size_t numElements = transferType.getNumElements(); + size_t elementTypeSize = + transferType.getElementType().getIntOrFloatBitWidth(); + + // ElementSize -> NumElements + const llvm::SmallDenseMap kValidLoadSizeMap = { + {4, 16}, + {6, 16}, + {8, 8}, + {16, 4}, + }; + + auto validNumElems = kValidLoadSizeMap.find(elementTypeSize); + if (validNumElems == kValidLoadSizeMap.end()) + return emitOpError("Unsupported element type size for transpose load: ") + << elementTypeSize << " bits"; + + if (numElements != validNumElems->second) + return emitOpError( + "Transferring type size mismatch: expected num of elements: ") + << validNumElems->second; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MakeDmaBaseOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyBase(BaseOp op) { + auto ldsType = cast(op.getLds().getType()); + auto globalType = cast(op.getGlobal().getType()); + if (!hasWorkgroupMemorySpace(ldsType.getMemorySpace())) + return op.emitOpError( + "lds memref must have workgroup address space attribute."); + if (!hasGlobalMemorySpace(globalType.getMemorySpace())) + return op.emitOpError( + "global memref must have global address space attribute."); + + Type elementType = ldsType.getElementType(); + unsigned width = elementType.getIntOrFloatBitWidth(); + + if (!llvm::is_contained({8u, 16u, 32u, 64u}, width)) + return op.emitOpError( + "element type must be 1, 2, 4, or 8 bytes long but type was ") + << width << " bits long."; + return success(); +} + +LogicalResult MakeDmaBaseOp::verify() { return verifyBase(*this); } + +//===----------------------------------------------------------------------===// +// MakeGatherDmaBaseOp +//===----------------------------------------------------------------------===// + +LogicalResult +TDMGatherBaseType::verify(function_ref emitError, + Type elementType, Type indexType) { + unsigned width = elementType.getIntOrFloatBitWidth(); + if (!llvm::is_contained({8u, 16u, 32u, 64u}, width)) + return emitError() + << "element type must be 1, 2, 4, or 8 bytes wide but type " + << elementType << " is " << width / 8 << " bytes wide."; + MLIRContext *ctx = elementType.getContext(); + Type i16 = IntegerType::get(ctx, 32); + Type i32 = IntegerType::get(ctx, 16); + if (!llvm::is_contained({i16, i32}, indexType)) + return emitError() << "index type must be i16 or i32 but index type is " + << indexType << "."; + return success(); +} + +LogicalResult MakeGatherDmaBaseOp::verify() { return verifyBase(*this); } + +//===----------------------------------------------------------------------===// +// MakeDmaDescriptorOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyDescriptorOp(DescriptorOp op) { + ArrayRef globalStaticStrides = op.getGlobalStaticStrides(); + + if (globalStaticStrides.empty()) + return op.emitOpError("strides must not be empty."); + if (globalStaticStrides.back() != 1) + return op.emitOpError("strides for the innermost dimension must be 1."); + + ArrayRef globalStaticSizes = op.getGlobalStaticSizes(); + size_t rank = globalStaticSizes.size(); + if (rank > 5) + return op.emitOpError("tensor and tile must be at most of rank 5."); + if (rank != globalStaticStrides.size()) + return op.emitOpError("strides and sizes must have same rank."); + + ArrayRef sharedStaticSizes = op.getSharedStaticSizes(); + if (rank != sharedStaticSizes.size()) + return op.emitOpError("tensor must have same rank as tile."); + + unsigned elementTypeWidth = op.getElementTypeWidth(); + if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth)) + return op.emitOpError( + "element type width must be 1, 2, 4 or 8 bytes, but was ") + << elementTypeWidth << " bits long"; + + if (Value atomicBarrierAddress = op.getAtomicBarrierAddress()) { + auto atomicBarrierAddressType = + cast(atomicBarrierAddress.getType()); + bool barrierInLDS = + hasWorkgroupMemorySpace(atomicBarrierAddressType.getMemorySpace()); + if (!barrierInLDS) + return op.emitOpError("atomic barrier address must be in LDS."); + } + + if (op.getEarlyTimeout() && !op.getWorkgroupMask()) + return op.emitOpError( + "early timeout does not apply when workgroup_mask is not set."); + return success(); +} + +template +static OpFoldResult foldDescriptorOp(DescriptorOp op, FoldAdaptor adaptor) { + SmallVector mixedGlobalSizes(op.getMixedGlobalSizes()); + SmallVector mixedGlobalStrides(op.getMixedGlobalStrides()); + SmallVector mixedSharedSizes(op.getMixedSharedSizes()); + + if (failed(foldDynamicIndexList(mixedGlobalSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedGlobalStrides, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true)) && + failed(foldDynamicIndexList(mixedSharedSizes, /*onlyNonNegative=*/true, + /*onlyNonZero=*/true))) + return nullptr; + + SmallVector dynamicGlobalSizes, dynamicGlobalStrides, + dynamicSharedSizes; + SmallVector staticGlobalSizes, staticGlobalStrides, + staticSharedSizes; + + dispatchIndexOpFoldResults(mixedGlobalSizes, dynamicGlobalSizes, + staticGlobalSizes); + op.setGlobalStaticSizes(staticGlobalSizes); + op.getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes); + + dispatchIndexOpFoldResults(mixedGlobalStrides, dynamicGlobalStrides, + staticGlobalStrides); + op.setGlobalStaticStrides(staticGlobalStrides); + op.getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides); + + dispatchIndexOpFoldResults(mixedSharedSizes, dynamicSharedSizes, + staticSharedSizes); + op.setSharedStaticSizes(staticSharedSizes); + op.getSharedDynamicSizesMutable().assign(dynamicSharedSizes); + return op.getResult(); +} + +LogicalResult MakeDmaDescriptorOp::verify() { + return verifyDescriptorOp(*this); +} + +OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) { + return foldDescriptorOp(*this, adaptor); +} + +//===----------------------------------------------------------------------===// +// MakeGatherDmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult MakeGatherDmaDescriptorOp::verify() { + ArrayRef globalStaticSizes = getGlobalStaticSizes(); + size_t rank = globalStaticSizes.size(); + if (rank > 2) + return emitOpError( + "tensor and tile must be at most of rank two in gather mode."); + Value indices = getIndices(); + Type elementType = cast(indices.getType()).getElementType(); + if (elementType != getBase().getType().getIndexType()) + return emitOpError("indices' element type must match base's element type."); + + return verifyDescriptorOp(*this); +} + +OpFoldResult MakeGatherDmaDescriptorOp::fold(FoldAdaptor adaptor) { + return foldDescriptorOp(*this, adaptor); +} + +//===----------------------------------------------------------------------===// +// ScaledMFMAOp +//===----------------------------------------------------------------------===// + +namespace { +/// Check if the scales input is used in other scaled mfma's while they exist. +/// If theyre unused then pack the scales. +struct PackScales final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ScaledMFMAOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto setOpsel = [&op](unsigned idx, int64_t val) { + switch (idx) { + case 3: + op.setScalesIdxA(val); + break; + case 4: + op.setScalesIdxB(val); + break; + default: + break; + } + }; + + // For every scale operand of this ScaledMFMAOp, if the scale is produced by + // the extraction of a single scale from some vector, then attempt to + // extract 4 values from that vector instead. + // + // Example: (f8 here means f8E8M0FNU) + // %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...> + // %scale = vector.insert %unit, ... : f8 into vector<4xf8> + // amdgpu.scaled_mfma(%scale[0] * ... + // + // rewrite to: + // + // %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector + // %scale = vector.extract %reshaped[?] : vector<4xf8> from vector + // amdgpu.scaled_mfma(%scale[0-3] * ... + // + // This creates duplicate shape_casts for every use but these will be + // removed in CSE. + for (auto opIdx : std::array({3, 4})) { + auto insertOp = op.getOperand(opIdx).getDefiningOp(); + if (!insertOp) { + return rewriter.notifyMatchFailure(op, + "defining op not a vector.insert"); + } + // If the extracted value is not a single scalar, then it has been packed. + if (isa(insertOp.getValueToStore().getType())) { + return rewriter.notifyMatchFailure( + op, "scaled mfma operand already packed"); + } + + auto extractOp = + insertOp.getValueToStore().getDefiningOp(); + if (!extractOp) { + return rewriter.notifyMatchFailure(op, + "defining op not a vector.extract"); + } + + Value scaleSrc = extractOp.getOperand(0); + auto scaleSrcType = dyn_cast(scaleSrc.getType()); + if (!scaleSrcType) { + return rewriter.notifyMatchFailure(op, "not a vector type"); + } + + // We do not handle dynamic dims yet, assume that the input is padded to + // a static shape now. + if (!scaleSrcType.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, + "dynamic dims not yet supported"); + } + + int64_t numElements = scaleSrcType.getNumElements(); + if (numElements <= 4) { + return rewriter.notifyMatchFailure( + op, "no packing if # of scales less than four"); + } + + // Find a linearized idx using the size and offsets of the extract op. + auto extractedPos = llvm::to_vector_of( + llvm::reverse(extractOp.getStaticPosition())); + ArrayRef scaleSrcShape = scaleSrcType.getShape(); + int64_t scaleSrcRank = scaleSrcType.getRank(); + SmallVector extractSizes(scaleSrcRank, 1); + for (int64_t i = 1; i < scaleSrcRank; ++i) { + extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i]; + } + int64_t idx = linearize(extractedPos, extractSizes); + + // All n scales (where n is the total number of scales) must now be + // extracted in chunks of 4 elements. This is done by dividing the + // original vector of scales into groups of 4 elements + // at offsets 0, 4, ..., m (where m = n/4). All extractions of a + // scale at a particular index are now replaced with an extraction + // of the entire group of 4 elements to which that index belongs. + // + // If the number of scales happens to be indivisible by 4, extract + // the remaining n - m scales in a chunk of 4 elements starting at + // offset n - 4. + int64_t offset = idx - (idx % 4); + int64_t opsel = idx - offset; + int64_t size = 4l; + // Accomdate remaining elements in the case of non-4-divisible vectors. + if (numElements - offset < size) { + opsel = size - (numElements - idx); + offset = numElements - 4l; + } + Type scaleSrcElemType = scaleSrcType.getElementType(); + auto newSrcType = + VectorType::get(ArrayRef{numElements}, scaleSrcElemType); + Value newScaleSrc = + vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); + auto extract = vector::ExtractStridedSliceOp::create( + rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size}, + ArrayRef{int64_t(1)}); + rewriter.modifyOpInPlace(op, [&] { + op->setOperand(opIdx, extract); + setOpsel(opIdx, opsel); + }); + } + return success(); + } +}; +} // namespace + +void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUTypes.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUTypes.cpp new file mode 100644 index 0000000000000..a12695cbb2ba7 --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUTypes.cpp @@ -0,0 +1,26 @@ +//===- AMDGPUTypes.cpp - MLIR AMDGPU dialect types ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AMDGPU dialect types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" + +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + +void mlir::amdgpu::AMDGPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt index 5d14a05945e95..aef460904baaa 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt @@ -1,5 +1,9 @@ add_mlir_dialect_library(MLIRAMDGPUDialect + AMDGPUAttrs.cpp AMDGPUDialect.cpp + AMDGPUEnums.cpp + AMDGPUOps.cpp + AMDGPUTypes.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU