diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 4820b7a747ac2..e07c72b839e7c 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -33,6 +33,7 @@ def AMDGPU_Dialect : Dialect { "gpu::GPUDialect" ]; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">; @@ -79,6 +80,30 @@ def AMDGPU_AddressSpaceAttr : EnumAttr traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +//===----------------------------------------------------------------------===// +// AMDGPU Type definitions +//===----------------------------------------------------------------------===// + +def AMDGPU_TDMBaseType : AMDGPU_Type<"TDMBase", "tdm_base"> { + let summary = "Pair of base addresses that move data between LDS and global storage."; + let description = [{ + This type is opaque and it is used to represent a struct of two addresses. + One address is in LDS while the other is in global memory. + }]; + let parameters = (ins "Type":$elementType); + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ + return $_get(elementType.getContext(), elementType); + }]> + ]; + let assemblyFormat = "`<` $elementType `>`"; +} + //===----------------------------------------------------------------------===// // AMDGPU Op definitions //===----------------------------------------------------------------------===// @@ -1192,4 +1217,35 @@ def AMDGPU_ScaledMFMAOp : }]; let hasCanonicalizer = 1; } + +def AMDGPU_MakeDmaBaseOp : + AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>, + Arguments<(ins + Arg:$src, + Variadic:$srcIndices, + Arg:$dst, + Variadic:$dstIndices)>, + Results<(outs AMDGPU_TDMBaseType: $base)> { + + // TODO: + // * Add verifiers such that one of the memrefs is from LDS and the other global. + // * Add verifiers to make sure that the type is in the correct direction. + // * Add verifiers to make sure that the number of indices do not exceed the number of dimensions. + + let summary = "Pair of based addresses used when moving tiles between LDS and global memory."; + let description = [{ + This operation creates a pair of addresses that will be used by tensor_load_to_lds + and tensor_store_from_lds. + + This operation creates a value corresponding to the tensor descriptor (D#) group 0 + found in TensorLoadToLDSOp and TensorStoreFromLDSOp in the rocdl dialect. + + These tensor DMA operations were introduced in gfx1250. + }]; + + let assemblyFormat = [{ + $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst) `to` type(results) + }]; +} + #endif // AMDGPU diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h index dcd9f95a7561f..a7680fb5c3191 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -25,6 +25,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc" #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc" namespace mlir::amdgpu { /// Parser for the `custom` custom assembly format used by @@ -52,6 +53,9 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *, #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPU.h.inc" diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index d55f3cec47c1f..cdc10c60a42ae 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -55,6 +55,10 @@ void AMDGPUDialect::initialize() { #define GET_OP_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" @@ -839,5 +843,8 @@ void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc" diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 09134cb4704bb..653f9f64d24f4 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -685,3 +685,15 @@ func.func @memory_counter_wait() { amdgpu.memory_counter_wait exp(4) func.return } + +// CHECK-LABEL: func @make_dma_base +// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEM:.+]]: memref<8xi32>, %[[SMEM:.+]]: memref<8xi32, #gpu.address_space>) +func.func @make_dma_base(%idx: index, %mem: memref<8xi32>, %smem: memref<8xi32, #gpu.address_space>) { + // CHECK: amdgpu.make_dma_base %[[MEM]][%[[IDX]]], %[[SMEM]][%[[IDX]]] : memref<8xi32>, memref<8xi32, #gpu.address_space> to !amdgpu.tdm_base + amdgpu.make_dma_base %mem[%idx], %smem[%idx] : memref<8xi32>, memref<8xi32, #gpu.address_space> to !amdgpu.tdm_base + + // CHECK: amdgpu.make_dma_base %[[SMEM]][%[[IDX]]], %[[MEM]][%[[IDX]]] : memref<8xi32, #gpu.address_space>, memref<8xi32> to !amdgpu.tdm_base + amdgpu.make_dma_base %smem[%idx], %mem[%idx] : memref<8xi32, #gpu.address_space>, memref<8xi32> to !amdgpu.tdm_base + func.return +} +