From d3b0e49f06ebbef30a32a2a48e47c58c5bac1b00 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Fri, 12 Jan 2024 01:47:35 +0000 Subject: [PATCH 1/8] update with upstream practics --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 95 ++-- .../mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 34 +- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 440 ++++++++---------- .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 8 +- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 25 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 171 +++---- 6 files changed, 363 insertions(+), 410 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 84112b8b18a81..00b149090003f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -10,12 +10,11 @@ #define MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td" - include "mlir/IR/EnumAttr.td" class XeGPUAttr traits = [], string baseCppClass = "::mlir::Attribute"> - : AttrDef { + : AttrDef { let mnemonic = attrMnemonic; } @@ -49,7 +48,7 @@ def XeGPU_SgMapAttr: XeGPUAttr<"SubGroupMap", "sg_map"> { def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> { let parameters = (ins - DefaultValuedParameter<"xegpu::MemoryScope", "xegpu::MemoryScope::GLOBAL">: $memory_scope, + DefaultValuedParameter<"xegpu::MemoryScopeKind", "xegpu::MemoryScopeKind::GLOBAL">: $memory_scope, DefaultValuedParameter<"int", "1">: $array_length, DefaultValuedParameter<"bool", "true">: $boundary_check, OptionalParameter<"xegpu::ScatteredAttr">: $scattered, @@ -58,7 +57,7 @@ def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> { let builders = [ AttrBuilder<(ins - CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::GLOBAL">:$memory_scope, + CArg<"xegpu::MemoryScopeKind", "xegpu::MemoryScopeKind::GLOBAL">:$memory_scope, CArg<"int", "1">:$array_length, CArg<"xegpu::ScatteredAttr", "{}">:$scattered, CArg<"xegpu::SubGroupMapAttr", "{}">:$map @@ -72,42 +71,62 @@ def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> { let hasCustomAssemblyFormat = true; } -def XeGPU_ArgTypeAttr : I32EnumAttr< - "ArgType", "", [ I32EnumAttrCase<"Vector", 0, "vector">, - I32EnumAttrCase<"Scalar", 1, "scalar"> ]> { - let cppNamespace = "::mlir::xegpu"; +def ARG_TYPE_VECTOR : I32EnumAttrCase<"VECTOR", 0, "vector">; +def ARG_TYPE_SCALAR : I32EnumAttrCase<"SCALAR", 1, "scalar">; +def XeGPU_ArgTypeKind : I32EnumAttr<"ArgTypeKind", + "Argument type for Invoke_SIMD op", + [ARG_TYPE_VECTOR, ARG_TYPE_SCALAR]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xegpu"; } -def XeGPU_ModeAttr : I32EnumAttr< - "Mode", "", [ I32EnumAttrCase<"SIMT", 0, "simt">, - I32EnumAttrCase<"VC", 1, "vc"> ]> { - let cppNamespace = "::mlir::xegpu"; +def MODE_SIMT : I32EnumAttrCase<"SIMT", 0, "simt">; +def MODE_VC : I32EnumAttrCase<"VC", 1, "vc">; +def XeGPU_ModeKind : I32EnumAttr<"ModeKind", + "The Mode an operator runs on", + [MODE_SIMT, MODE_VC]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xegpu"; } -def XeGPU_MemoryScopeAttr : I32EnumAttr< - "MemoryScope", "", [ I32EnumAttrCase<"GLOBAL", 0, "global">, - I32EnumAttrCase<"SLM", 1, "slm"> ]> { - let cppNamespace = "::mlir::xegpu"; +def MEMORY_SCOPE_GLOBAL: I32EnumAttrCase<"GLOBAL", 0, "global">; +def MEMORY_SCOPE_SHARED: I32EnumAttrCase<"SLM", 1, "slm">; +def XeGPU_MemoryScopeKind: I32EnumAttr<"MemoryScopeKind", + "The scope of the memory the tensor descritor is created for", + [MEMORY_SCOPE_GLOBAL, MEMORY_SCOPE_SHARED]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xegpu"; } -def XeGPU_CacheReadAttr : I32EnumAttr< - "CacheReadHint", "", [ I32EnumAttrCase<"UNCACHED", 0, "uncached">, - I32EnumAttrCase<"CACHED", 1, "cached">, - I32EnumAttrCase<"STREAMING", 2, "streaming">, - I32EnumAttrCase<"READ_INVALIDATE", 3, "read_invalidate"> ]> { - - let cppNamespace = "::mlir::xegpu"; +def CACHE_KIND_CACHED: I32EnumAttrCase<"CACHED", 0, "cached">; +def CACHE_KIND_UNCACHED: I32EnumAttrCase<"UNCACHED", 1, "uncached">; +def CACHE_KIND_STREAMING: I32EnumAttrCase<"STREAMING", 2, "streaming">; +def CACHE_KIND_INVALIDATE: I32EnumAttrCase<"READ_INVALIDATE", 3, "read_invalidate">; +def CACHE_KIND_WRITE_BACK: I32EnumAttrCase<"WRITE_BACK", 4, "write_back">; +def CACHE_KIND_WRITE_THROUGH: I32EnumAttrCase<"WRITE_THROUGH", 5, "write_through">; + +def XeGPU_ReadCacheKind : I32EnumAttr<"ReadCacheKind", + "Cache behavior for read", + [CACHE_KIND_CACHED, CACHE_KIND_UNCACHED, + CACHE_KIND_STREAMING, CACHE_KIND_INVALIDATE]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xegpu"; } -def XeGPU_CacheWriteAttr : I32EnumAttr< - "CacheWriteHint", "", [ I32EnumAttrCase<"UNCACHED", 0, "uncached">, - I32EnumAttrCase<"WRITE_THROUGH", 1, "write_through">, - I32EnumAttrCase<"WRITE_BACK", 2, "write_back">, - I32EnumAttrCase<"STREAMING", 3, "streaming"> ]> { - - let cppNamespace = "::mlir::xegpu"; +def XeGPU_WriteCacheKind: I32EnumAttr<"WriteCacheKind", + "Cache behavior for write", + [CACHE_KIND_UNCACHED, CACHE_KIND_STREAMING, + CACHE_KIND_WRITE_BACK, CACHE_KIND_WRITE_THROUGH]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xegpu"; } +def XeGPU_ArgTypeAttr : EnumAttr; +def XeGPU_ModeAttr : EnumAttr; +def XeGPU_MemoryScopeAttr : EnumAttr; +def XeGPU_ReadCacheAttr : EnumAttr; +def XeGPU_WriteCacheAttr : EnumAttr; + // RMW kind attribute def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; @@ -123,14 +142,16 @@ def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>; def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>; -def XeGPU_AtomicRMWKindAttr : I64EnumAttr< - "AtomicRMWKind", "", - [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, - ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, - ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, - ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI, - ATOMIC_RMW_KIND_ANDI]> { +def XeGPU_AtomicRMWKind : I64EnumAttr<"AtomicRMWKind", + "Operation type for AtomicRMW", + [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, + ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, + ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, + ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI, + ATOMIC_RMW_KIND_ANDI]> { + let genSpecializedAttr = 0; let cppNamespace = "::mlir::xegpu"; } +def XeGPU_AtomicRMWKindAttr : EnumAttr; #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td index ae29f87a8812a..f85ccb32cc43b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td @@ -22,31 +22,23 @@ include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ShapedOpInterfaces.td" - -// Provide a definition of the 'XeGPU' dialect in the ODS framework so that we -// can define our operations. -def XeGPUDialect : Dialect { - // The namespace of our dialect +def XeGPU_Dialect : Dialect { let name = "xegpu"; - - // A short one-line summary of our dialect. + let cppNamespace = "::mlir::xegpu"; let summary = "The XeGPU dialect that models Intel GPU's ISA"; - - // A longer description of our dialect. let description = [{ - The XeGPU dialect models Intel Xe ISA semantics but works at vector and - TensorDesc data type. It provides 1:1 mappings to match Xe instructions like - DPAS and 2D block load. The matrix size being processed at this level - exactly matches the hardware instructions or the intrinsic supported by - the lower-level GPU compiler. - }]; - - // The C++ namespace that the dialect class definition resides in. - let cppNamespace = "::mlir::xegpu"; - - let dependentDialects = ["::mlir::memref::MemRefDialect"]; + The XeGPU dialect models Intel Xe ISA semantics but works at vector and + TensorDesc data type. It provides 1:1 mappings to match Xe instructions + like DPAS and 2D block load. The matrix size being processed at this level + exactly matches the hardware instructions or the intrinsic supported by + the lower-level GPU compiler. + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "memref::MemRefDialect" + ]; - // TODO: temporary disable it. let useDefaultTypePrinterParser = true; let useDefaultAttributePrinterParser = true; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 6866f903d715e..7e62185d7b08e 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -19,35 +19,36 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td" // * The parent dialect of the operation. // * The mnemonic for the operation, or the name without the dialect prefix. // * A list of traits for the operation. -class XeGPU_Op traits = []> : - Op; +class XeGPU_Op traits = []>: + Op; -def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSegments]> { +def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSegments]> { let summary = "create nd tensor descriptor operation"; let description = [{ The "create_nd_tdesc" operation creates a TensorDescType which represents a sub-view of a 2D memory region (It can be extended to support N-D memory - region if needed in future). Elements in the subview continuous in each dimention. - It encodes the following important information for supporting intel hardware features: - - * source: an object representing (starting address/pointer of) a 2D memory reagion. It can - be either a 2D memref object, or simply a pointer represented by uint64_t type. - * offsets: two index values represents offsets from the "source" at the each dimension at - which the subview of the target memory will be created. It is encoded via two - variables, including "dynamic_offsets" and "static_offsets", such that it can - accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4])). - * shape: the shape information of the memory region pointed by the "source". It is typically - encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. But if "source" - is simply a pointer represented as uint64_t type, or a memref type without shape information - e.g., memref, the shape information has to be explicitly passed via the "dynamic_shape" - argument. Currently "dynamic_shape" only accepts operands(e.g., [%c4096, %c4096]), - not attributes(e.g., [4096, 4096]). - * strides: the strides of the memory region pointed by the "source". Similar to shape, it is typically - encoded via the MemRefType of the source too. But if "source" is simply a pointer represented - as uint64_t type, or a memref type without shape information e.g., memref, the strides - information has to be explicitly passed via the "dynamic_strides" argument. And it currently - only accepts operands two. + region if needed in future). Elements in the subview continuous in each + dimention. It encodes the following important information for supporting + Intel hardware features: + + * source: an object representing (starting address/pointer of) a 2D memory reagion. + It can be either a 2D memref object, or simply a pointer represented by uint64_t type. + * offsets: two index values represents offsets from the "source" at the each dimension + at which the subview of the target memory will be created. It is encoded via two + variables, including "dynamic_offsets" and "static_offsets", such that it can + accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4])). + * shape: the shape information of the memory region pointed by the "source". It is + typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. + But if "source" is simply a pointer represented as uint64_t type, or a memref + type without shape information e.g., memref, the shape information has + to be explicitly passed via the "dynamic_shape" argument. Currently "dynamic_shape" + only accepts operands(e.g., [%c4096, %c4096]), not attributes(e.g., [4096, 4096]). + * strides: the strides of the memory region pointed by the "source". Similar to shape, + it is typically encoded via the MemRefType of the source too. But if "source" is + simply a pointer represented as uint64_t type, or a memref type without shape + information e.g., memref, the strides information has to be explicitly + passed via the "dynamic_strides" argument. And it currently only accepts operands two. Example 1 (suppose the tensor shape inferred by the compiler is 8x16): %0 = memref.alloc() : memref<32x24xf32> @@ -68,30 +69,32 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe %1 = xegpu.create_nd_tdesc %0[%c0, %c1], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32> }]; - let arguments = (ins XeGPU_BaseAddrType: $source, - Variadic: $dynamic_offsets, - Variadic: $dynamic_shape, - Variadic: $dynamic_strides, - DenseI64ArrayAttr: $static_offsets, - DefaultValuedAttr: $mode); - + let arguments = (ins XeGPU_BaseAddrType: $source, + Variadic: $dynamic_offsets, + Variadic: $dynamic_shape, + Variadic: $dynamic_strides, + DenseI64ArrayAttr: $static_offsets, + DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc:$TensorDesc); let hasCustomAssemblyFormat = 1; - let skipDefaultBuilders = 1; + let hasVerifier = 1; let builders = [ - OpBuilder<(ins "Type": $TensorDesc, "Value": $source, "ValueRange": $offsets, - "ValueRange": $shape, "ValueRange": $strides, "::llvm::ArrayRef": $static_offsets, - CArg<"xegpu::Mode", "xegpu::Mode::SIMT">: $mode)>, + OpBuilder<(ins "Type": $TensorDesc, "Value": $source, "ValueRange": $offsets, + "ValueRange": $shape, "ValueRange": $strides, + "llvm::ArrayRef": $static_offsets, + CArg<"xegpu::ModeKind", "xegpu::ModeKind::SIMT">: $mode)>, - OpBuilder<(ins "Type": $tdesc, "Value": $source, "::llvm::ArrayRef": $offsets, - CArg<"xegpu::Mode", "xegpu::Mode::SIMT">: $mode)>, + OpBuilder<(ins "Type": $tdesc, "Value": $source, + "llvm::ArrayRef": $offsets, + CArg<"xegpu::ModeKind", "xegpu::ModeKind::SIMT">: $mode)>, - OpBuilder<(ins "Type": $tdesc, "Value": $source, "::llvm::ArrayRef": $offsets, + OpBuilder<(ins "Type": $tdesc, "Value": $source, + "llvm::ArrayRef": $offsets, "ValueRange": $shape, "ValueRange": $stride, - CArg<"xegpu::Mode", "xegpu::Mode::SIMT">: $mode)> + CArg<"xegpu::ModeKind", "xegpu::ModeKind::SIMT">: $mode)> ]; let extraClassDeclaration = [{ @@ -124,7 +127,6 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe /// strides information from memref type will be ignored. llvm::SmallVector getStrides(); - /// return the shape embeded in the memref type of the source. /// If source is not memref type. array of kDynamic will be returned. llvm::ArrayRef getStaticShape(); @@ -133,34 +135,30 @@ def XeGPU_CreateNdDescOp : XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSe /// If source is not memref type. array of kDynamic will be returned. llvm::ArrayRef getStaticStrides(); - /// Return the element type of the TensorDesc Type getElementType(); /// Return the shape of the TensorDesc llvm::ArrayRef getTensorDescShape(); - - - }]; - let hasVerifier = 1; } -def XeGPU_CreateDescOp - : XeGPU_Op<"create_tdesc", [Pure]> { - +def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure]> { let summary = "create scattered tensor descritors (TensorDesc)."; let description = [{ - "create_tdesc" is similar to "create_nd_tdesc" in terms that it creates a TensorDesc for a memory region. - while "create_nd_tdesc" is for creating continious subviews, "create_tdesc" is for creating non-continious - (scattered) subviews. It only works with VectorCompute (VC) mode and accepts the following parameters: + "create_tdesc" is similar to "create_nd_tdesc" in terms that it creates + a Tensor Descriptor (TensorDescType) for a memory region. While "create_nd_tdesc" + is for creating continious subviews, "create_tdesc" is for creating non-continious + (scattered) subviews. It is designed only works with VectorCompute (VC) mode and + accepts the following parameters: * source: a 1D memref or pointer (uint64_t) represents the memory object. - * offsets: It is a 1D vector containing offsets of each access point, the size should be aligned with - supportted group size, e.g., vector<16xindex>. And each element in the vector corresponds to a - work item (SIMT lane) in the subgroup. - * chunk_size_per_lane: [optional attribute] indicates number of continious elements accessed for each offset, default is 1. + * offsets: It is a 1D vector containing offsets of each access point, the supportted + group size, e.g., vector<16xindex>. And each element in the vector corresponds + to a work item (SIMT lane) in the subgroup. + * chunk_size_per_lane: [optional attribute] indicates number of continious elements + accessed for each offset, default is 1. Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64] %a = memref.alloc() : memref<1024xf32> @@ -177,48 +175,46 @@ def XeGPU_CreateDescOp let arguments = (ins XeGPU_BaseAddrType: $source, XeGPU_OffsetType: $offsets, DefaultValuedAttr: $chunk_size_per_lane, - DefaultValuedAttr: $mode); - + DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc:$TensorDesc); let builders = [ OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source, "Value": $offsets, CArg<"uint32_t", "1"> : $chunk_size_per_lane)>, - OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source, "Value": $offsets, "IntegerAttr": $chunk_size_per_lane)> ]; let skipDefaultBuilders = 1; // Format: xegpu.create_tdesc %src, %offsets {mode=simt, chunk_size_per_lane=1} - // : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + // : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { - let summary = "loads a n-D block from global memory (represented by TensorDesc) to registers (represented by vector)"; + let summary = "loads a n-D block from memory (represented by TensorDesc)" + "to registers (represented by vector)"; let description = [{ - LoadNDOp essentially mimics the hardware block read instruction to read a block of data from memory to register. - It takes a set of cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, - Corresponding cache hint attribute will be masked. - - If both transpose and vnni_axis present at the same time. it assume to perform transpose first and then vnni transform. + LoadNDOp essentially mimics the hardware block read instruction to read + a block of data from memory to register. It takes a set of cache hints + for each level of cache, L1, L2 and L3. If hardware does not have a + correspoding cache, Corresponding cache hint attribute will be masked. + If both transpose and vnni_axis present at the same time. It assume to + perform transpose first and then vnni transform. }]; - let arguments = (ins - XeGPU_TensorDesc: $TensorDesc, - OptionalAttr: $vnni_axis, - OptionalAttr: $transpose, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, - DefaultValuedAttr: $mode); + let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + OptionalAttr: $vnni_axis, + OptionalAttr: $transpose, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode); let results = (outs XeGPU_ValueType: $value); let extraClassDeclaration = [{ - VectorType getValueType() { return llvm::dyn_cast(getValue().getType()); } @@ -232,20 +228,17 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { // Format: xegpu.load_nd %1 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} // : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } def XeGPU_StoreNDOp : XeGPU_Op<"store_nd", []> { let summary = "stores a n-D block register region back to memory, currently only supports 2D"; - let arguments = (ins - XeGPU_TensorDesc: $TensorDesc, - XeGPU_ValueType: $value, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, - DefaultValuedAttr: $mode - ); + let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + XeGPU_ValueType: $value, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode); // Format: xegpu.store_nd %3, %2 {l1_hint = write_back, l2_hint = uncached} // : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> @@ -256,13 +249,12 @@ def XeGPU_StoreNDOp : XeGPU_Op<"store_nd", []> { def XeGPU_PrefetchNDOp : XeGPU_Op<"prefetch_nd", []> { let summary = "prefetches a nD block to cache"; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, - DefaultValuedAttr: $mode - ); + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode); - // In format of: xegpu.prefetch_nd %tdesc {l1_hint = cached, l2_hint = uncached}: + // Format: xegpu.prefetch_nd %tdesc {l1_hint = cached, l2_hint = uncached}: // !xegpu.tensor_desc<8x16xf16> let hasCustomAssemblyFormat = 1; } @@ -273,7 +265,7 @@ def XeGPU_UpdateNDOffsetOp : XeGPU_Op<"update_nd_offset", []> { let arguments = (ins XeGPU_TensorDesc: $TensorDesc, Variadic: $offsets, - DefaultValuedAttr: $mode); + DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc: $result); @@ -285,14 +277,13 @@ def XeGPU_UpdateNDOffsetOp : XeGPU_Op<"update_nd_offset", []> { let hasVerifier = 1; } - def XeGPU_DpasOp : XeGPU_Op<"dpas"> { let summary = "performs dpas computation"; let arguments = (ins XeGPU_DpasOpType : $lhs, XeGPU_DpasOpType : $rhs, Optional: $acc, - DefaultValuedAttr: $mode + DefaultValuedAttr: $mode ); let results = (outs XeGPU_Vector2DType: $result); let assemblyFormat = [{ @@ -322,36 +313,33 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas"> { def XeGPU_LoadGatherOp : XeGPU_Op<"load"> { let summary = "load a scalar at source[offset]."; - let arguments = (ins - XeGPU_TensorDesc: $TensorDesc, - XeGPU_MaskType: $mask, - OptionalAttr: $vnni_axis, - OptionalAttr: $transpose, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, - DefaultValuedAttr: $mode - ); - + let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + XeGPU_MaskType: $mask, + OptionalAttr: $vnni_axis, + OptionalAttr: $transpose, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode); let results = (outs XeGPU_ValueType: $value); let builders = [ - OpBuilder<(ins "Type": $value, "Value": $TensorDesc, "Value": $mask, "IntegerAttr": $vnni_axis, - CArg<"DenseI64ArrayAttr", "DenseI64ArrayAttr()">: $transpose, - CArg<"xegpu::CacheReadHintAttr", "xegpu::CacheReadHintAttr()">: $l1_hint, - CArg<"xegpu::CacheReadHintAttr", "xegpu::CacheReadHintAttr()">: $l2_hint, - CArg<"xegpu::CacheReadHintAttr", "xegpu::CacheReadHintAttr()">: $l3_hint)>, + OpBuilder<(ins "mlir::Type": $value, "mlir::Value": $TensorDesc, "mlir::Value": $mask, "mlir::IntegerAttr": $vnni_axis, + CArg<"mlir::DenseI64ArrayAttr", "mlir::DenseI64ArrayAttr()">: $transpose, + CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l1_hint, + CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l2_hint, + CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l3_hint)>, OpBuilder<(ins "Type": $value, "Value": $TensorDesc, "Value": $mask, "IntegerAttr": $vnni_axis, CArg<"DenseI64ArrayAttr", "DenseI64ArrayAttr()">: $transpose, - CArg<"xegpu::CacheReadHint", "xegpu::CacheReadHint::CACHED">: $l1_hint, - CArg<"xegpu::CacheReadHint", "xegpu::CacheReadHint::CACHED">: $l2_hint, - CArg<"xegpu::CacheReadHint", "xegpu::CacheReadHint::CACHED">: $l3_hint)> + CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l1_hint, + CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l2_hint, + CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l3_hint)> ]; let skipDefaultBuilders = 1; - // In format of: %2 = xegpu.load %1, %0 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached} + // Format: %2 = xegpu.load %1, %0 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached} // : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> let hasCustomAssemblyFormat = 1; let hasVerifier = 1; @@ -364,22 +352,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", []> { XeGPU_ValueType: $value, XeGPU_TensorDesc: $TensorDesc, XeGPU_MaskType: $mask, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, - DefaultValuedAttr: $mode + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode ); let builders = [ OpBuilder<(ins "Value": $value, "Value": $TensorDesc, "Value": $mask, - CArg<"xegpu::CacheWriteHintAttr", "xegpu::CacheWriteHintAttr()">: $l1_hint, - CArg<"xegpu::CacheWriteHintAttr", "xegpu::CacheWriteHintAttr()">: $l2_hint, - CArg<"xegpu::CacheWriteHintAttr", "xegpu::CacheWriteHintAttr()">: $l3_hint)>, - + CArg<"xegpu::WriteCacheKindAttr", "xegpu::WriteCacheKindAttr()">: $l1_hint, + CArg<"xegpu::WriteCacheKindAttr", "xegpu::WriteCacheKindAttr()">: $l2_hint, + CArg<"xegpu::WriteCacheKindAttr", "xegpu::WriteCacheKindAttr()">: $l3_hint)>, OpBuilder<(ins "Value": $value, "Value": $TensorDesc, "Value": $mask, - CArg<"xegpu::CacheWriteHint", "xegpu::CacheWriteHint::WRITE_BACK">: $l1_hint, - CArg<"xegpu::CacheWriteHint", "xegpu::CacheWriteHint::WRITE_BACK">: $l2_hint, - CArg<"xegpu::CacheWriteHint", "xegpu::CacheWriteHint::WRITE_BACK">: $l3_hint)> + CArg<"xegpu::WriteCacheKind", "xegpu::WriteCacheKind::WRITE_BACK">: $l1_hint, + CArg<"xegpu::WriteCacheKind", "xegpu::WriteCacheKind::WRITE_BACK">: $l2_hint, + CArg<"xegpu::WriteCacheKind", "xegpu::WriteCacheKind::WRITE_BACK">: $l3_hint)> ]; let skipDefaultBuilders = 1; @@ -389,71 +376,59 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", []> { let hasVerifier = 1; } -def XeGPU_UpdateOffsetOp - : XeGPU_Op<"update_offset", []> { - let summary = "update the offsets for the given tensor descriptor"; - - let arguments = (ins - XeGPU_TensorDesc: $TensorDesc, - XeGPU_OffsetType: $offsets, - DefaultValuedAttr: $mode - ); - - let results = (outs XeGPU_TensorDesc: $result); +def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> { + let summary = "update the offsets for the given tensor descriptor"; + let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + XeGPU_OffsetType: $offsets, + DefaultValuedAttr: $mode); + let results = (outs XeGPU_TensorDesc: $result); - let builders = [ - OpBuilder<(ins "Type": $result, "Value": $TensorDesc, "Value": $offsets), [{ - $_state.addOperands(TensorDesc); - $_state.addOperands(offsets); - $_state.getOrAddProperties().mode = xegpu::ModeAttr::get($_builder.getContext(), xegpu::Mode::VC); - $_state.addTypes(result); - }]> - ]; + let builders = [ + OpBuilder<(ins "Type": $result, "Value": $TensorDesc, "Value": $offsets)> + ]; - let skipDefaultBuilders = 1; + let skipDefaultBuilders = 1; - let assemblyFormat = [{ - $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? - attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) - }]; + let assemblyFormat = [{ + $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? + attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) + }]; - let hasVerifier = 1; - } + let hasVerifier = 1; +} def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { let summary = "prefetches a nD block to cache"; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, - DefaultValuedAttr: $mode - ); + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode); let builders = [ OpBuilder<(ins "Value": $TensorDesc, - CArg<"xegpu::CacheReadHintAttr", "xegpu::CacheReadHintAttr()">: $l1_hint, - CArg<"xegpu::CacheReadHintAttr", "xegpu::CacheReadHintAttr()">: $l2_hint, - CArg<"xegpu::CacheReadHintAttr", "xegpu::CacheReadHintAttr()">: $l3_hint)>, - + CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l1_hint, + CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l2_hint, + CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l3_hint)>, OpBuilder<(ins "Value": $TensorDesc, - CArg<"xegpu::CacheReadHint", "xegpu::CacheReadHint::CACHED">: $l1_hint, - CArg<"xegpu::CacheReadHint", "xegpu::CacheReadHint::CACHED">: $l2_hint, - CArg<"xegpu::CacheReadHint", "xegpu::CacheReadHint::CACHED">: $l3_hint)> + CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l1_hint, + CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l2_hint, + CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l3_hint)> ]; let skipDefaultBuilders = 1; + let hasVerifier = 1; - // In format of: xegpu.prefetch %tdesc {l1_hint = cached, l2_hint = uncached}: + // Format: xegpu.prefetch %tdesc {l1_hint = cached, l2_hint = uncached}: // !xegpu.tensor_desc<8x16xf16> let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; } def XeGPU_InvokeSIMDOp : XeGPU_Op<"invoke_SIMD", []> { let summary = "Invoke_SIMD operation"; let description = [{ - The `xegpu.invoke_SIMD` operation works similar to a direct call to a function. But it is - special to Intel GPU. + The `xegpu.invoke_SIMD` operation works similar to a direct call to a function. + But it is special to Intel GPU. }]; let arguments = (ins FlatSymbolRefAttr:$callee, @@ -463,15 +438,12 @@ def XeGPU_InvokeSIMDOp : XeGPU_Op<"invoke_SIMD", []> { let builders = [ OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, - "xegpu::ArgTypeAttr":$argType, CArg<"ValueRange", "{}">:$operands)>, - + "xegpu::ArgTypeKindAttr":$argType, CArg<"ValueRange", "{}">:$operands)>, OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, - "xegpu::ArgTypeAttr":$argType, CArg<"ValueRange", "{}">:$operands)>, - + "xegpu::ArgTypeKindAttr":$argType, CArg<"ValueRange", "{}">:$operands)>, OpBuilder<(ins "llvm::StringRef":$callee, "TypeRange":$results, - "xegpu::ArgTypeAttr":$argType, CArg<"ValueRange", "{}">:$operands)> + "xegpu::ArgTypeKindAttr":$argType, CArg<"ValueRange", "{}">:$operands)> ]; - } def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { @@ -481,8 +453,9 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { XeGPU_TensorDesc:$tensorDesc, XeGPU_MaskType:$mask, Optional:$value, - DefaultValuedAttr: $mode + DefaultValuedAttr: $mode ); + let results = (outs XeGPU_ValueType:$result); let assemblyFormat = [{ $kind $tensorDesc `,` $mask (`,` $value^)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:` qualified(type(operands)) `->` type($result) @@ -490,98 +463,61 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { let builders = [ OpBuilder<(ins "Type": $result, "xegpu::AtomicRMWKindAttr": $kind, - "Value": $tensorDesc, "Value": $mask, - "Value": $value)>, - + "Value": $tensorDesc, "Value": $mask, "Value": $value)>, OpBuilder<(ins "Type": $result, "xegpu::AtomicRMWKind": $kind, - "Value": $tensorDesc, "Value": $mask, - "Value": $value)> + "Value": $tensorDesc, "Value": $mask, "Value": $value)> ]; let skipDefaultBuilders = 1; - let hasVerifier = 1; } def XeGPU_AllocNbarrierOp: XeGPU_Op<"alloc_nbarrier", []> { - let summary = "allocate a specific number of named barriers."; - let arguments = (ins I32Attr: $nbarrierCount); - let assemblyFormat = "$nbarrierCount attr-dict"; + let summary = "allocate a specific number of named barriers."; + let arguments = (ins I32Attr: $nbarrierCount); + let assemblyFormat = "$nbarrierCount attr-dict"; } -def XeGPU_CreateNbarrierOp - : XeGPU_Op<"create_nbarrier", []> { - let summary = "create a named barrier."; - - let arguments = (ins - I8: $nbarrier_id, - I8: $nbarrier_role, - I8Attr: $num_producers, - I8Attr: $num_consumers, - DefaultValuedAttr: $mode - ); - - let results = (outs XeGPU_Nbarrier: $result); - - let assemblyFormat = [{ - $nbarrier_id `,` $nbarrier_role - attr-dict `:` `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` - `->` qualified(type($result)) - }]; - - // let hasVerifier = 1; - } - -def XeGPU_NbarrierArriveOp - : XeGPU_Op<"nbarrier_arrive", []> { - let summary = "arrive at a named barrier."; - - let arguments = (ins - XeGPU_Nbarrier: $payload - ); - - let assemblyFormat = [{ - $payload attr-dict `:` qualified(type($payload)) - }]; - } - -def XeGPU_NbarrierWaitOp - : XeGPU_Op<"nbarrier_wait", []> { - let summary = "wait for a named barrier."; - - let arguments = (ins - XeGPU_Nbarrier: $payload - ); - - let assemblyFormat = [{ - $payload attr-dict `:` qualified(type($payload)) - }]; - } - -def XeGPU_CompileHintOp - : XeGPU_Op<"compile_hint", []> { - let summary = "prevents the compiler from scheduling."; - - let assemblyFormat = [{ - attr-dict - }]; - } - -def XeGPU_MfenceOp - : XeGPU_Op<"mfence", []> { - let summary = "lsc fence."; - - let arguments = (ins - StrAttr: $memory_kind, - StrAttr: $fence_op, - StrAttr: $fence_scope - ); - - let assemblyFormat = [{ - attr-dict - }]; - } +def XeGPU_CreateNbarrierOp: XeGPU_Op<"create_nbarrier", []> { + let summary = "create a named barrier."; + let arguments = (ins I8: $nbarrier_id, + I8: $nbarrier_role, + I8Attr: $num_producers, + I8Attr: $num_consumers, + DefaultValuedAttr: $mode); + let results = (outs XeGPU_Nbarrier: $result); + let assemblyFormat = [{ + $nbarrier_id `,` $nbarrier_role + attr-dict `:` `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` + `->` qualified(type($result)) + }]; +} + +def XeGPU_NbarrierArriveOp: XeGPU_Op<"nbarrier_arrive", []> { + let summary = "arrive at a named barrier."; + let arguments = (ins XeGPU_Nbarrier: $payload); + let assemblyFormat = [{ $payload attr-dict `:` qualified(type($payload))}]; +} + +def XeGPU_NbarrierWaitOp: XeGPU_Op<"nbarrier_wait", []> { + let summary = "wait for a named barrier."; + let arguments = (ins XeGPU_Nbarrier: $payload); + let assemblyFormat = [{ $payload attr-dict `:` qualified(type($payload)) }]; +} + +def XeGPU_CompileHintOp: XeGPU_Op<"compile_hint", []> { + let summary = "prevents the compiler from scheduling."; + let assemblyFormat = [{ attr-dict }]; +} + +def XeGPU_MfenceOp: XeGPU_Op<"mfence", []> { + let summary = "lsc fence."; + let arguments = (ins StrAttr: $memory_kind, + StrAttr: $fence_op, + StrAttr: $fence_scope); + let assemblyFormat = [{ attr-dict }]; +} #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 8d2f1e769c304..4831cb1fa18c5 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -31,7 +31,7 @@ def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>; // common base class for types in XeGPU dialect class XeGPUTypeDef traits = [], string baseCppClass = "::mlir::Type"> - : TypeDef { + : TypeDef { let mnemonic = typeMnemonic; } @@ -108,13 +108,13 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", )>, TypeBuilder<(ins "llvm::ArrayRef": $shape, "mlir::Type": $elementType, - "mlir::xegpu::MemoryScope": $memory_scope, "int": $array_length, + "mlir::xegpu::MemoryScopeKind": $memory_scope, "int": $array_length, "bool": $boundary_check, "mlir::xegpu::ScatteredAttr": $scattered, "mlir::xegpu::SubGroupMapAttr": $mapping )>, TypeBuilderWithInferredContext<(ins "llvm::ArrayRef": $shape, "mlir::Type": $elementType, - "mlir::xegpu::MemoryScope": $memory_scope, "int": $array_length, + "mlir::xegpu::MemoryScopeKind": $memory_scope, "int": $array_length, "bool": $boundary_check, "mlir::xegpu::ScatteredAttr": $scattered, "mlir::xegpu::SubGroupMapAttr": $mapping )> @@ -147,7 +147,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc", return llvm::dyn_cast_if_present(getEncoding()); } - xegpu::MemoryScope getMemoryScope(); + xegpu::MemoryScopeKind getMemoryScope(); int getArrayLength(); bool getBoundaryCheck(); xegpu::ScatteredAttr getScattered(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 4a2eec6fde163..7c85f44cc9360 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -76,8 +76,7 @@ mlir::LogicalResult SubGroupMapAttr::verify( mlir::Attribute TensorDescAttr::parse(mlir::AsmParser &parser, mlir::Type type) { - - mlir::FailureOr memory_scope; + mlir::FailureOr memory_scope; mlir::FailureOr array_length; mlir::FailureOr boundary_check; mlir::FailureOr scattered; @@ -105,7 +104,7 @@ mlir::Attribute TensorDescAttr::parse(mlir::AsmParser &parser, seen_memory_scope = true; // Parse variable 'memory_scope' memory_scope = - mlir::FieldParser::parse(parser); + mlir::FieldParser::parse(parser); if (mlir::failed(memory_scope)) return parser.emitError( parser.getCurrentLocation(), @@ -157,7 +156,7 @@ mlir::Attribute TensorDescAttr::parse(mlir::AsmParser &parser, if (parser.parseGreater()) return {}; return TensorDescAttr::get( - parser.getContext(), memory_scope.value_or(xegpu::MemoryScope::GLOBAL), + parser.getContext(), memory_scope.value_or(xegpu::MemoryScopeKind::GLOBAL), array_length.value_or(1), boundary_check.value_or(true), scattered.value_or(xegpu::ScatteredAttr()), map.value_or(xegpu::SubGroupMapAttr())); @@ -169,7 +168,7 @@ void TensorDescAttr::print(::mlir::AsmPrinter &printer) const { printer << "<"; - if (printDefaults || getMemoryScope() != xegpu::MemoryScope::GLOBAL) { + if (printDefaults || getMemoryScope() != xegpu::MemoryScopeKind::GLOBAL) { if (printSep) printer << ", "; printSep = true; @@ -208,7 +207,7 @@ void TensorDescAttr::print(::mlir::AsmPrinter &printer) const { bool TensorDescAttr::hasNonDefaultAttrs() { int count = 0; - if (getMemoryScope() != MemoryScope::GLOBAL) + if (getMemoryScope() != MemoryScopeKind::GLOBAL) count++; if (getBoundaryCheck() != true) count++; @@ -222,7 +221,7 @@ bool TensorDescAttr::hasNonDefaultAttrs() { } TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context, - xegpu::MemoryScope memory_scope, + xegpu::MemoryScopeKind memory_scope, int array_length, xegpu::ScatteredAttr scattered, xegpu::SubGroupMapAttr map) { @@ -287,11 +286,11 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const { auto encoding = getEncoding(); if (auto attr = getEncodingAsMapAttr()) { encoding = - TensorDescAttr::get(getContext(), MemoryScope::GLOBAL, 1, {}, attr); + TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1, {}, attr); } if (auto attr = getEncodingAsScatteredAttr()) { encoding = - TensorDescAttr::get(getContext(), MemoryScope::GLOBAL, 1, attr, {}); + TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1, attr, {}); } printer << ", " << encoding; } else if (auto encoding = getEncodingAsTensorDescAttr()) { @@ -312,7 +311,7 @@ TensorDescType TensorDescType::get(llvm::ArrayRef shape, TensorDescType TensorDescType::get(mlir::MLIRContext *context, llvm::ArrayRef shape, mlir::Type elementType, - mlir::xegpu::MemoryScope memory_scope, + mlir::xegpu::MemoryScopeKind memory_scope, int array_length, bool boundary_check, mlir::xegpu::ScatteredAttr scattered, mlir::xegpu::SubGroupMapAttr mapping) { @@ -323,7 +322,7 @@ TensorDescType TensorDescType::get(mlir::MLIRContext *context, TensorDescType TensorDescType::get(llvm::ArrayRef shape, mlir::Type elementType, - mlir::xegpu::MemoryScope memory_scope, + mlir::xegpu::MemoryScopeKind memory_scope, int array_length, bool boundary_check, mlir::xegpu::ScatteredAttr scattered, mlir::xegpu::SubGroupMapAttr mapping) { @@ -333,12 +332,12 @@ TensorDescType TensorDescType::get(llvm::ArrayRef shape, return Base::get(elementType.getContext(), shape, elementType, attr); } -xegpu::MemoryScope TensorDescType::getMemoryScope() { +xegpu::MemoryScopeKind TensorDescType::getMemoryScope() { auto attr = getEncodingAsTensorDescAttr(); if (attr) return attr.getMemoryScope(); // return default value - return MemoryScope::GLOBAL; + return MemoryScopeKind::GLOBAL; } int TensorDescType::getArrayLength() { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index a831e3bc2ae5d..f559491299fe5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -125,15 +125,15 @@ parseOptionalAttrDict(OpAsmParser &parser, OperationState &result, if (nameId == "l1_hint" || nameId == "l2_hint" || nameId == "l3_hint") { if (isWrite) - return parseCustomEnumAttr( + return parseCustomEnumAttr( parser, result, nameId); else - return parseCustomEnumAttr( + return parseCustomEnumAttr( parser, result, nameId); } if (nameId == "mode") { - return parseCustomEnumAttr(parser, result, nameId); + return parseCustomEnumAttr(parser, result, nameId); } if (nameId == "chunk_size_per_lane" || nameId == "vnni_axis") @@ -217,7 +217,7 @@ static bool verifyAndInferShape(std::vector &shape, void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type TensorDesc, Value source, ValueRange offsets, ValueRange shape, ValueRange strides, - llvm::ArrayRef static_offsets, Mode mode) { + llvm::ArrayRef static_offsets, ModeKind mode) { auto offsetRank = static_offsets.size(); auto shapeRank = shape.size() ? shape.size() : getRankOf(source); @@ -243,13 +243,13 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, state.addAttribute(getStaticOffsetsAttrName(state.name), builder.getDenseI64ArrayAttr(static_offsets)); state.addAttribute(getModeAttrName(state.name), - xegpu::ModeAttr::get(builder.getContext(), mode)); + xegpu::ModeKindAttr::get(builder.getContext(), mode)); state.addTypes(TensorDesc); } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, Value source, - llvm::ArrayRef offsets, Mode mode) { + llvm::ArrayRef offsets, ModeKind mode) { auto ty = llvm::dyn_cast_if_present(source.getType()); assert(ty && ty.hasStaticShape() && offsets.size() == getRankOf(source)); @@ -267,7 +267,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, Value source, llvm::ArrayRef offsets, ValueRange shape, ValueRange stride, - xegpu::Mode mode) { + ModeKind mode) { assert(shape.size() && offsets.size() && stride.size() && shape.size() == stride.size() && shape.size() == offsets.size()); @@ -391,7 +391,7 @@ void CreateNdDescOp::print(OpAsmPrinter &printer) { printer << "]"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << ' ' << "{"; printer << "mode = " << mode; printer << "}"; @@ -415,12 +415,12 @@ LogicalResult CreateNdDescOp::verify() { "non-scattered operators.\n"); } - if (mode == Mode::VC && mapping) { + if (mode == ModeKind::VC && mapping) { return emitOpError("Mapping attribute of TensorDesc is not expected " "for VC mode operations.\n"); } - if (mode == Mode::SIMT && !mapping) { + if (mode == ModeKind::SIMT && !mapping) { return emitOpError("Expecting SgMap attribute for SIMT mode operators.\n"); } @@ -606,11 +606,11 @@ void CreateDescOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getOffsets(); - if (printDefaults || mode != Mode::SIMT || chunk != 1) { + if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { printer << ' ' << "{"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << "mode = " << mode; printSep = true; } @@ -621,7 +621,7 @@ void CreateDescOp::print(OpAsmPrinter &printer) { printer << "chunk_size_per_lane = " << chunk; } - if (printDefaults || mode != Mode::SIMT || chunk != 1) { + if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { printer << "}"; } @@ -643,7 +643,7 @@ LogicalResult CreateDescOp::verify() { auto tdescTy = getTensorDesc().getType(); auto chunkSize = getChunkSizePerLane(); - if (mode == Mode::SIMT || mapping) { + if (mode == ModeKind::SIMT || mapping) { return emitOpError("CreateDescOp only support VC mode and mapping " "attribute of TensorDesc is not expected.\n"); } @@ -685,7 +685,7 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state, state.getOrAddProperties().chunk_size_per_lane = builder.getIntegerAttr(builder.getIntegerType(32), chunk_size_per_lane); state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); state.addTypes(TensorDesc); } @@ -698,7 +698,7 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state, state.getOrAddProperties().chunk_size_per_lane = chunk_size_per_lane; state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); state.addTypes(TensorDesc); } @@ -748,11 +748,11 @@ void LoadNDOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << ' ' << "{"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << "mode = " << mode; printSep = true; } @@ -774,7 +774,7 @@ void LoadNDOp::print(OpAsmPrinter &printer) { printCacheHintAttrs(printer, *this, printSep); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << "}"; } @@ -809,7 +809,7 @@ LogicalResult LoadNDOp::verify() { auto valueShape = valueTy.getShape().vec(); auto array_len = tdescTy.getArrayLength(); - if (mode == Mode::SIMT) { + if (mode == ModeKind::SIMT) { auto sgMap = tdescTy.getMapping(); if (!sgMap) { return emitOpError( @@ -925,18 +925,18 @@ void StoreNDOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << ' ' << "{"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << "mode = " << getMode(); printSep = true; } printCacheHintAttrs(printer, *this, true); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << "}"; } @@ -969,7 +969,7 @@ LogicalResult StoreNDOp::verify() { auto mode = getMode(); - if (mode == Mode::VC) { // for VC mode, no attr attached + if (mode == ModeKind::VC) { // for VC mode, no attr attached if (dstTy.getShape() != valTy.getShape()) return emitOpError("In VC mode, the value (vector) shape doesn't match " "the memory (dst) shape.\n"); @@ -1039,18 +1039,18 @@ void PrefetchNDOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << ' ' << "{"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << "mode = " << getMode(); printSep = true; } printCacheHintAttrs(printer, *this, true); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << "}"; } @@ -1157,11 +1157,11 @@ void LoadGatherOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getMask(); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << ' ' << "{"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << "mode = " << getMode(); printSep = true; } @@ -1183,7 +1183,7 @@ void LoadGatherOp::print(OpAsmPrinter &printer) { printCacheHintAttrs(printer, *this, printSep); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << "}"; } @@ -1244,7 +1244,7 @@ LogicalResult LoadGatherOp::verify() { auto mode = getMode(); auto mapping = tdescTy.getMapping(); - if (mode == Mode::SIMT || mapping) { + if (mode == ModeKind::SIMT || mapping) { return emitOpError("LoadGatherOp only supports VC mode and mapping " "attribute of TensorDesc is not expected.\n"); } @@ -1280,8 +1280,8 @@ LogicalResult LoadGatherOp::verify() { void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, Value TensorDesc, Value mask, IntegerAttr vnni_axis, - DenseI64ArrayAttr transpose, CacheReadHintAttr l1_hint, - CacheReadHintAttr l2_hint, CacheReadHintAttr l3_hint) { + DenseI64ArrayAttr transpose, ReadCacheKindAttr l1_hint, + ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { state.addOperands(TensorDesc); state.addOperands(mask); if (vnni_axis) @@ -1300,14 +1300,14 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, state.getOrAddProperties().l3_hint = l3_hint; state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); state.addTypes(value); } void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, Value TensorDesc, Value mask, IntegerAttr vnni_axis, - DenseI64ArrayAttr transpose, CacheReadHint l1_hint, - CacheReadHint l2_hint, CacheReadHint l3_hint) { + DenseI64ArrayAttr transpose, ReadCacheKind l1_hint, + ReadCacheKind l2_hint, ReadCacheKind l3_hint) { state.addOperands(TensorDesc); state.addOperands(mask); if (vnni_axis) @@ -1317,13 +1317,13 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, state.getOrAddProperties().transpose = transpose; state.getOrAddProperties().l1_hint = - CacheReadHintAttr::get(builder.getContext(), l1_hint); + ReadCacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - CacheReadHintAttr::get(builder.getContext(), l2_hint); + ReadCacheKindAttr::get(builder.getContext(), l2_hint); state.getOrAddProperties().l3_hint = - CacheReadHintAttr::get(builder.getContext(), l3_hint); + ReadCacheKindAttr::get(builder.getContext(), l3_hint); state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); state.addTypes(value); } @@ -1420,18 +1420,18 @@ void StoreScatterOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getMask(); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << ' ' << "{"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << "mode = " << getMode(); printSep = true; } printCacheHintAttrs(printer, *this, printSep); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << "}"; } @@ -1453,7 +1453,7 @@ LogicalResult StoreScatterOp::verify() { auto mode = getMode(); auto mapping = tdescTy.getMapping(); - if (mode != Mode::VC || mapping) + if (mode != ModeKind::VC || mapping) return emitOpError("StoreScatterOp only supports VC mode and mapping " "attribute of TensorDesc is not expected.\n"); @@ -1492,9 +1492,9 @@ LogicalResult StoreScatterOp::verify() { void StoreScatterOp::build(OpBuilder &builder, OperationState &state, Value value, Value TensorDesc, Value mask, - CacheWriteHintAttr l1_hint, - CacheWriteHintAttr l2_hint, - CacheWriteHintAttr l3_hint) { + WriteCacheKindAttr l1_hint, + WriteCacheKindAttr l2_hint, + WriteCacheKindAttr l3_hint) { state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); @@ -1508,26 +1508,26 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, state.getOrAddProperties().l3_hint = l3_hint; } state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); } void StoreScatterOp::build(OpBuilder &builder, OperationState &state, Value value, Value TensorDesc, Value mask, - CacheWriteHint l1_hint, CacheWriteHint l2_hint, - CacheWriteHint l3_hint) { + WriteCacheKind l1_hint, WriteCacheKind l2_hint, + WriteCacheKind l3_hint) { state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); state.getOrAddProperties().l1_hint = - CacheWriteHintAttr::get(builder.getContext(), l1_hint); + WriteCacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - CacheWriteHintAttr::get(builder.getContext(), l2_hint); + WriteCacheKindAttr::get(builder.getContext(), l2_hint); ; state.getOrAddProperties().l3_hint = - CacheWriteHintAttr::get(builder.getContext(), l3_hint); + WriteCacheKindAttr::get(builder.getContext(), l3_hint); ; state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); } ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1566,18 +1566,18 @@ void PrefetchOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << ' ' << "{"; } - if (printDefaults || mode != Mode::SIMT) { + if (printDefaults || mode != ModeKind::SIMT) { printer << "mode = " << getMode(); printSep = true; } printCacheHintAttrs(printer, *this, printSep); - if (printDefaults || mode != Mode::SIMT || numAttrs > 1) { + if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { printer << "}"; } @@ -1595,7 +1595,7 @@ LogicalResult PrefetchOp::verify() { return emitOpError("Invalid TensorDesc. PrefetchOp only works on " "TensorDescs with ScatteredAttr."); - if (mode != Mode::VC || mapping) { + if (mode != ModeKind::VC || mapping) { return emitOpError("PrefetchOp only supports VC mode. and mapping " "attribute of TensorDesc is not expected.\n"); } @@ -1604,8 +1604,8 @@ LogicalResult PrefetchOp::verify() { } void PrefetchOp::build(OpBuilder &builder, OperationState &state, - Value TensorDesc, CacheReadHintAttr l1_hint, - CacheReadHintAttr l2_hint, CacheReadHintAttr l3_hint) { + Value TensorDesc, ReadCacheKindAttr l1_hint, + ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { state.addOperands(TensorDesc); if (l1_hint) state.getOrAddProperties().l1_hint = l1_hint; @@ -1617,22 +1617,30 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, state.getOrAddProperties().l3_hint = l3_hint; state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); } void PrefetchOp::build(OpBuilder &builder, OperationState &state, - Value TensorDesc, CacheReadHint l1_hint, - CacheReadHint l2_hint, CacheReadHint l3_hint) { + Value TensorDesc, ReadCacheKind l1_hint, + ReadCacheKind l2_hint, ReadCacheKind l3_hint) { state.addOperands(TensorDesc); state.getOrAddProperties().l1_hint = - CacheReadHintAttr::get(builder.getContext(), l1_hint); + ReadCacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - CacheReadHintAttr::get(builder.getContext(), l2_hint); + ReadCacheKindAttr::get(builder.getContext(), l2_hint); state.getOrAddProperties().l3_hint = - CacheReadHintAttr::get(builder.getContext(), l3_hint); - ; + ReadCacheKindAttr::get(builder.getContext(), l3_hint); state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); +} + +void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, + Type result, Value TensorDesc, Value offsets) { + state.addOperands(TensorDesc); + state.addOperands(offsets); + state.getOrAddProperties().mode = + xegpu::ModeKindAttr::get(builder.getContext(), xegpu::ModeKind::VC); + state.addTypes(result); } LogicalResult UpdateOffsetOp::verify() { @@ -1641,9 +1649,8 @@ LogicalResult UpdateOffsetOp::verify() { auto resTy = getResult().getType(); if (srcTy != resTy) - return emitOpError( - "The result should have the same type" - "(shape and encoding attribute) as the input TensorDesc."); + return emitOpError("The result should have the same type (shape and " + "encoding attribute) as the input TensorDesc."); auto shape = srcTy.getShape(); @@ -1673,7 +1680,7 @@ LogicalResult UpdateNDOffsetOp::verify() { void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, SymbolRefAttr callee, TypeRange results, - ArgTypeAttr argType, ValueRange operands) { + ArgTypeKindAttr argType, ValueRange operands) { state.addOperands(operands); state.addAttribute("argType", argType); state.addAttribute("callee", callee); @@ -1682,41 +1689,39 @@ void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, StringAttr callee, TypeRange results, - ArgTypeAttr argType, ValueRange operands) { + ArgTypeKindAttr argType, ValueRange operands) { build(builder, state, SymbolRefAttr::get(callee), results, argType, operands); } void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, llvm::StringRef callee, TypeRange results, - ArgTypeAttr argType, ValueRange operands) { - build(builder, state, StringAttr::get(builder.getContext(), callee), results, - argType, operands); + ArgTypeKindAttr argType, ValueRange operands) { + build(builder, state, StringAttr::get(builder.getContext(), callee), + results, argType, operands); } LogicalResult AtomicRMWOp::verify() { auto mode = getMode(); - if (mode != Mode::VC) { + if (mode != ModeKind::VC) { return emitOpError("AtomicRMWOp only work on VC mode.\n"); } return success(); } void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, - AtomicRMWKindAttr kind, Value tensorDesc, Value mask, - Value value) { + AtomicRMWKindAttr kind, Value tensorDesc, Value mask, Value value) { state.addOperands(tensorDesc); state.addOperands(mask); if (value) state.addOperands(value); state.getOrAddProperties().kind = kind; state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); state.addTypes(result); } void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, - AtomicRMWKind kind, Value tensorDesc, Value mask, - Value value) { + AtomicRMWKind kind, Value tensorDesc, Value mask, Value value) { state.addOperands(tensorDesc); state.addOperands(mask); if (value) @@ -1724,7 +1729,7 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, state.getOrAddProperties().kind = AtomicRMWKindAttr::get(builder.getContext(), kind); state.getOrAddProperties().mode = - ModeAttr::get(builder.getContext(), Mode::VC); + ModeKindAttr::get(builder.getContext(), ModeKind::VC); state.addTypes(result); } From c27ce22680250934f6f286108f170acd9c2b999a Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Fri, 12 Jan 2024 19:03:51 +0000 Subject: [PATCH 2/8] update attr def following upstream practices --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 30 +++++++++---------- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 5 ++-- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 23 +++++++------- mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir | 6 ++-- mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir | 6 ++-- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 00b149090003f..c54d65bbdf77b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -128,21 +128,21 @@ def XeGPU_ReadCacheAttr : EnumAttr; // RMW kind attribute -def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; -def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; -def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; -def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; -def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; -def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; -def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; -def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; -def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; -def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; -def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; -def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>; -def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>; - -def XeGPU_AtomicRMWKind : I64EnumAttr<"AtomicRMWKind", +def ATOMIC_RMW_KIND_ADDF : I32EnumAttrCase<"addf", 0>; +def ATOMIC_RMW_KIND_ADDI : I32EnumAttrCase<"addi", 1>; +def ATOMIC_RMW_KIND_ASSIGN : I32EnumAttrCase<"assign", 2>; +def ATOMIC_RMW_KIND_MAXF : I32EnumAttrCase<"maxf", 3>; +def ATOMIC_RMW_KIND_MAXS : I32EnumAttrCase<"maxs", 4>; +def ATOMIC_RMW_KIND_MAXU : I32EnumAttrCase<"maxu", 5>; +def ATOMIC_RMW_KIND_MINF : I32EnumAttrCase<"minf", 6>; +def ATOMIC_RMW_KIND_MINS : I32EnumAttrCase<"mins", 7>; +def ATOMIC_RMW_KIND_MINU : I32EnumAttrCase<"minu", 8>; +def ATOMIC_RMW_KIND_MULF : I32EnumAttrCase<"mulf", 9>; +def ATOMIC_RMW_KIND_MULI : I32EnumAttrCase<"muli", 10>; +def ATOMIC_RMW_KIND_ORI : I32EnumAttrCase<"ori", 11>; +def ATOMIC_RMW_KIND_ANDI : I32EnumAttrCase<"andi", 12>; + +def XeGPU_AtomicRMWKind : I32EnumAttr<"AtomicRMWKind", "Operation type for AtomicRMW", [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 7e62185d7b08e..743498a83053c 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -222,7 +222,6 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { xegpu::TensorDescType getTensorDescType() { return getTensorDesc().getType(); } - }]; // Format: xegpu.load_nd %1 {transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} @@ -380,7 +379,7 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> { let summary = "update the offsets for the given tensor descriptor"; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, XeGPU_OffsetType: $offsets, - DefaultValuedAttr: $mode); + DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc: $result); let builders = [ @@ -453,7 +452,7 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { XeGPU_TensorDesc:$tensorDesc, XeGPU_MaskType:$mask, Optional:$value, - DefaultValuedAttr: $mode + DefaultValuedAttr: $mode ); let results = (outs XeGPU_ValueType:$result); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index f559491299fe5..78798b147896d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -42,7 +42,7 @@ static void transpose(llvm::ArrayRef trans, std::vector old = shape; for (size_t i = 0; i < trans.size(); i++) shape[i] = old[trans[i]]; -}; +} template static std::string makeString(T array, bool breakline = false) { @@ -97,7 +97,7 @@ static ParseResult parseBoolAndIntegerAttr(OpAsmParser &parser, if (attr) result.addAttribute(attrKeyword, attr); return success(); -}; +} /// @brief Parsing optional attribute list which are enclosed in braces "{}", /// and seperated by comma @@ -1644,26 +1644,28 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, } LogicalResult UpdateOffsetOp::verify() { + + auto mode = getMode(); + if (mode != ModeKind::VC) + return emitOpError("UpdateOffsetOp only work on VC mode.\n"); + auto srcTy = getTensorDesc().getType(); - auto offTy = getOffsets().getType(); auto resTy = getResult().getType(); - if (srcTy != resTy) return emitOpError("The result should have the same type (shape and " "encoding attribute) as the input TensorDesc."); - auto shape = srcTy.getShape(); - if (!srcTy.getScattered()) { return emitOpError("Invalid TensorDesc. UpdateOffsetOp only works on " "TensorDescs with ScatteredAttr."); } - auto vecTy = llvm::dyn_cast(offTy); - if (!vecTy || vecTy.getRank() != 1) + auto offTy = llvm::dyn_cast(getOffsets().getType()); + if (!offTy || offTy.getRank() != 1) return emitOpError("The offset should be an 1D vector.\n"); - if (shape[0] != vecTy.getShape()[0]) + auto shape = srcTy.getShape(); + if (shape[0] != offTy.getShape()[0]) return emitOpError( "The offset should have same length as the dim-0 of TensorDesc."); @@ -1702,9 +1704,8 @@ void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, LogicalResult AtomicRMWOp::verify() { auto mode = getMode(); - if (mode != ModeKind::VC) { + if (mode != ModeKind::VC) return emitOpError("AtomicRMWOp only work on VC mode.\n"); - } return success(); } diff --git a/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir b/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir index 25e7de33c6c12..08cfa47104622 100644 --- a/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir +++ b/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir @@ -10,7 +10,7 @@ func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : v // CHECK: xegpu.atomic_rmw // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32> - xegpu.atomic_rmw "addf" %1, %mask, %value {mode=vc} + xegpu.atomic_rmw addf %1, %mask, %value {mode=vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32> -> vector<16xf32> return @@ -23,7 +23,7 @@ func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : // CHECK: xegpu.atomic_rmw // CHECK-SAME: tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> - xegpu.atomic_rmw "mulf" %1, %mask, %value {mode=vc} + xegpu.atomic_rmw mulf %1, %mask, %value {mode=vc} : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> return @@ -36,7 +36,7 @@ func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : // CHECK: xegpu.atomic_rmw // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> - xegpu.atomic_rmw "andi" %1, %mask, %value {mode=vc} + xegpu.atomic_rmw andi %1, %mask, %value {mode=vc} : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> return diff --git a/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir b/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir index 096451df04564..0f7229a02aa18 100644 --- a/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/atomic_rmw_vc.mlir @@ -10,7 +10,7 @@ func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : v // CHECK: xegpu.atomic_rmw // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> - xegpu.atomic_rmw "addf" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> -> vector<16x1xf32> + xegpu.atomic_rmw addf %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16x1xf32> -> vector<16x1xf32> return } @@ -21,7 +21,7 @@ func.func @test_atomic_rmw_0(%src: ui64, %offsets : vector<16 x index>, %value : // CHECK: xegpu.atomic_rmw // CHECK-SAME: !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> - xegpu.atomic_rmw "mulf" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> + xegpu.atomic_rmw mulf %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xf32, #xegpu.scattered>, vector<16xi1>, vector<16x2xf32> -> vector<16x2xf32> return } @@ -32,7 +32,7 @@ func.func @test_atomic_rmw_1(%src: ui64, %offsets : vector<16 x index>, %value : // CHECK: xegpu.atomic_rmw // CHECK-SAME: !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> - xegpu.atomic_rmw "andi" %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> + xegpu.atomic_rmw andi %1, %mask, %value {mode = vc} : !xegpu.tensor_desc<16x2xi32, #xegpu.scattered>, vector<16xi1>, vector<16x2xi32> -> vector<16x2xf32> return } From c8fa03c10fd7f41d8d64c4c024961df42598ce17 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Fri, 12 Jan 2024 20:56:27 +0000 Subject: [PATCH 3/8] reorder the code --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 2 +- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 147 ++-- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 725 ++++++++++-------- mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir | 2 +- 4 files changed, 472 insertions(+), 404 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index c54d65bbdf77b..edcf863691217 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -152,6 +152,6 @@ def XeGPU_AtomicRMWKind : I32EnumAttr<"AtomicRMWKind", let genSpecializedAttr = 0; let cppNamespace = "::mlir::xegpu"; } -def XeGPU_AtomicRMWKindAttr : EnumAttr; +def XeGPU_AtomicRMWKindAttr : EnumAttr; #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 743498a83053c..6ed65286c2524 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -144,55 +144,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, AttrSizedOperandSeg } -def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure]> { - let summary = "create scattered tensor descritors (TensorDesc)."; - let description = [{ - "create_tdesc" is similar to "create_nd_tdesc" in terms that it creates - a Tensor Descriptor (TensorDescType) for a memory region. While "create_nd_tdesc" - is for creating continious subviews, "create_tdesc" is for creating non-continious - (scattered) subviews. It is designed only works with VectorCompute (VC) mode and - accepts the following parameters: - - * source: a 1D memref or pointer (uint64_t) represents the memory object. - * offsets: It is a 1D vector containing offsets of each access point, the supportted - group size, e.g., vector<16xindex>. And each element in the vector corresponds - to a work item (SIMT lane) in the subgroup. - * chunk_size_per_lane: [optional attribute] indicates number of continious elements - accessed for each offset, default is 1. - - Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64] - %a = memref.alloc() : memref<1024xf32> - %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> - %1 = xegpu.create_tdesc %a, %c0: memref<1024xf32> -> TensorDesc<4xf32> - - Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. - It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71] - %0 = memref.alloc() : memref<1024xf32> - %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> - %1 = xegpu.create_tdesc %0, %c0 {chunk_size_per_lane = 8}: memref<1024xf32> -> TensorDesc<4x8xf32> - }]; - - let arguments = (ins XeGPU_BaseAddrType: $source, - XeGPU_OffsetType: $offsets, - DefaultValuedAttr: $chunk_size_per_lane, - DefaultValuedAttr: $mode); - let results = (outs XeGPU_TensorDesc:$TensorDesc); - - let builders = [ - OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source, - "Value": $offsets, CArg<"uint32_t", "1"> : $chunk_size_per_lane)>, - OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source, - "Value": $offsets, "IntegerAttr": $chunk_size_per_lane)> - ]; - let skipDefaultBuilders = 1; - - // Format: xegpu.create_tdesc %src, %offsets {mode=simt, chunk_size_per_lane=1} - // : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; -} - - def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { let summary = "loads a n-D block from memory (represented by TensorDesc)" "to registers (represented by vector)"; @@ -276,36 +227,51 @@ def XeGPU_UpdateNDOffsetOp : XeGPU_Op<"update_nd_offset", []> { let hasVerifier = 1; } -def XeGPU_DpasOp : XeGPU_Op<"dpas"> { - let summary = "performs dpas computation"; - let arguments = (ins - XeGPU_DpasOpType : $lhs, - XeGPU_DpasOpType : $rhs, - Optional: $acc, - DefaultValuedAttr: $mode - ); - let results = (outs XeGPU_Vector2DType: $result); - let assemblyFormat = [{ - $lhs `,` $rhs (`,` $acc^)? (` ``{` `mode` `=` $mode^ `}`)? attr-dict `:` - qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result)) - }]; - - let extraClassDeclaration = [{ - VectorType getLhsType() { - return ::llvm::cast(getLhs().getType()); - } +def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure]> { + let summary = "create scattered tensor descritors (TensorDesc)."; + let description = [{ + "create_tdesc" is similar to "create_nd_tdesc" in terms that it creates + a Tensor Descriptor (TensorDescType) for a memory region. While "create_nd_tdesc" + is for creating continious subviews, "create_tdesc" is for creating non-continious + (scattered) subviews. It is designed only works with VectorCompute (VC) mode and + accepts the following parameters: - VectorType getRhsType() { - return ::llvm::cast(getRhs().getType()); - } + * source: a 1D memref or pointer (uint64_t) represents the memory object. + * offsets: It is a 1D vector containing offsets of each access point, the supportted + group size, e.g., vector<16xindex>. And each element in the vector corresponds + to a work item (SIMT lane) in the subgroup. + * chunk_size_per_lane: [optional attribute] indicates number of continious elements + accessed for each offset, default is 1. - VectorType getAccType() { - return ::llvm::cast(getAcc().getType()); - } + Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64] + %a = memref.alloc() : memref<1024xf32> + %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> + %1 = xegpu.create_tdesc %a, %c0: memref<1024xf32> -> TensorDesc<4xf32> - VectorType getResultType() { return getResult().getType(); } + Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. + It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71] + %0 = memref.alloc() : memref<1024xf32> + %c0 = arith.constant dense<0, 16, 32, 64> : vector<4xindex> + %1 = xegpu.create_tdesc %0, %c0 {chunk_size_per_lane = 8}: memref<1024xf32> -> TensorDesc<4x8xf32> }]; + let arguments = (ins XeGPU_BaseAddrType: $source, + XeGPU_OffsetType: $offsets, + DefaultValuedAttr: $chunk_size_per_lane, + DefaultValuedAttr: $mode); + let results = (outs XeGPU_TensorDesc:$TensorDesc); + + let builders = [ + OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source, + "Value": $offsets, CArg<"uint32_t", "1"> : $chunk_size_per_lane)>, + OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source, + "Value": $offsets, "IntegerAttr": $chunk_size_per_lane)> + ]; + let skipDefaultBuilders = 1; + + // Format: xegpu.create_tdesc %src, %offsets {mode=simt, chunk_size_per_lane=1} + // : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -375,6 +341,39 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", []> { let hasVerifier = 1; } +def XeGPU_DpasOp : XeGPU_Op<"dpas"> { + let summary = "performs dpas computation"; + let arguments = (ins + XeGPU_DpasOpType : $lhs, + XeGPU_DpasOpType : $rhs, + Optional: $acc, + DefaultValuedAttr: $mode + ); + let results = (outs XeGPU_Vector2DType: $result); + let assemblyFormat = [{ + $lhs `,` $rhs (`,` $acc^)? (` ``{` `mode` `=` $mode^ `}`)? attr-dict `:` + qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result)) + }]; + + let extraClassDeclaration = [{ + VectorType getLhsType() { + return ::llvm::cast(getLhs().getType()); + } + + VectorType getRhsType() { + return ::llvm::cast(getRhs().getType()); + } + + VectorType getAccType() { + return ::llvm::cast(getAcc().getType()); + } + + VectorType getResultType() { return getResult().getType(); } + }]; + + let hasVerifier = 1; +} + def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> { let summary = "update the offsets for the given tensor descriptor"; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 78798b147896d..fca6c10aae81b 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -27,6 +27,22 @@ namespace xegpu { extern bool printDefaultValues(); +template +static std::string makeString(T array, bool breakline = false) { + std::string buf; + buf.clear(); + llvm::raw_string_ostream os(buf); + os << "["; + for (size_t i = 1; i < array.size(); i++) { + os << array[i - 1] << ", "; + if (breakline) + os << "\n\t\t"; + } + os << array.back() << "]"; + os.flush(); + return buf; +} + static size_t getRankOf(Value value) { if (value.getType().isIntOrIndexOrFloat()) return 0; @@ -44,20 +60,29 @@ static void transpose(llvm::ArrayRef trans, shape[i] = old[trans[i]]; } -template -static std::string makeString(T array, bool breakline = false) { - std::string buf; - buf.clear(); - llvm::raw_string_ostream os(buf); - os << "["; - for (size_t i = 1; i < array.size(); i++) { - os << array[i - 1] << ", "; - if (breakline) - os << "\n\t\t"; +static bool verifyAndInferShape(std::vector &shape, + SubGroupMapAttr sgMap) { + if (sgMap) { + auto wiLayout = sgMap.getWiLayout(); + auto wiData = sgMap.getWiData(); + + if ((int64_t)shape.size() != wiData.size() || + (int64_t)shape.size() != wiLayout.size()) { + return false; + } + + for (size_t i = 0; i < shape.size(); i++) { + + if ((shape[i] % (wiLayout[i] * wiData[i]) != 0 && + (wiLayout[i] * wiData[i]) % shape[i] != 0) || + shape[i] % wiLayout[i] != 0 || shape[i] % wiData[i] != 0) { + return false; + } + shape[i] /= wiLayout[i]; + } } - os << array.back() << "]"; - os.flush(); - return buf; + + return true; } template @@ -99,12 +124,6 @@ static ParseResult parseBoolAndIntegerAttr(OpAsmParser &parser, return success(); } -/// @brief Parsing optional attribute list which are enclosed in braces "{}", -/// and seperated by comma -/// @param parser -/// @param result -/// @param allowedKeywords -/// @return static ParseResult parseOptionalAttrDict(OpAsmParser &parser, OperationState &result, llvm::ArrayRef allowedKeywords, @@ -177,43 +196,10 @@ static void printCacheHintAttrs(OpAsmPrinter &printer, T op, bool printSep) { } } -static bool verifyAndInferShape(std::vector &shape, - SubGroupMapAttr sgMap) { - if (sgMap) { - auto wiLayout = sgMap.getWiLayout(); - auto wiData = sgMap.getWiData(); - - if ((int64_t)shape.size() != wiData.size() || - (int64_t)shape.size() != wiLayout.size()) { - return false; - } - - for (size_t i = 0; i < shape.size(); i++) { - - if ((shape[i] % (wiLayout[i] * wiData[i]) != 0 && - (wiLayout[i] * wiData[i]) % shape[i] != 0) || - shape[i] % wiLayout[i] != 0 || shape[i] % wiData[i] != 0) { - return false; - } - shape[i] /= wiLayout[i]; - } - } - return true; -} - -/// @brief the base builder for CreateNdDescOp -/// @param builder, the mlir OpBuilder -/// @param state , the mlir OperationState -/// @param TensorDesc, the TensorDescType of the result -/// @param source, the base address of the data. It can be either 2D memref -/// object or simple integer value (pointer) -/// @param offsets, the dynamic offset given as Value -/// @param shape, the dynamic shape given as array of Values -/// @param strides, the dynamic shape given as array of Values -/// @param static_offsets, the static offset. If it is not used it should be -/// filled with ShapeType::kDynamic -/// @param mode, VC or SIMT +//===----------------------------------------------------------------------===// +// XeGPU_CreateNdDescOp +//===----------------------------------------------------------------------===// void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type TensorDesc, Value source, ValueRange offsets, ValueRange shape, ValueRange strides, @@ -541,166 +527,10 @@ llvm::ArrayRef CreateNdDescOp::getTensorDescShape() { return getTensorDescType().getShape(); } -ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand sourceRawOperands[1]; - llvm::ArrayRef sourceOperands( - sourceRawOperands); - llvm::SMLoc sourceOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(sourceRawOperands[0])) - return failure(); - - if (parser.parseComma()) - return failure(); - - OpAsmParser::UnresolvedOperand offsetsRawOperands[1]; - llvm::ArrayRef offsetsOperands( - offsetsRawOperands); - llvm::SMLoc offsetsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(offsetsRawOperands[0])) - return failure(); - - if (parseOptionalAttrDict(parser, result, {"chunk_size_per_lane", "mode"})) - return failure(); - - if (parser.parseColon()) - return failure(); - - Type sourceRawTypes[1]; - llvm::ArrayRef sourceTypes(sourceRawTypes); - if (parser.parseType(sourceRawTypes[0])) - return failure(); - if (parser.parseComma()) - return failure(); - - Type offsetsRawTypes[1]; - llvm::ArrayRef offsetsTypes(offsetsRawTypes); - if (parser.parseType(offsetsRawTypes[0])) - return failure(); - if (parser.parseArrow()) - return failure(); - - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); - if (parser.parseType(TensorDescRawTypes[0])) - return failure(); - - result.addTypes(TensorDescTypes); - if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc, - result.operands)) - return failure(); - if (parser.resolveOperands(offsetsOperands, offsetsTypes, offsetsOperandsLoc, - result.operands)) - return failure(); - return success(); -} - -void CreateDescOp::print(OpAsmPrinter &printer) { - auto mode = getMode(); - bool printSep = false; - auto chunk = getChunkSizePerLane(); - auto printDefaults = printDefaultValues(); - - printer << ' '; - printer << getSource(); - printer << ","; - printer << ' '; - printer << getOffsets(); - - if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { - printer << ' ' << "{"; - } - - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << mode; - printSep = true; - } - - if (printDefaults || chunk != 1) { - if (printSep) - printer << "," << ' '; - printer << "chunk_size_per_lane = " << chunk; - } - - if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { - printer << "}"; - } - - printer << ' ' << ":"; - printer << ' '; - printer << getSource().getType(); - printer << ","; - printer << ' '; - printer << getOffsets().getType(); - printer << ' ' << "->"; - printer << ' '; - printer << getTensorDesc().getType(); -} - -LogicalResult CreateDescOp::verify() { - auto mode = getMode(); - auto mapping = getTensorDesc().getType().getMapping(); - auto offsetTy = getOffsets().getType(); - auto tdescTy = getTensorDesc().getType(); - auto chunkSize = getChunkSizePerLane(); - - if (mode == ModeKind::SIMT || mapping) { - return emitOpError("CreateDescOp only support VC mode and mapping " - "attribute of TensorDesc is not expected.\n"); - } - - if (getRankOf(getSource()) > 2) - return emitOpError( - "Expecting the source is a 1D/2D memref or pointer (uint64_t)."); - - if (!tdescTy.getScattered()) - return emitOpError( - "Expecting the presence of ScatteredAttr for tensor descriptor."); - - // Infer the TensorDesc shape - std::vector shape; - if (llvm::isa(offsetTy)) { - shape = llvm::dyn_cast(offsetTy).getShape().vec(); - if (shape.size() != 1) - return emitOpError("Expecting the offset is a 1D vector."); - } - - if (chunkSize != 1) { - shape.push_back(chunkSize); - } - - auto tdescShape = tdescTy.getShape(); - if (shape != tdescShape.vec()) { - return emitOpError("Expecting dimensions of offsets is the same as the " - "tensor descriptor, or one less than."); - } - - return success(); -} - -void CreateDescOp::build(OpBuilder &builder, OperationState &state, - TensorDescType TensorDesc, Value source, Value offsets, - uint32_t chunk_size_per_lane) { - state.addOperands(source); - state.addOperands(offsets); - state.getOrAddProperties().chunk_size_per_lane = - builder.getIntegerAttr(builder.getIntegerType(32), chunk_size_per_lane); - state.getOrAddProperties().mode = - ModeKindAttr::get(builder.getContext(), ModeKind::VC); - state.addTypes(TensorDesc); -} -void CreateDescOp::build(OpBuilder &builder, OperationState &state, - TensorDescType TensorDesc, Value source, Value offsets, - IntegerAttr chunk_size_per_lane) { - state.addOperands(source); - state.addOperands(offsets); - if (chunk_size_per_lane) - state.getOrAddProperties().chunk_size_per_lane = - chunk_size_per_lane; - state.getOrAddProperties().mode = - ModeKindAttr::get(builder.getContext(), ModeKind::VC); - state.addTypes(TensorDesc); -} +//===----------------------------------------------------------------------===// +// XeGPU_LoadNDOp +//===----------------------------------------------------------------------===// ParseResult LoadNDOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; @@ -864,6 +694,9 @@ LogicalResult LoadNDOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_StoreNDOp +//===----------------------------------------------------------------------===// ParseResult StoreNDOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand valueRawOperands[1]; llvm::ArrayRef valueOperands( @@ -1004,6 +837,9 @@ LogicalResult StoreNDOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_PrefetchNDOp +//===----------------------------------------------------------------------===// ParseResult PrefetchNDOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; llvm::ArrayRef TensorDescOperands( @@ -1059,30 +895,280 @@ void PrefetchNDOp::print(OpAsmPrinter &printer) { printer << getTensorDesc().getType(); } -LogicalResult DpasOp::verify() { - - int64_t lhsRank = getLhsType().getRank(); - int64_t rhsRank = getRhsType().getRank(); - Type lhsElemType = getLhsType().getElementType(); - Type rhsElemType = getRhsType().getElementType(); - - if (lhsElemType != rhsElemType) { - return emitOpError("lhs and rhs element type does not match for dpas op"); - } - - if (getAcc() && getAccType() != getResultType()) { - return emitOpError("Accumulator and Result for dpas op should have the " - "same type (both shape and element type)."); +//===----------------------------------------------------------------------===// +// XeGPU_UpdateNDOffsetOp +//===----------------------------------------------------------------------===// +LogicalResult UpdateNDOffsetOp::verify() { + // number of offsets specified must match the rank of the tensor descriptor + if (getTensorDesc().getType().getRank() != (int64_t)getOffsets().size()) { + return emitOpError("Invalid number of offsets."); } + return success(); +} - if (lhsRank != rhsRank || lhsRank != 3) { +//===----------------------------------------------------------------------===// +// XeGPU_CreateDescOp +//===----------------------------------------------------------------------===// + +void CreateDescOp::build(OpBuilder &builder, OperationState &state, + TensorDescType TensorDesc, Value source, Value offsets, + uint32_t chunk_size_per_lane) { + state.addOperands(source); + state.addOperands(offsets); + state.getOrAddProperties().chunk_size_per_lane = + builder.getIntegerAttr(builder.getIntegerType(32), chunk_size_per_lane); + state.getOrAddProperties().mode = + ModeKindAttr::get(builder.getContext(), ModeKind::VC); + state.addTypes(TensorDesc); +} + +void CreateDescOp::build(OpBuilder &builder, OperationState &state, + TensorDescType TensorDesc, Value source, Value offsets, + IntegerAttr chunk_size_per_lane) { + state.addOperands(source); + state.addOperands(offsets); + if (chunk_size_per_lane) + state.getOrAddProperties().chunk_size_per_lane = + chunk_size_per_lane; + state.getOrAddProperties().mode = + ModeKindAttr::get(builder.getContext(), ModeKind::VC); + state.addTypes(TensorDesc); +} + +ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand rawOperands[2]; + llvm::ArrayRef operands(rawOperands); + llvm::SMLoc operandsLoc = parser.getCurrentLocation(); + // parse the source operand + if (parser.parseOperand(rawOperands[0])) + return failure(); + + if (parser.parseComma()) + return failure(); + + // parse the offset operand + if (parser.parseOperand(rawOperands[1])) + return failure(); + + // parse the optional attributes + if (parseOptionalAttrDict(parser, result, {"chunk_size_per_lane", "mode"})) + return failure(); + + if (parser.parseColon()) + return failure(); + + Type rawTypes[2]; + llvm::ArrayRef types(rawTypes); + if (parser.parseType(rawTypes[0])) + return failure(); + if (parser.parseComma()) + return failure(); + + if (parser.parseType(rawTypes[1])) + return failure(); + if (parser.parseArrow()) + return failure(); + + Type TensorDescRawTypes[1]; + llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); + if (parser.parseType(TensorDescRawTypes[0])) + return failure(); + + result.addTypes(TensorDescTypes); + if (parser.resolveOperands(operands, types, operandsLoc, + result.operands)) + return failure(); + return success(); +} + +// ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { +// OpAsmParser::UnresolvedOperand sourceRawOperands[1]; +// llvm::ArrayRef sourceOperands( +// sourceRawOperands); +// llvm::SMLoc sourceOperandsLoc = parser.getCurrentLocation(); +// if (parser.parseOperand(sourceRawOperands[0])) +// return failure(); + +// if (parser.parseComma()) +// return failure(); + +// OpAsmParser::UnresolvedOperand offsetsRawOperands[1]; +// llvm::ArrayRef offsetsOperands( +// offsetsRawOperands); +// llvm::SMLoc offsetsOperandsLoc = parser.getCurrentLocation(); +// if (parser.parseOperand(offsetsRawOperands[0])) +// return failure(); + +// if (parseOptionalAttrDict(parser, result, {"chunk_size_per_lane", "mode"})) +// return failure(); + +// if (parser.parseColon()) +// return failure(); + +// Type sourceRawTypes[1]; +// llvm::ArrayRef sourceTypes(sourceRawTypes); +// if (parser.parseType(sourceRawTypes[0])) +// return failure(); +// if (parser.parseComma()) +// return failure(); + +// Type offsetsRawTypes[1]; +// llvm::ArrayRef offsetsTypes(offsetsRawTypes); +// if (parser.parseType(offsetsRawTypes[0])) +// return failure(); +// if (parser.parseArrow()) +// return failure(); + +// Type TensorDescRawTypes[1]; +// llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); +// if (parser.parseType(TensorDescRawTypes[0])) +// return failure(); + +// result.addTypes(TensorDescTypes); +// if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc, +// result.operands)) +// return failure(); +// if (parser.resolveOperands(offsetsOperands, offsetsTypes, offsetsOperandsLoc, +// result.operands)) +// return failure(); +// return success(); +// } + +void CreateDescOp::print(OpAsmPrinter &printer) { + auto mode = getMode(); + bool printSep = false; + auto chunk = getChunkSizePerLane(); + auto printDefaults = printDefaultValues(); + + printer << ' '; + printer << getSource(); + printer << ","; + printer << ' '; + printer << getOffsets(); + + if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { + printer << ' ' << "{"; + } + + if (printDefaults || mode != ModeKind::SIMT) { + printer << "mode = " << mode; + printSep = true; + } + + if (printDefaults || chunk != 1) { + if (printSep) + printer << "," << ' '; + printer << "chunk_size_per_lane = " << chunk; + } + + if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { + printer << "}"; + } + + printer << ' ' << ":"; + printer << ' '; + printer << getSource().getType(); + printer << ","; + printer << ' '; + printer << getOffsets().getType(); + printer << ' ' << "->"; + printer << ' '; + printer << getTensorDesc().getType(); +} + +LogicalResult CreateDescOp::verify() { + auto mode = getMode(); + auto mapping = getTensorDesc().getType().getMapping(); + auto offsetTy = getOffsets().getType(); + auto tdescTy = getTensorDesc().getType(); + auto chunkSize = getChunkSizePerLane(); + + if (mode == ModeKind::SIMT || mapping) { + return emitOpError("CreateDescOp only support VC mode and mapping " + "attribute of TensorDesc is not expected.\n"); + } + + if (getRankOf(getSource()) > 2) return emitOpError( - "lhs and rhs rank does not match for dpas op, or their rank is not 3."); + "Expecting the source is a 1D/2D memref or pointer (uint64_t)."); + + if (!tdescTy.getScattered()) + return emitOpError( + "Expecting the presence of ScatteredAttr for tensor descriptor."); + + // Infer the TensorDesc shape + std::vector shape; + if (llvm::isa(offsetTy)) { + shape = llvm::dyn_cast(offsetTy).getShape().vec(); + if (shape.size() != 1) + return emitOpError("Expecting the offset is a 1D vector."); + } + + if (chunkSize != 1) { + shape.push_back(chunkSize); + } + + auto tdescShape = tdescTy.getShape(); + if (shape != tdescShape.vec()) { + return emitOpError("Expecting dimensions of offsets is the same as the " + "tensor descriptor, or one less than."); } return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_LoadGatherOp +//===----------------------------------------------------------------------===// +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, + Value TensorDesc, Value mask, IntegerAttr vnni_axis, + DenseI64ArrayAttr transpose, ReadCacheKindAttr l1_hint, + ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { + state.addOperands(TensorDesc); + state.addOperands(mask); + if (vnni_axis) + state.getOrAddProperties().vnni_axis = vnni_axis; + + if (transpose) + state.getOrAddProperties().transpose = transpose; + + if (l1_hint) + state.getOrAddProperties().l1_hint = l1_hint; + + if (l2_hint) + state.getOrAddProperties().l2_hint = l2_hint; + + if (l3_hint) + state.getOrAddProperties().l3_hint = l3_hint; + + state.getOrAddProperties().mode = + ModeKindAttr::get(builder.getContext(), ModeKind::VC); + state.addTypes(value); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, + Value TensorDesc, Value mask, IntegerAttr vnni_axis, + DenseI64ArrayAttr transpose, ReadCacheKind l1_hint, + ReadCacheKind l2_hint, ReadCacheKind l3_hint) { + state.addOperands(TensorDesc); + state.addOperands(mask); + if (vnni_axis) + state.getOrAddProperties().vnni_axis = vnni_axis; + + if (transpose) + state.getOrAddProperties().transpose = transpose; + + state.getOrAddProperties().l1_hint = + ReadCacheKindAttr::get(builder.getContext(), l1_hint); + state.getOrAddProperties().l2_hint = + ReadCacheKindAttr::get(builder.getContext(), l2_hint); + state.getOrAddProperties().l3_hint = + ReadCacheKindAttr::get(builder.getContext(), l3_hint); + state.getOrAddProperties().mode = + ModeKindAttr::get(builder.getContext(), ModeKind::VC); + state.addTypes(value); +} + ParseResult LoadGatherOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; llvm::ArrayRef TensorDescOperands( @@ -1278,53 +1364,48 @@ LogicalResult LoadGatherOp::verify() { return success(); } -void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, - Value TensorDesc, Value mask, IntegerAttr vnni_axis, - DenseI64ArrayAttr transpose, ReadCacheKindAttr l1_hint, - ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { + +//===----------------------------------------------------------------------===// +// XeGPU_StoreScatterOp +//===----------------------------------------------------------------------===// +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value TensorDesc, Value mask, + WriteCacheKindAttr l1_hint, + WriteCacheKindAttr l2_hint, + WriteCacheKindAttr l3_hint) { + state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); - if (vnni_axis) - state.getOrAddProperties().vnni_axis = vnni_axis; - - if (transpose) - state.getOrAddProperties().transpose = transpose; - - if (l1_hint) + if (l1_hint) { state.getOrAddProperties().l1_hint = l1_hint; - - if (l2_hint) + } + if (l2_hint) { state.getOrAddProperties().l2_hint = l2_hint; - - if (l3_hint) + } + if (l3_hint) { state.getOrAddProperties().l3_hint = l3_hint; - + } state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); - state.addTypes(value); } -void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, - Value TensorDesc, Value mask, IntegerAttr vnni_axis, - DenseI64ArrayAttr transpose, ReadCacheKind l1_hint, - ReadCacheKind l2_hint, ReadCacheKind l3_hint) { +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value TensorDesc, Value mask, + WriteCacheKind l1_hint, WriteCacheKind l2_hint, + WriteCacheKind l3_hint) { + state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); - if (vnni_axis) - state.getOrAddProperties().vnni_axis = vnni_axis; - - if (transpose) - state.getOrAddProperties().transpose = transpose; - state.getOrAddProperties().l1_hint = - ReadCacheKindAttr::get(builder.getContext(), l1_hint); + WriteCacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - ReadCacheKindAttr::get(builder.getContext(), l2_hint); + WriteCacheKindAttr::get(builder.getContext(), l2_hint); + ; state.getOrAddProperties().l3_hint = - ReadCacheKindAttr::get(builder.getContext(), l3_hint); + WriteCacheKindAttr::get(builder.getContext(), l3_hint); + ; state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); - state.addTypes(value); } ParseResult StoreScatterOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1490,42 +1571,36 @@ LogicalResult StoreScatterOp::verify() { return success(); } -void StoreScatterOp::build(OpBuilder &builder, OperationState &state, - Value value, Value TensorDesc, Value mask, - WriteCacheKindAttr l1_hint, - WriteCacheKindAttr l2_hint, - WriteCacheKindAttr l3_hint) { - state.addOperands(value); +//===----------------------------------------------------------------------===// +// XeGPU_PrefetchOp +//===----------------------------------------------------------------------===// +void PrefetchOp::build(OpBuilder &builder, OperationState &state, + Value TensorDesc, ReadCacheKindAttr l1_hint, + ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { state.addOperands(TensorDesc); - state.addOperands(mask); - if (l1_hint) { + if (l1_hint) state.getOrAddProperties().l1_hint = l1_hint; - } - if (l2_hint) { + + if (l2_hint) state.getOrAddProperties().l2_hint = l2_hint; - } - if (l3_hint) { + + if (l3_hint) state.getOrAddProperties().l3_hint = l3_hint; - } + state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); } -void StoreScatterOp::build(OpBuilder &builder, OperationState &state, - Value value, Value TensorDesc, Value mask, - WriteCacheKind l1_hint, WriteCacheKind l2_hint, - WriteCacheKind l3_hint) { - state.addOperands(value); +void PrefetchOp::build(OpBuilder &builder, OperationState &state, + Value TensorDesc, ReadCacheKind l1_hint, + ReadCacheKind l2_hint, ReadCacheKind l3_hint) { state.addOperands(TensorDesc); - state.addOperands(mask); state.getOrAddProperties().l1_hint = - WriteCacheKindAttr::get(builder.getContext(), l1_hint); + ReadCacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - WriteCacheKindAttr::get(builder.getContext(), l2_hint); - ; + ReadCacheKindAttr::get(builder.getContext(), l2_hint); state.getOrAddProperties().l3_hint = - WriteCacheKindAttr::get(builder.getContext(), l3_hint); - ; + ReadCacheKindAttr::get(builder.getContext(), l3_hint); state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); } @@ -1603,37 +1678,9 @@ LogicalResult PrefetchOp::verify() { return success(); } -void PrefetchOp::build(OpBuilder &builder, OperationState &state, - Value TensorDesc, ReadCacheKindAttr l1_hint, - ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { - state.addOperands(TensorDesc); - if (l1_hint) - state.getOrAddProperties().l1_hint = l1_hint; - - if (l2_hint) - state.getOrAddProperties().l2_hint = l2_hint; - - if (l3_hint) - state.getOrAddProperties().l3_hint = l3_hint; - - state.getOrAddProperties().mode = - ModeKindAttr::get(builder.getContext(), ModeKind::VC); -} - -void PrefetchOp::build(OpBuilder &builder, OperationState &state, - Value TensorDesc, ReadCacheKind l1_hint, - ReadCacheKind l2_hint, ReadCacheKind l3_hint) { - state.addOperands(TensorDesc); - state.getOrAddProperties().l1_hint = - ReadCacheKindAttr::get(builder.getContext(), l1_hint); - state.getOrAddProperties().l2_hint = - ReadCacheKindAttr::get(builder.getContext(), l2_hint); - state.getOrAddProperties().l3_hint = - ReadCacheKindAttr::get(builder.getContext(), l3_hint); - state.getOrAddProperties().mode = - ModeKindAttr::get(builder.getContext(), ModeKind::VC); -} - +//===----------------------------------------------------------------------===// +// XeGPU_UpdateOffsetOp +//===----------------------------------------------------------------------===// void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, Type result, Value TensorDesc, Value offsets) { state.addOperands(TensorDesc); @@ -1644,7 +1691,6 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, } LogicalResult UpdateOffsetOp::verify() { - auto mode = getMode(); if (mode != ModeKind::VC) return emitOpError("UpdateOffsetOp only work on VC mode.\n"); @@ -1672,14 +1718,35 @@ LogicalResult UpdateOffsetOp::verify() { return success(); } -LogicalResult UpdateNDOffsetOp::verify() { - // number of offsets specified must match the rank of the tensor descriptor - if (getTensorDesc().getType().getRank() != (int64_t)getOffsets().size()) { - return emitOpError("Invalid number of offsets."); +//===----------------------------------------------------------------------===// +// XeGPU_DpasOp +//===----------------------------------------------------------------------===// +LogicalResult DpasOp::verify() { + int64_t lhsRank = getLhsType().getRank(); + int64_t rhsRank = getRhsType().getRank(); + Type lhsElemType = getLhsType().getElementType(); + Type rhsElemType = getRhsType().getElementType(); + + if (lhsElemType != rhsElemType) { + return emitOpError("lhs and rhs element type does not match for dpas op"); } + + if (getAcc() && getAccType() != getResultType()) { + return emitOpError("Accumulator and Result for dpas op should have the " + "same type (both shape and element type)."); + } + + if (lhsRank != rhsRank || lhsRank != 3) { + return emitOpError( + "lhs and rhs rank does not match for dpas op, or their rank is not 3."); + } + return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_InvokeSIMDOp +//===----------------------------------------------------------------------===// void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, SymbolRefAttr callee, TypeRange results, ArgTypeKindAttr argType, ValueRange operands) { @@ -1702,13 +1769,9 @@ void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, results, argType, operands); } -LogicalResult AtomicRMWOp::verify() { - auto mode = getMode(); - if (mode != ModeKind::VC) - return emitOpError("AtomicRMWOp only work on VC mode.\n"); - return success(); -} - +//===----------------------------------------------------------------------===// +// XeGPU_AtomicRMWOp +//===----------------------------------------------------------------------===// void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, AtomicRMWKindAttr kind, Value tensorDesc, Value mask, Value value) { state.addOperands(tensorDesc); @@ -1734,6 +1797,12 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, state.addTypes(result); } +LogicalResult AtomicRMWOp::verify() { + auto mode = getMode(); + if (mode != ModeKind::VC) + return emitOpError("AtomicRMWOp only work on VC mode.\n"); + return success(); +} } // namespace xegpu } // namespace mlir diff --git a/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir b/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir index 08cfa47104622..f80df161a543a 100644 --- a/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir +++ b/mlir/test/Dialect/XeGPU/IR/atomic_rmw.mlir @@ -10,7 +10,7 @@ func.func @test_atomic_rmw(%src: ui64, %offsets : vector<16 x index>, %value : v // CHECK: xegpu.atomic_rmw // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32> - xegpu.atomic_rmw addf %1, %mask, %value {mode=vc} + xegpu.atomic_rmw #xegpu %1, %mask, %value {mode=vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1>, vector<16xf32> -> vector<16xf32> return From e30cfa86417a164528058770121a81560eb52a20 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Sat, 13 Jan 2024 01:08:26 +0000 Subject: [PATCH 4/8] unify ReachCacheAttr and WriteCacheAttr into CacheAttr and update parseOptionalAttrDict with parseOptionalAttrDictWithCustomAttrs --- .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 31 +-- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 162 ++++++------ .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 1 - mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 248 ++++++------------ 4 files changed, 179 insertions(+), 263 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index edcf863691217..ed3d9bbc77256 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -98,24 +98,18 @@ def XeGPU_MemoryScopeKind: I32EnumAttr<"MemoryScopeKind", let cppNamespace = "::mlir::xegpu"; } -def CACHE_KIND_CACHED: I32EnumAttrCase<"CACHED", 0, "cached">; -def CACHE_KIND_UNCACHED: I32EnumAttrCase<"UNCACHED", 1, "uncached">; -def CACHE_KIND_STREAMING: I32EnumAttrCase<"STREAMING", 2, "streaming">; -def CACHE_KIND_INVALIDATE: I32EnumAttrCase<"READ_INVALIDATE", 3, "read_invalidate">; -def CACHE_KIND_WRITE_BACK: I32EnumAttrCase<"WRITE_BACK", 4, "write_back">; -def CACHE_KIND_WRITE_THROUGH: I32EnumAttrCase<"WRITE_THROUGH", 5, "write_through">; - -def XeGPU_ReadCacheKind : I32EnumAttr<"ReadCacheKind", - "Cache behavior for read", - [CACHE_KIND_CACHED, CACHE_KIND_UNCACHED, - CACHE_KIND_STREAMING, CACHE_KIND_INVALIDATE]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::xegpu"; -} +def CACHE_KIND_CACHED: I32EnumAttrCase<"CACHED", 0, "cached">; // valid for read and write +def CACHE_KIND_UNCACHED: I32EnumAttrCase<"UNCACHED", 1, "uncached">; // valid for read and write +def CACHE_KIND_STREAMING: I32EnumAttrCase<"STREAMING", 2, "streaming">; // valid for read only +def CACHE_KIND_INVALIDATE: I32EnumAttrCase<"READ_INVALIDATE", 3, "read_invalidate">; // valid for read only +def CACHE_KIND_WRITE_BACK: I32EnumAttrCase<"WRITE_BACK", 4, "write_back">; // valid for write only +def CACHE_KIND_WRITE_THROUGH: I32EnumAttrCase<"WRITE_THROUGH", 5, "write_through">; // valid for write only -def XeGPU_WriteCacheKind: I32EnumAttr<"WriteCacheKind", - "Cache behavior for write", - [CACHE_KIND_UNCACHED, CACHE_KIND_STREAMING, + + +def XeGPU_CacheKind : I32EnumAttr<"CacheKind", "Cache kind", + [CACHE_KIND_CACHED, CACHE_KIND_UNCACHED, + CACHE_KIND_STREAMING, CACHE_KIND_INVALIDATE, CACHE_KIND_WRITE_BACK, CACHE_KIND_WRITE_THROUGH]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::xegpu"; @@ -124,8 +118,7 @@ def XeGPU_WriteCacheKind: I32EnumAttr<"WriteCacheKind", def XeGPU_ArgTypeAttr : EnumAttr; def XeGPU_ModeAttr : EnumAttr; def XeGPU_MemoryScopeAttr : EnumAttr; -def XeGPU_ReadCacheAttr : EnumAttr; -def XeGPU_WriteCacheAttr : EnumAttr; +def XeGPU_CacheAttr : EnumAttr; // RMW kind attribute def ATOMIC_RMW_KIND_ADDF : I32EnumAttrCase<"addf", 0>; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 6ed65286c2524..52e149d12c09d 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -157,11 +157,11 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { }]; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - OptionalAttr: $vnni_axis, + OptionalAttr: $vnni_axis, OptionalAttr: $transpose, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, DefaultValuedAttr: $mode); let results = (outs XeGPU_ValueType: $value); @@ -185,9 +185,9 @@ def XeGPU_StoreNDOp : XeGPU_Op<"store_nd", []> { let summary = "stores a n-D block register region back to memory, currently only supports 2D"; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, XeGPU_ValueType: $value, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, DefaultValuedAttr: $mode); // Format: xegpu.store_nd %3, %2 {l1_hint = write_back, l2_hint = uncached} @@ -199,9 +199,9 @@ def XeGPU_StoreNDOp : XeGPU_Op<"store_nd", []> { def XeGPU_PrefetchNDOp : XeGPU_Op<"prefetch_nd", []> { let summary = "prefetches a nD block to cache"; let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, DefaultValuedAttr: $mode); // Format: xegpu.prefetch_nd %tdesc {l1_hint = cached, l2_hint = uncached}: @@ -257,7 +257,7 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure]> { let arguments = (ins XeGPU_BaseAddrType: $source, XeGPU_OffsetType: $offsets, - DefaultValuedAttr: $chunk_size_per_lane, + DefaultValuedAttr: $chunk_size_per_lane, DefaultValuedAttr: $mode); let results = (outs XeGPU_TensorDesc:$TensorDesc); @@ -280,26 +280,26 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load"> { let arguments = (ins XeGPU_TensorDesc: $TensorDesc, XeGPU_MaskType: $mask, - OptionalAttr: $vnni_axis, - OptionalAttr: $transpose, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, + OptionalAttr: $vnni_axis, + OptionalAttr: $transpose, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, DefaultValuedAttr: $mode); let results = (outs XeGPU_ValueType: $value); let builders = [ OpBuilder<(ins "mlir::Type": $value, "mlir::Value": $TensorDesc, "mlir::Value": $mask, "mlir::IntegerAttr": $vnni_axis, CArg<"mlir::DenseI64ArrayAttr", "mlir::DenseI64ArrayAttr()">: $transpose, - CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l1_hint, - CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l2_hint, - CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l3_hint)>, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>, OpBuilder<(ins "Type": $value, "Value": $TensorDesc, "Value": $mask, "IntegerAttr": $vnni_axis, CArg<"DenseI64ArrayAttr", "DenseI64ArrayAttr()">: $transpose, - CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l1_hint, - CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l2_hint, - CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l3_hint)> + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l1_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l2_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l3_hint)> ]; let skipDefaultBuilders = 1; @@ -317,21 +317,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", []> { XeGPU_ValueType: $value, XeGPU_TensorDesc: $TensorDesc, XeGPU_MaskType: $mask, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, DefaultValuedAttr: $mode ); let builders = [ OpBuilder<(ins "Value": $value, "Value": $TensorDesc, "Value": $mask, - CArg<"xegpu::WriteCacheKindAttr", "xegpu::WriteCacheKindAttr()">: $l1_hint, - CArg<"xegpu::WriteCacheKindAttr", "xegpu::WriteCacheKindAttr()">: $l2_hint, - CArg<"xegpu::WriteCacheKindAttr", "xegpu::WriteCacheKindAttr()">: $l3_hint)>, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>, OpBuilder<(ins "Value": $value, "Value": $TensorDesc, "Value": $mask, - CArg<"xegpu::WriteCacheKind", "xegpu::WriteCacheKind::WRITE_BACK">: $l1_hint, - CArg<"xegpu::WriteCacheKind", "xegpu::WriteCacheKind::WRITE_BACK">: $l2_hint, - CArg<"xegpu::WriteCacheKind", "xegpu::WriteCacheKind::WRITE_BACK">: $l3_hint)> + CArg<"xegpu::CacheKind", "xegpu::CacheKind::WRITE_BACK">: $l1_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::WRITE_BACK">: $l2_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::WRITE_BACK">: $l3_hint)> ]; let skipDefaultBuilders = 1; @@ -341,6 +341,54 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", []> { let hasVerifier = 1; } +def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { + let summary = "prefetches a nD block to cache"; + let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + OptionalAttr: $l1_hint, + OptionalAttr: $l2_hint, + OptionalAttr: $l3_hint, + DefaultValuedAttr: $mode); + + let builders = [ + OpBuilder<(ins "Value": $TensorDesc, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>, + OpBuilder<(ins "Value": $TensorDesc, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l1_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l2_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l3_hint)> + ]; + + let skipDefaultBuilders = 1; + let hasVerifier = 1; + + // Format: xegpu.prefetch %tdesc {l1_hint = cached, l2_hint = uncached}: + // !xegpu.tensor_desc<8x16xf16> + let hasCustomAssemblyFormat = 1; +} + +def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> { + let summary = "update the offsets for the given tensor descriptor"; + let arguments = (ins XeGPU_TensorDesc: $TensorDesc, + XeGPU_OffsetType: $offsets, + DefaultValuedAttr: $mode); + let results = (outs XeGPU_TensorDesc: $result); + + let builders = [ + OpBuilder<(ins "Type": $result, "Value": $TensorDesc, "Value": $offsets)> + ]; + + let skipDefaultBuilders = 1; + + let assemblyFormat = [{ + $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? + attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) + }]; + + let hasVerifier = 1; +} + def XeGPU_DpasOp : XeGPU_Op<"dpas"> { let summary = "performs dpas computation"; let arguments = (ins @@ -374,54 +422,6 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas"> { let hasVerifier = 1; } -def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> { - let summary = "update the offsets for the given tensor descriptor"; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - XeGPU_OffsetType: $offsets, - DefaultValuedAttr: $mode); - let results = (outs XeGPU_TensorDesc: $result); - - let builders = [ - OpBuilder<(ins "Type": $result, "Value": $TensorDesc, "Value": $offsets)> - ]; - - let skipDefaultBuilders = 1; - - let assemblyFormat = [{ - $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? - attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) - }]; - - let hasVerifier = 1; -} - -def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> { - let summary = "prefetches a nD block to cache"; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - OptionalAttr: $l1_hint, - OptionalAttr: $l2_hint, - OptionalAttr: $l3_hint, - DefaultValuedAttr: $mode); - - let builders = [ - OpBuilder<(ins "Value": $TensorDesc, - CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l1_hint, - CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l2_hint, - CArg<"xegpu::ReadCacheKindAttr", "xegpu::ReadCacheKindAttr()">: $l3_hint)>, - OpBuilder<(ins "Value": $TensorDesc, - CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l1_hint, - CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l2_hint, - CArg<"xegpu::ReadCacheKind", "xegpu::ReadCacheKind::CACHED">: $l3_hint)> - ]; - - let skipDefaultBuilders = 1; - let hasVerifier = 1; - - // Format: xegpu.prefetch %tdesc {l1_hint = cached, l2_hint = uncached}: - // !xegpu.tensor_desc<8x16xf16> - let hasCustomAssemblyFormat = 1; -} - def XeGPU_InvokeSIMDOp : XeGPU_Op<"invoke_SIMD", []> { let summary = "Invoke_SIMD operation"; let description = [{ @@ -451,7 +451,7 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { XeGPU_TensorDesc:$tensorDesc, XeGPU_MaskType:$mask, Optional:$value, - DefaultValuedAttr: $mode + DefaultValuedAttr: $mode ); let results = (outs XeGPU_ValueType:$result); @@ -473,7 +473,7 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { def XeGPU_AllocNbarrierOp: XeGPU_Op<"alloc_nbarrier", []> { let summary = "allocate a specific number of named barriers."; - let arguments = (ins I32Attr: $nbarrierCount); + let arguments = (ins I64Attr: $nbarrierCount); let assemblyFormat = "$nbarrierCount attr-dict"; } diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td index 4831cb1fa18c5..b3dceff9587ad 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td @@ -15,7 +15,6 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td" include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td" // An Integer array attribute with fixed 2 elements. -def XeGPU_IntArrayAttr2: ConfinedAttr]>; def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>; def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index fca6c10aae81b..d272d10d73a03 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -23,6 +23,8 @@ #define DEBUG_TYPE "xegpu" namespace mlir { +class Token; + namespace xegpu { extern bool printDefaultValues(); @@ -85,86 +87,51 @@ static bool verifyAndInferShape(std::vector &shape, return true; } -template -static ParseResult parseCustomEnumAttr(OpAsmParser &parser, - OperationState &result, - llvm::StringRef attrKeyword) { - auto loc = parser.getCurrentLocation(); - auto attrOptional = FieldParser::parse(parser); - if (failed(attrOptional)) - return parser.emitError(loc, "invalid attribute specification"); - auto attr = - CustomEnumAttr::get(parser.getBuilder().getContext(), *attrOptional); - result.addAttribute(attrKeyword, attr); - return success(); -} - -template -static ParseResult parseBoolAndIntegerAttr(OpAsmParser &parser, - OperationState &result, - llvm::StringRef attrKeyword) { - AttrType attr; - Type ty; - - if (std::is_same::value) { - ty = parser.getBuilder().getIntegerType(1); - } else if (std::is_same::value) { - ty = parser.getBuilder().getIntegerType(32); - } else if (std::is_same::value) { - ty = Type{}; - } else { - llvm_unreachable("Unsupported Attribute Type."); - } - - if (parser.parseCustomAttributeWithFallback(attr, ty)) - return failure(); - - if (attr) - result.addAttribute(attrKeyword, attr); - return success(); -} - static ParseResult -parseOptionalAttrDict(OpAsmParser &parser, OperationState &result, - llvm::ArrayRef allowedKeywords, - bool isWrite = false) { +parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, NamedAttrList &attributes) { // no optional attributes, return success if (failed(parser.parseOptionalLBrace())) return success(); + llvm::SmallDenseSet seenKeys; auto parseElt = [&]() -> ParseResult { + // The name of an attribute can either be a keyword, or a string. + // as compared to mlir::parseOptionalAttrList, the cases of using + // TOken::bare_identifier and Token::inttype as key maybe not handlered + std::string nameId; auto loc = parser.getCurrentLocation(); - llvm::StringRef nameId; - if (parser.parseOptionalKeyword(&nameId, allowedKeywords)) - return parser.emitError(loc, "invalid attribute keyword: ") - << nameId << ".\n"; - - if (parser.parseEqual()) - return failure(); - - if (nameId == "l1_hint" || nameId == "l2_hint" || nameId == "l3_hint") { - if (isWrite) - return parseCustomEnumAttr( - parser, result, nameId); - else - return parseCustomEnumAttr( - parser, result, nameId); + if (parser.parseOptionalKeywordOrString(&nameId)) + return parser.emitError(loc, "invalid attribute name: ") << nameId << ".\n"; + + if (nameId.empty()) + return parser.emitError(loc, "expected valid attribute name"); + + if (!seenKeys.insert(nameId).second) + return parser.emitError(loc, "duplicate key '") + << nameId << "' in dictionary attribute."; + + // Lazy load a dialect in the context if there is a possible namespace. + auto splitName = StringRef(nameId).split('.'); + if (!splitName.second.empty()) + parser.getContext()->getOrLoadDialect(splitName.first); + + // Try to parse the '=' for the attribute value. + if (parser.parseEqual()) { + // If there is no '=', it is treated as a unit attribute. + attributes.append(nameId, parser.getBuilder().getUnitAttr()); + return success(); } if (nameId == "mode") { - return parseCustomEnumAttr(parser, result, nameId); + ModeKindAttr attr; + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, attributes); + } else if (nameId == "l1_hint" || nameId == "l2_hint" || nameId == "l3_hint") { + CacheKindAttr attr; + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, attributes); + } else { + Attribute attr; + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, attributes); } - - if (nameId == "chunk_size_per_lane" || nameId == "vnni_axis") - return parseBoolAndIntegerAttr(parser, result, nameId); - - if (nameId == "boundary_check") - return parseBoolAndIntegerAttr(parser, result, nameId); - - if (nameId == "transpose") - return parseBoolAndIntegerAttr(parser, result, nameId); - - llvm_unreachable("Unsupported attribute keyword."); }; if (parser.parseCommaSeparatedList(parseElt)) @@ -314,7 +281,7 @@ ParseResult CreateNdDescOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } - if (parseOptionalAttrDict(parser, result, {"boundary_check", "mode"})) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -540,9 +507,7 @@ ParseResult LoadNDOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDict( - parser, result, - {"mode", "vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"})) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -715,8 +680,7 @@ ParseResult StoreNDOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDict(parser, result, - {"mode", "l1_hint", "l2_hint", "l3_hint"}, true)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -852,8 +816,7 @@ ParseResult PrefetchNDOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDict(parser, result, - {"mode", "l1_hint", "l2_hint", "l3_hint"})) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -951,7 +914,7 @@ ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // parse the optional attributes - if (parseOptionalAttrDict(parser, result, {"chunk_size_per_lane", "mode"})) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -981,59 +944,6 @@ ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -// ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { -// OpAsmParser::UnresolvedOperand sourceRawOperands[1]; -// llvm::ArrayRef sourceOperands( -// sourceRawOperands); -// llvm::SMLoc sourceOperandsLoc = parser.getCurrentLocation(); -// if (parser.parseOperand(sourceRawOperands[0])) -// return failure(); - -// if (parser.parseComma()) -// return failure(); - -// OpAsmParser::UnresolvedOperand offsetsRawOperands[1]; -// llvm::ArrayRef offsetsOperands( -// offsetsRawOperands); -// llvm::SMLoc offsetsOperandsLoc = parser.getCurrentLocation(); -// if (parser.parseOperand(offsetsRawOperands[0])) -// return failure(); - -// if (parseOptionalAttrDict(parser, result, {"chunk_size_per_lane", "mode"})) -// return failure(); - -// if (parser.parseColon()) -// return failure(); - -// Type sourceRawTypes[1]; -// llvm::ArrayRef sourceTypes(sourceRawTypes); -// if (parser.parseType(sourceRawTypes[0])) -// return failure(); -// if (parser.parseComma()) -// return failure(); - -// Type offsetsRawTypes[1]; -// llvm::ArrayRef offsetsTypes(offsetsRawTypes); -// if (parser.parseType(offsetsRawTypes[0])) -// return failure(); -// if (parser.parseArrow()) -// return failure(); - -// Type TensorDescRawTypes[1]; -// llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); -// if (parser.parseType(TensorDescRawTypes[0])) -// return failure(); - -// result.addTypes(TensorDescTypes); -// if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc, -// result.operands)) -// return failure(); -// if (parser.resolveOperands(offsetsOperands, offsetsTypes, offsetsOperandsLoc, -// result.operands)) -// return failure(); -// return success(); -// } - void CreateDescOp::print(OpAsmPrinter &printer) { auto mode = getMode(); bool printSep = false; @@ -1122,8 +1032,8 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, Value TensorDesc, Value mask, IntegerAttr vnni_axis, - DenseI64ArrayAttr transpose, ReadCacheKindAttr l1_hint, - ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { + DenseI64ArrayAttr transpose, CacheKindAttr l1_hint, + CacheKindAttr l2_hint, CacheKindAttr l3_hint) { state.addOperands(TensorDesc); state.addOperands(mask); if (vnni_axis) @@ -1148,8 +1058,8 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, Value TensorDesc, Value mask, IntegerAttr vnni_axis, - DenseI64ArrayAttr transpose, ReadCacheKind l1_hint, - ReadCacheKind l2_hint, ReadCacheKind l3_hint) { + DenseI64ArrayAttr transpose, CacheKind l1_hint, + CacheKind l2_hint, CacheKind l3_hint) { state.addOperands(TensorDesc); state.addOperands(mask); if (vnni_axis) @@ -1159,11 +1069,11 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, state.getOrAddProperties().transpose = transpose; state.getOrAddProperties().l1_hint = - ReadCacheKindAttr::get(builder.getContext(), l1_hint); + CacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - ReadCacheKindAttr::get(builder.getContext(), l2_hint); + CacheKindAttr::get(builder.getContext(), l2_hint); state.getOrAddProperties().l3_hint = - ReadCacheKindAttr::get(builder.getContext(), l3_hint); + CacheKindAttr::get(builder.getContext(), l3_hint); state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); state.addTypes(value); @@ -1196,9 +1106,7 @@ ParseResult LoadGatherOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(maskRawOperands[0])) return failure(); - if (parseOptionalAttrDict( - parser, result, - {"mode", "vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"})) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -1370,9 +1278,9 @@ LogicalResult LoadGatherOp::verify() { //===----------------------------------------------------------------------===// void StoreScatterOp::build(OpBuilder &builder, OperationState &state, Value value, Value TensorDesc, Value mask, - WriteCacheKindAttr l1_hint, - WriteCacheKindAttr l2_hint, - WriteCacheKindAttr l3_hint) { + CacheKindAttr l1_hint, + CacheKindAttr l2_hint, + CacheKindAttr l3_hint) { state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); @@ -1391,18 +1299,18 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, void StoreScatterOp::build(OpBuilder &builder, OperationState &state, Value value, Value TensorDesc, Value mask, - WriteCacheKind l1_hint, WriteCacheKind l2_hint, - WriteCacheKind l3_hint) { + CacheKind l1_hint, CacheKind l2_hint, + CacheKind l3_hint) { state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); state.getOrAddProperties().l1_hint = - WriteCacheKindAttr::get(builder.getContext(), l1_hint); + CacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - WriteCacheKindAttr::get(builder.getContext(), l2_hint); + CacheKindAttr::get(builder.getContext(), l2_hint); ; state.getOrAddProperties().l3_hint = - WriteCacheKindAttr::get(builder.getContext(), l3_hint); + CacheKindAttr::get(builder.getContext(), l3_hint); ; state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); @@ -1450,8 +1358,7 @@ ParseResult StoreScatterOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(maskRawOperands[0])) return failure(); - if (parseOptionalAttrDict(parser, result, - {"mode", "l1_hint", "l2_hint", "l3_hint"}, true)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -1575,8 +1482,8 @@ LogicalResult StoreScatterOp::verify() { // XeGPU_PrefetchOp //===----------------------------------------------------------------------===// void PrefetchOp::build(OpBuilder &builder, OperationState &state, - Value TensorDesc, ReadCacheKindAttr l1_hint, - ReadCacheKindAttr l2_hint, ReadCacheKindAttr l3_hint) { + Value TensorDesc, CacheKindAttr l1_hint, + CacheKindAttr l2_hint, CacheKindAttr l3_hint) { state.addOperands(TensorDesc); if (l1_hint) state.getOrAddProperties().l1_hint = l1_hint; @@ -1592,15 +1499,15 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, } void PrefetchOp::build(OpBuilder &builder, OperationState &state, - Value TensorDesc, ReadCacheKind l1_hint, - ReadCacheKind l2_hint, ReadCacheKind l3_hint) { + Value TensorDesc, CacheKind l1_hint, + CacheKind l2_hint, CacheKind l3_hint) { state.addOperands(TensorDesc); state.getOrAddProperties().l1_hint = - ReadCacheKindAttr::get(builder.getContext(), l1_hint); + CacheKindAttr::get(builder.getContext(), l1_hint); state.getOrAddProperties().l2_hint = - ReadCacheKindAttr::get(builder.getContext(), l2_hint); + CacheKindAttr::get(builder.getContext(), l2_hint); state.getOrAddProperties().l3_hint = - ReadCacheKindAttr::get(builder.getContext(), l3_hint); + CacheKindAttr::get(builder.getContext(), l3_hint); state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); } @@ -1617,8 +1524,7 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDict(parser, result, - {"mode", "l1_hint", "l2_hint", "l3_hint"})) + if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) return failure(); if (parser.parseColon()) @@ -1666,12 +1572,30 @@ LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDesc().getType(); auto mapping = tdescTy.getMapping(); - if (tdescTy.getScattered()) + auto isValidHint = [&](CacheKindAttr attr) -> bool { + if (!attr) return true; + auto kind = attr.getValue(); + return kind == CacheKind::CACHED || + kind == CacheKind::UNCACHED || + kind == CacheKind::STREAMING || + kind == CacheKind::READ_INVALIDATE; + }; + + if (!isValidHint(getL1HintAttr())) + return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + + if (!isValidHint(getL2HintAttr())) + return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + + if (!isValidHint(getL3HintAttr())) + return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + + if (!tdescTy.getScattered()) return emitOpError("Invalid TensorDesc. PrefetchOp only works on " "TensorDescs with ScatteredAttr."); if (mode != ModeKind::VC || mapping) { - return emitOpError("PrefetchOp only supports VC mode. and mapping " + return emitOpError("PrefetchOp only supports VC mode, and mapping " "attribute of TensorDesc is not expected.\n"); } From d3da5b2df8ae4e688d72fd820f5d42d520b1f798 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 16 Jan 2024 01:02:42 +0000 Subject: [PATCH 5/8] sync from PVC --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 62 +-- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 455 +++++++++++++++++- 2 files changed, 463 insertions(+), 54 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 52e149d12c09d..846101030ca04 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -289,18 +289,19 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load"> { let results = (outs XeGPU_ValueType: $value); let builders = [ - OpBuilder<(ins "mlir::Type": $value, "mlir::Value": $TensorDesc, "mlir::Value": $mask, "mlir::IntegerAttr": $vnni_axis, - CArg<"mlir::DenseI64ArrayAttr", "mlir::DenseI64ArrayAttr()">: $transpose, - CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint, - CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint, - CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>, - - OpBuilder<(ins "Type": $value, "Value": $TensorDesc, "Value": $mask, "IntegerAttr": $vnni_axis, - CArg<"DenseI64ArrayAttr", "DenseI64ArrayAttr()">: $transpose, - CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l1_hint, - CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l2_hint, - CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l3_hint)> - + OpBuilder<(ins "mlir::Type": $value, "mlir::Value": $TensorDesc, + "mlir::Value": $mask, "mlir::IntegerAttr": $vnni_axis, + CArg<"mlir::DenseI64ArrayAttr", "mlir::DenseI64ArrayAttr()">: $transpose, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l1_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l2_hint, + CArg<"xegpu::CacheKindAttr", "xegpu::CacheKindAttr()">: $l3_hint)>, + + OpBuilder<(ins "mlir::Type": $value, "mlir::Value": $TensorDesc, + "mlir::Value": $mask, "mlir::IntegerAttr": $vnni_axis, + CArg<"DenseI64ArrayAttr", "DenseI64ArrayAttr()">: $transpose, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l1_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l2_hint, + CArg<"xegpu::CacheKind", "xegpu::CacheKind::CACHED">: $l3_hint)> ]; let skipDefaultBuilders = 1; @@ -380,11 +381,12 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> { ]; let skipDefaultBuilders = 1; + let hasCustomAssemblyFormat = 1; - let assemblyFormat = [{ - $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? - attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) - }]; + // let assemblyFormat = [{ + // $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? + // attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) + // }]; let hasVerifier = 1; } @@ -398,10 +400,12 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas"> { DefaultValuedAttr: $mode ); let results = (outs XeGPU_Vector2DType: $result); - let assemblyFormat = [{ - $lhs `,` $rhs (`,` $acc^)? (` ``{` `mode` `=` $mode^ `}`)? attr-dict `:` - qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result)) - }]; +// let assemblyFormat = [{ +// $lhs `,` $rhs (`,` $acc^)? (` ``{` `mode` `=` $mode^ `}`)? attr-dict `:` +// qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result)) +// }]; + + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ VectorType getLhsType() { @@ -455,9 +459,10 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { ); let results = (outs XeGPU_ValueType:$result); - let assemblyFormat = [{ - $kind $tensorDesc `,` $mask (`,` $value^)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:` qualified(type(operands)) `->` type($result) - }]; + let hasCustomAssemblyFormat = 1; +// let assemblyFormat = [{ +// $kind $tensorDesc `,` $mask (`,` $value^)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:` qualified(type(operands)) `->` type($result) +// }]; let builders = [ OpBuilder<(ins "Type": $result, "xegpu::AtomicRMWKindAttr": $kind, @@ -486,11 +491,12 @@ def XeGPU_CreateNbarrierOp: XeGPU_Op<"create_nbarrier", []> { I8Attr: $num_consumers, DefaultValuedAttr: $mode); let results = (outs XeGPU_Nbarrier: $result); - let assemblyFormat = [{ - $nbarrier_id `,` $nbarrier_role - attr-dict `:` `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` - `->` qualified(type($result)) - }]; + let hasCustomAssemblyFormat = 1; +// let assemblyFormat = [{ +// $nbarrier_id `,` $nbarrier_role +// attr-dict `:` `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` +// `->` qualified(type($result)) +// }]; } def XeGPU_NbarrierArriveOp: XeGPU_Op<"nbarrier_arrive", []> { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index d272d10d73a03..2883a5db2a734 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -88,7 +88,7 @@ static bool verifyAndInferShape(std::vector &shape, } static ParseResult -parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, NamedAttrList &attributes) { +parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result) { // no optional attributes, return success if (failed(parser.parseOptionalLBrace())) return success(); @@ -118,19 +118,24 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, NamedAttrList &attribu // Try to parse the '=' for the attribute value. if (parser.parseEqual()) { // If there is no '=', it is treated as a unit attribute. - attributes.append(nameId, parser.getBuilder().getUnitAttr()); + result.addAttribute(nameId, parser.getBuilder().getUnitAttr()); return success(); } + // handle xegpu specific attributes if (nameId == "mode") { ModeKindAttr attr; - return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, attributes); + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); } else if (nameId == "l1_hint" || nameId == "l2_hint" || nameId == "l3_hint") { CacheKindAttr attr; - return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, attributes); + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); + } else if (nameId == "transpose") { + DenseI64ArrayAttr attr; + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); } else { + // for generic attribtes Attribute attr; - return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, attributes); + return parser.parseAttribute(attr, nameId, result.attributes); } }; @@ -281,7 +286,7 @@ ParseResult CreateNdDescOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (parser.parseColon()) @@ -507,7 +512,7 @@ ParseResult LoadNDOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (parser.parseColon()) @@ -680,7 +685,7 @@ ParseResult StoreNDOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (parser.parseColon()) @@ -816,7 +821,7 @@ ParseResult PrefetchNDOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (parser.parseColon()) @@ -872,7 +877,6 @@ LogicalResult UpdateNDOffsetOp::verify() { //===----------------------------------------------------------------------===// // XeGPU_CreateDescOp //===----------------------------------------------------------------------===// - void CreateDescOp::build(OpBuilder &builder, OperationState &state, TensorDescType TensorDesc, Value source, Value offsets, uint32_t chunk_size_per_lane) { @@ -914,7 +918,7 @@ ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // parse the optional attributes - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (parser.parseColon()) @@ -1106,7 +1110,7 @@ ParseResult LoadGatherOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(maskRawOperands[0])) return failure(); - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (parser.parseColon()) @@ -1243,13 +1247,12 @@ LogicalResult LoadGatherOp::verify() { "attribute of TensorDesc is not expected.\n"); } - if (getTranspose()) { + if (getTransposeAttr()) { auto trans = getTranspose().value(); - if (tdescShape.size() >= trans.size()) - transpose(trans, tdescShape); - else - emitWarning("Invalid transpose attr. It is ignored."); - } + if (tdescShape.size() < trans.size()) + return emitWarning("Invalid transpose attr. It is ignored."); + transpose(trans, tdescShape); + } if (getVnniAxis()) { auto axis = getVnniAxis().value(); @@ -1284,15 +1287,12 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); - if (l1_hint) { + if (l1_hint) state.getOrAddProperties().l1_hint = l1_hint; - } - if (l2_hint) { + if (l2_hint) state.getOrAddProperties().l2_hint = l2_hint; - } - if (l3_hint) { + if (l3_hint) state.getOrAddProperties().l3_hint = l3_hint; - } state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); } @@ -1358,7 +1358,7 @@ ParseResult StoreScatterOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(maskRawOperands[0])) return failure(); - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (parser.parseColon()) @@ -1524,8 +1524,14 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(TensorDescRawOperands[0])) return failure(); - if (parseOptionalAttrDictWithCustomAttrs(parser, result.attributes)) + auto loc = parser.getCurrentLocation(); + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + if (parser.parseColon()) return failure(); @@ -1614,6 +1620,83 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, state.addTypes(result); } +ParseResult UpdateOffsetOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand RawOperands[2]; + llvm::ArrayRef Operands(RawOperands); + + Type resultRawTypes[1]; + llvm::ArrayRef resultTypes(resultRawTypes); + + Type RawTypes[2]; + llvm::ArrayRef Types(Types); + + auto OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(RawOperands[0])) + return failure(); + if (parser.parseComma()) + return failure(); + + if (parser.parseOperand(RawOperands[1])) + return failure(); + + auto AttrLoc = parser.getCurrentLocation(); + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) + return failure(); + + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(AttrLoc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + + if (parser.parseColon()) + return failure(); + + if (parser.parseType(RawTypes[0])) + return failure(); + + if (parser.parseComma()) + return failure(); + + if (parser.parseType(RawTypes[1])) + return failure(); + if (parser.parseArrow()) + return failure(); + + if (parser.parseType(resultRawTypes[0])) + return failure(); + result.addTypes(resultTypes); + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) + return failure(); + return success(); +} + +void UpdateOffsetOp::print(OpAsmPrinter &printer) { + auto mode =getMode(); + auto printDefaults = printDefaultValues(); + + printer << ' '; + printer << getTensorDesc(); + printer << ","; + printer << ' '; + printer << getOffsets(); + + llvm::SmallVector elidedAttrs; + if (!printDefaults) { + if (mode == ModeKind::SIMT) + elidedAttrs.push_back("mode"); + } + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + printer << ' ' << ":"; + printer << ' '; + printer << getTensorDesc().getType(); + printer << ","; + printer << ' '; + printer << getOffsets().getType(); + printer << ' ' << "->"; + printer << ' '; + printer << getResult().getType(); +} + LogicalResult UpdateOffsetOp::verify() { auto mode = getMode(); if (mode != ModeKind::VC) @@ -1645,6 +1728,129 @@ LogicalResult UpdateOffsetOp::verify() { //===----------------------------------------------------------------------===// // XeGPU_DpasOp //===----------------------------------------------------------------------===// +ParseResult DpasOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand lhsRawOperands[1]; + llvm::ArrayRef lhsOperands(lhsRawOperands); + llvm::SMLoc lhsOperandsLoc; + OpAsmParser::UnresolvedOperand rhsRawOperands[1]; + llvm::ArrayRef rhsOperands(rhsRawOperands); + llvm::SMLoc rhsOperandsLoc; + llvm::SmallVector accOperands; + llvm::SMLoc accOperandsLoc; + Type lhsRawTypes[1]; + llvm::ArrayRef lhsTypes(lhsRawTypes); + Type rhsRawTypes[1]; + llvm::ArrayRef rhsTypes(rhsRawTypes); + llvm::SmallVector accTypes; + Type resultRawTypes[1]; + llvm::ArrayRef resultTypes(resultRawTypes); + + lhsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(lhsRawOperands[0])) + return failure(); + + if (parser.parseComma()) + return failure(); + + rhsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(rhsRawOperands[0])) + return failure(); + + // parse optional acc operand + if (succeeded(parser.parseOptionalComma())) { + accOperandsLoc = parser.getCurrentLocation(); + OpAsmParser::UnresolvedOperand operand; + OptionalParseResult parseResult = parser.parseOptionalOperand(operand); + if (parseResult.has_value()) { + if (failed(*parseResult)) + return failure(); + accOperands.push_back(operand); + } + } + + auto loc = parser.getCurrentLocation(); + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) + return failure(); + + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + + if (parser.parseColon()) + return failure(); + + if (parser.parseType(lhsRawTypes[0])) + return failure(); + + if (parser.parseComma()) + return failure(); + + if (parser.parseType(rhsRawTypes[0])) + return failure(); + + // parse type for optional acc + if (succeeded(parser.parseOptionalComma())) { + Type optionalType; + OptionalParseResult parseResult = parser.parseOptionalType(optionalType); + if (parseResult.has_value()) { + if (failed(*parseResult)) + return failure(); + accTypes.push_back(optionalType); + } + + } + + if (parser.parseArrow()) + return failure(); + + if (parser.parseType(resultRawTypes[0])) + return failure(); + result.addTypes(resultTypes); + if (parser.resolveOperands(lhsOperands, lhsTypes, lhsOperandsLoc, result.operands)) + return failure(); + if (parser.resolveOperands(rhsOperands, rhsTypes, rhsOperandsLoc, result.operands)) + return failure(); + if (parser.resolveOperands(accOperands, accTypes, accOperandsLoc, result.operands)) + return failure(); + return success(); +} + +void DpasOp::print(OpAsmPrinter &printer) { + auto mode = getMode(); + auto printDefaults = printDefaultValues(); + + printer << ' '; + printer << getLhs(); + printer << ","; + printer << ' '; + printer << getRhs(); + if (Value value = getAcc()) + printer << ", " << value; + + llvm::SmallVector elidedAttrs; + if (!printDefaults) { + if (mode == ModeKind::SIMT) + elidedAttrs.push_back("mode"); + } + + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + printer << ' ' << ":"; + printer << ' '; + printer << getLhs().getType(); + printer << ","; + printer << ' '; + printer << getRhs().getType(); + if (getAcc()) { + printer << ","; + printer << ' '; + printer << llvm::ArrayRef(getAcc().getType()); + } + printer << ' ' << "->"; + printer << ' '; + printer << getResult().getType(); +} + LogicalResult DpasOp::verify() { int64_t lhsRank = getLhsType().getRank(); int64_t rhsRank = getRhsType().getRank(); @@ -1721,12 +1927,209 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, state.addTypes(result); } +ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { + xegpu::AtomicRMWKindAttr kindAttr; + OpAsmParser::UnresolvedOperand tensorDescRawOperands[1]; + llvm::ArrayRef tensorDescOperands(tensorDescRawOperands); + llvm::SMLoc tensorDescOperandsLoc; + OpAsmParser::UnresolvedOperand maskRawOperands[1]; + llvm::ArrayRef maskOperands(maskRawOperands); llvm::SMLoc maskOperandsLoc; + llvm::SmallVector valueOperands; + llvm::SMLoc valueOperandsLoc; + xegpu::ModeKindAttr modeAttr; + llvm::SmallVector allOperandTypes; + Type resultRawTypes[1]; + llvm::ArrayRef resultTypes(resultRawTypes); + + if (parser.parseCustomAttributeWithFallback(kindAttr, Type{})) { + return failure(); + } + if (kindAttr) result.getOrAddProperties().kind = kindAttr; + + tensorDescOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(tensorDescRawOperands[0])) + return failure(); + if (parser.parseComma()) + return failure(); + + maskOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(maskRawOperands[0])) + return failure(); + + if (succeeded(parser.parseOptionalComma())) { + valueOperandsLoc = parser.getCurrentLocation(); + OpAsmParser::UnresolvedOperand operand; + OptionalParseResult parseResult = parser.parseOptionalOperand(operand); + if (parseResult.has_value()) { + if (failed(*parseResult)) + return failure(); + valueOperands.push_back(operand); + } + } + + auto loc = parser.getCurrentLocation();(void)loc; + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) + return failure(); + + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + + if (parser.parseColon()) + return failure(); + + if (parser.parseTypeList(allOperandTypes)) + return failure(); + if (parser.parseArrow()) + return failure(); + + Type type; + if (parser.parseCustomTypeWithFallback(type)) + return failure(); + resultRawTypes[0] = type; + + result.addTypes(resultTypes); + if (parser.resolveOperands(llvm::concat(tensorDescOperands, maskOperands, valueOperands), allOperandTypes, parser.getNameLoc(), result.operands)) + return failure(); + return success(); +} + +void AtomicRMWOp::print(OpAsmPrinter &printer) { + auto mode = getMode(); + auto printDefaults = printDefaultValues(); + + printer << ' '; + printer.printStrippedAttrOrType(getKindAttr()); + printer << ' '; + printer << getTensorDesc(); + printer << ","; + printer << ' '; + printer << getMask(); + if (Value value = getValue()) + printer << ", " << value; + + llvm::SmallVector elidedAttrs; + if (!printDefaults) { + if (mode == ModeKind::SIMT) + elidedAttrs.push_back("mode"); + } + + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + printer << ' ' << ":"; + printer << ' '; + printer << getOperation()->getOperandTypes(); + printer << ' ' << "->"; + printer << ' '; + + auto type = getResult().getType(); + if (auto validType = llvm::dyn_cast(type)) + printer.printStrippedAttrOrType(validType); + else + printer << type; + +} + LogicalResult AtomicRMWOp::verify() { auto mode = getMode(); if (mode != ModeKind::VC) return emitOpError("AtomicRMWOp only work on VC mode.\n"); return success(); } + +//===----------------------------------------------------------------------===// +// XeGPU_CreateNbarrierOp +//===----------------------------------------------------------------------===// +ParseResult CreateNbarrierOp::parse(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand nbarrier_idRawOperands[1]; + llvm::ArrayRef nbarrier_idOperands(nbarrier_idRawOperands); + llvm::SMLoc nbarrier_idOperandsLoc; + OpAsmParser::UnresolvedOperand nbarrier_roleRawOperands[1]; + llvm::ArrayRef nbarrier_roleOperands(nbarrier_roleRawOperands); + llvm::SMLoc nbarrier_roleOperandsLoc; + Type nbarrier_idRawTypes[1]; + llvm::ArrayRef nbarrier_idTypes(nbarrier_idRawTypes); + Type nbarrier_roleRawTypes[1]; + llvm::ArrayRef nbarrier_roleTypes(nbarrier_roleRawTypes); + Type resultRawTypes[1]; + llvm::ArrayRef resultTypes(resultRawTypes); + + nbarrier_idOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(nbarrier_idRawOperands[0])) + return failure(); + if (parser.parseComma()) + return failure(); + + nbarrier_roleOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(nbarrier_roleRawOperands[0])) + return failure(); + + auto loc = parser.getCurrentLocation(); + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) + return failure(); + + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + + if (parser.parseColon()) + return failure(); + if (parser.parseLParen()) + return failure(); + + if (parser.parseType(nbarrier_idRawTypes[0])) + return failure(); + if (parser.parseComma()) + return failure(); + + if (parser.parseType(nbarrier_roleRawTypes[0])) + return failure(); + if (parser.parseRParen()) + return failure(); + if (parser.parseArrow()) + return failure(); + + if (parser.parseType(resultRawTypes[0])) + return failure(); + result.addTypes(resultTypes); + if (parser.resolveOperands(nbarrier_idOperands, nbarrier_idTypes, nbarrier_idOperandsLoc, result.operands)) + return failure(); + if (parser.resolveOperands(nbarrier_roleOperands, nbarrier_roleTypes, nbarrier_roleOperandsLoc, result.operands)) + return failure(); + return success(); +} + +void CreateNbarrierOp::print(OpAsmPrinter &printer) { + auto mode = getMode(); + auto printDefaults = printDefaultValues(); + + printer << ' '; + printer << getNbarrierId(); + printer << ","; + printer << ' '; + printer << getNbarrierRole(); + llvm::SmallVector elidedAttrs; + if (!printDefaults) { + if (mode == ModeKind::SIMT) + elidedAttrs.push_back("mode"); + } + + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + printer << ' ' << ":"; + printer << ' ' << "("; + printer << getNbarrierId().getType(); + printer << ","; + printer << ' '; + printer << getNbarrierRole().getType(); + printer << ")"; + printer << ' ' << "->"; + printer << ' '; + printer << getResult().getType(); +} + + + } // namespace xegpu } // namespace mlir From 5375573fe1d7f10d5bc3ae30cbfb63a94ac6cf1c Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 17 Jan 2024 01:50:51 +0000 Subject: [PATCH 6/8] sync from pvc --- .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 36 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 848 +++++++----------- mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir | 46 +- .../Dialect/XeGPU/IR/create_tdesc_vc.mlir | 15 +- .../test/Dialect/XeGPU/IR/load_gather_vc.mlir | 22 +- 5 files changed, 346 insertions(+), 621 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 846101030ca04..766590f6a3f87 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -158,10 +158,10 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { let arguments = (ins XeGPU_TensorDesc: $TensorDesc, OptionalAttr: $vnni_axis, - OptionalAttr: $transpose, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint, + OptionalAttr: $transpose, DefaultValuedAttr: $mode); let results = (outs XeGPU_ValueType: $value); @@ -183,8 +183,8 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> { def XeGPU_StoreNDOp : XeGPU_Op<"store_nd", []> { let summary = "stores a n-D block register region back to memory, currently only supports 2D"; - let arguments = (ins XeGPU_TensorDesc: $TensorDesc, - XeGPU_ValueType: $value, + let arguments = (ins XeGPU_ValueType: $value, + XeGPU_TensorDesc: $TensorDesc, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint, @@ -219,11 +219,7 @@ def XeGPU_UpdateNDOffsetOp : XeGPU_Op<"update_nd_offset", []> { let results = (outs XeGPU_TensorDesc: $result); - let assemblyFormat = [{ - $TensorDesc `,` (`[` $offsets^ `]`)? (`{` `mode` `=` $mode^ `}`)? - attr-dict `:` qualified(type($TensorDesc)) `->` qualified(type($result)) - }]; - + let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } @@ -382,12 +378,6 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset", []> { let skipDefaultBuilders = 1; let hasCustomAssemblyFormat = 1; - - // let assemblyFormat = [{ - // $TensorDesc `,` $offsets (`{` `mode` `=` $mode^ `}`)? - // attr-dict `:` qualified(type($TensorDesc)) `,` qualified(type($offsets)) `->` qualified(type($result)) - // }]; - let hasVerifier = 1; } @@ -400,11 +390,6 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas"> { DefaultValuedAttr: $mode ); let results = (outs XeGPU_Vector2DType: $result); -// let assemblyFormat = [{ -// $lhs `,` $rhs (`,` $acc^)? (` ``{` `mode` `=` $mode^ `}`)? attr-dict `:` -// qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result)) -// }]; - let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ @@ -420,7 +405,9 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas"> { return ::llvm::cast(getAcc().getType()); } - VectorType getResultType() { return getResult().getType(); } + VectorType getResultType() { + return getResult().getType(); + } }]; let hasVerifier = 1; @@ -460,9 +447,6 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { let results = (outs XeGPU_ValueType:$result); let hasCustomAssemblyFormat = 1; -// let assemblyFormat = [{ -// $kind $tensorDesc `,` $mask (`,` $value^)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:` qualified(type(operands)) `->` type($result) -// }]; let builders = [ OpBuilder<(ins "Type": $result, "xegpu::AtomicRMWKindAttr": $kind, @@ -475,7 +459,6 @@ def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", []> { let hasVerifier = 1; } - def XeGPU_AllocNbarrierOp: XeGPU_Op<"alloc_nbarrier", []> { let summary = "allocate a specific number of named barriers."; let arguments = (ins I64Attr: $nbarrierCount); @@ -492,11 +475,6 @@ def XeGPU_CreateNbarrierOp: XeGPU_Op<"create_nbarrier", []> { DefaultValuedAttr: $mode); let results = (outs XeGPU_Nbarrier: $result); let hasCustomAssemblyFormat = 1; -// let assemblyFormat = [{ -// $nbarrier_id `,` $nbarrier_role -// attr-dict `:` `(` qualified(type($nbarrier_id)) `,` qualified(type($nbarrier_role)) `)` -// `->` qualified(type($result)) -// }]; } def XeGPU_NbarrierArriveOp: XeGPU_Op<"nbarrier_arrive", []> { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 2883a5db2a734..5703c08430a30 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -93,7 +93,7 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result if (failed(parser.parseOptionalLBrace())) return success(); - llvm::SmallDenseSet seenKeys; + llvm::SmallDenseSet seenKeys; auto parseElt = [&]() -> ParseResult { // The name of an attribute can either be a keyword, or a string. // as compared to mlir::parseOptionalAttrList, the cases of using @@ -122,7 +122,7 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result return success(); } - // handle xegpu specific attributes + // for xegpu specific attributes if (nameId == "mode") { ModeKindAttr attr; return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); @@ -130,10 +130,26 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result CacheKindAttr attr; return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); } else if (nameId == "transpose") { - DenseI64ArrayAttr attr; - return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); + // in form of [4, 5] + if (succeeded(parser.parseOptionalLSquare())) { + Attribute attr; + // handle empty list case + if (succeeded(parser.parseOptionalRSquare())) { + attr = DenseI64ArrayAttr::get(parser.getContext(), {}); + } else { + attr = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); + if (failed(parser.parseRSquare())) + return failure(); + } + if (!attr) return failure(); + result.addAttribute(nameId, attr); + return success(); + } else { + // in form of array + DenseI64ArrayAttr attr; + return parser.parseAttribute(attr, nameId, result.attributes); + } } else { - // for generic attribtes Attribute attr; return parser.parseAttribute(attr, nameId, result.attributes); } @@ -145,30 +161,6 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result return parser.parseRBrace(); } -template -static void printCacheHintAttrs(OpAsmPrinter &printer, T op, bool printSep) { - if (op.getL1HintAttr()) { - if (printSep) - printer << ", "; - printer << "l1_hint = " << op.getL1Hint().value(); - printSep = true; - } - - if (op.getL2HintAttr()) { - if (printSep) - printer << ", "; - printer << "l2_hint = " << op.getL2Hint().value(); - printSep = true; - } - - if (op.getL3HintAttr()) { - if (printSep) - printer << ", "; - printer << "l3_hint = " << op.getL3Hint().value(); - } -} - - //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -241,11 +233,9 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, ParseResult CreateNdDescOp::parse(OpAsmParser &parser, OperationState &result) { // parse the source operand - OpAsmParser::UnresolvedOperand sourceRawOperands[1]; - llvm::ArrayRef sourceOperands( - sourceRawOperands); + llvm::SmallVector sourceOperands(1); llvm::SMLoc sourceOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(sourceRawOperands[0])) + if (parser.parseOperand(sourceOperands[0])) return failure(); // parse the offset operand, in format of [x, y] @@ -286,23 +276,27 @@ ParseResult CreateNdDescOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + if (parser.parseColon()) return failure(); - Type sourceRawTypes[1]; - llvm::ArrayRef sourceTypes(sourceRawTypes); - if (parser.parseType(sourceRawTypes[0])) + llvm::SmallVector sourceTypes(1); + if (parser.parseType(sourceTypes[0])) return failure(); if (parser.parseArrow()) return failure(); - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); - if (parser.parseType(TensorDescRawTypes[0])) + llvm::SmallVector TensorDescTypes(1); + if (parser.parseType(TensorDescTypes[0])) return failure(); result.addAttribute("operandSegmentSizes", parser.getBuilder().getDenseI32ArrayAttr( @@ -310,11 +304,12 @@ ParseResult CreateNdDescOp::parse(OpAsmParser &parser, OperationState &result) { static_cast(shapeOperands.size()), static_cast(stridesOperands.size())})); - Type indexType = parser.getBuilder().getIndexType(); result.addTypes(TensorDescTypes); if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc, result.operands)) return failure(); + + Type indexType = parser.getBuilder().getIndexType(); if (parser.resolveOperands(offsetsOperands, indexType, offsetsOperandsLoc, result.operands)) return failure(); @@ -349,11 +344,13 @@ void CreateNdDescOp::print(OpAsmPrinter &printer) { printer << "]"; } - if (printDefaults || mode != ModeKind::SIMT) { - printer << ' ' << "{"; - printer << "mode = " << mode; - printer << "}"; - } + llvm::SmallVector elidedAttrs; + elidedAttrs.push_back("static_offsets"); + elidedAttrs.push_back("operandSegmentSizes"); + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -452,8 +449,8 @@ llvm::SmallVector CreateNdDescOp::getShape() { return shape; } - emitOpError("The shape information is missing."); - llvm_unreachable("Unexpected error in CreateNdDescOp.\n"); + llvm_unreachable("Unexpected error in CreateNdDescOp. " + "The shape information is missing.\n"); } llvm::ArrayRef CreateNdDescOp::getStaticStrides() { @@ -505,35 +502,36 @@ llvm::ArrayRef CreateNdDescOp::getTensorDescShape() { //===----------------------------------------------------------------------===// ParseResult LoadNDOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; - llvm::ArrayRef TensorDescOperands( - TensorDescRawOperands); - llvm::SMLoc TensorDescOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(TensorDescRawOperands[0])) + llvm::SmallVector Operands(1); + llvm::SMLoc OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(Operands[0])) return failure(); + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + if (parser.parseColon()) return failure(); - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); - if (parser.parseType(TensorDescRawTypes[0])) + llvm::SmallVector Types(1); + if (parser.parseType(Types[0])) return failure(); if (parser.parseArrow()) return failure(); - Type valueRawTypes[1]; - llvm::ArrayRef valueTypes(valueRawTypes); - if (parser.parseType(valueRawTypes[0])) + llvm::SmallVector valueTypes(1); + if (parser.parseType(valueTypes[0])) return failure(); result.addTypes(valueTypes); - if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, - TensorDescOperandsLoc, result.operands)) + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); return success(); @@ -541,42 +539,16 @@ ParseResult LoadNDOp::parse(OpAsmParser &parser, OperationState &result) { void LoadNDOp::print(OpAsmPrinter &printer) { auto mode = getMode(); - bool printSep = false; auto printDefaults = printDefaultValues(); - auto numAttrs = (*this)->getAttrs().size(); printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << ' ' << "{"; - } - - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << mode; - printSep = true; - } - - if (getVnniAxisAttr()) { - if (printSep) - printer << "," << ' '; - printer << "vnni_axis = " << getVnniAxis().value(); - printSep = true; - } - - if (getTransposeAttr()) { - if (printSep) - printer << "," << ' '; - printer << "transpose = "; - getTransposeAttr().print(printer); - printSep = true; - } - - printCacheHintAttrs(printer, *this, printSep); - - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << "}"; - } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -668,48 +640,37 @@ LogicalResult LoadNDOp::verify() { // XeGPU_StoreNDOp //===----------------------------------------------------------------------===// ParseResult StoreNDOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand valueRawOperands[1]; - llvm::ArrayRef valueOperands( - valueRawOperands); - llvm::SMLoc valueOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(valueRawOperands[0])) + llvm::SmallVector Operands(2); + llvm::SMLoc OperandsLoc = parser.getCurrentLocation(); + // parse value + if (parser.parseOperand(Operands[0])) return failure(); if (parser.parseComma()) return failure(); - OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; - llvm::ArrayRef TensorDescOperands( - TensorDescRawOperands); - llvm::SMLoc TensorDescOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(TensorDescRawOperands[0])) + // parse TensorDesc + if (parser.parseOperand(Operands[1])) return failure(); + // parse optional attributes + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); - if (parser.parseColon()) - return failure(); - - Type valueRawTypes[1]; - llvm::ArrayRef valueTypes(valueRawTypes); - if (parser.parseType(valueRawTypes[0])) - return failure(); - - if (parser.parseComma()) + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); - if (parser.parseType(TensorDescRawTypes[0])) + if (parser.parseColon()) return failure(); - if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, - TensorDescOperandsLoc, result.operands)) + llvm::SmallVector Types; + if (parser.parseTypeList(Types)) return failure(); - if (parser.resolveOperands(valueOperands, valueTypes, valueOperandsLoc, - result.operands)) + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); return success(); @@ -717,9 +678,7 @@ ParseResult StoreNDOp::parse(OpAsmParser &parser, OperationState &result) { void StoreNDOp::print(OpAsmPrinter &printer) { auto mode = getMode(); - [[maybe_unused]] bool printSep = false; auto printDefaults = printDefaultValues(); - auto numAttrs = (*this)->getAttrs().size(); printer << ' '; printer << getValue(); @@ -727,20 +686,10 @@ void StoreNDOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << ' ' << "{"; - } - - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << getMode(); - printSep = true; - } - - printCacheHintAttrs(printer, *this, true); - - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << "}"; - } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -810,24 +759,27 @@ LogicalResult StoreNDOp::verify() { // XeGPU_PrefetchNDOp //===----------------------------------------------------------------------===// ParseResult PrefetchNDOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; - llvm::ArrayRef TensorDescOperands( - TensorDescRawOperands); + llvm::SmallVector TensorDescOperands(1); + llvm::SmallVector TensorDescTypes(1); llvm::SMLoc TensorDescOperandsLoc; - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); TensorDescOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(TensorDescRawOperands[0])) + if (parser.parseOperand(TensorDescOperands[0])) return failure(); + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + if (parser.parseColon()) return failure(); - if (parser.parseType(TensorDescRawTypes[0])) + if (parser.parseType(TensorDescTypes[0])) return failure(); if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, TensorDescOperandsLoc, result.operands)) @@ -837,35 +789,102 @@ ParseResult PrefetchNDOp::parse(OpAsmParser &parser, OperationState &result) { void PrefetchNDOp::print(OpAsmPrinter &printer) { auto mode = getMode(); - [[maybe_unused]] bool printSep = false; auto printDefaults = printDefaultValues(); - auto numAttrs = (*this)->getAttrs().size(); + printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << ' ' << "{"; - } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + + printer << ' ' << ":"; + printer << ' '; + printer << getTensorDesc().getType(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_UpdateNDOffsetOp +//===----------------------------------------------------------------------===// +ParseResult UpdateNDOffsetOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SmallVector TensorDescOperands(1); + llvm::SmallVector offsetsOperands; + llvm::SmallVector TensorDescTypes(1); + llvm::SmallVector resultTypes(1); + llvm::SMLoc TensorDescOperandsLoc; + llvm::SMLoc offsetsOperandsLoc; - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << getMode(); - printSep = true; + TensorDescOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(TensorDescOperands[0])) + return failure(); + if (parser.parseComma()) + return failure(); + + // parse offsets, e.g., [x, y] + if (succeeded(parser.parseOptionalLSquare())) { + offsetsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(offsetsOperands)) + return failure(); + if (parser.parseRSquare()) + return failure(); } - printCacheHintAttrs(printer, *this, true); + if (parseOptionalAttrDictWithCustomAttrs(parser, result)) + return failure(); + + auto loc = parser.getCurrentLocation(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + + if (parser.parseColon()) + return failure(); + + if (parser.parseType(TensorDescTypes[0])) + return failure(); + if (parser.parseArrow()) + return failure(); + + if (parser.parseType(resultTypes[0])) + return failure(); + result.addTypes(resultTypes); + if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, TensorDescOperandsLoc, result.operands)) + return failure(); + + Type indexType = parser.getBuilder().getIndexType(); + if (parser.resolveOperands(offsetsOperands, indexType, offsetsOperandsLoc, result.operands)) + return failure(); + return success(); +} + +void UpdateNDOffsetOp::print(OpAsmPrinter &printer) { + auto mode = getMode(); + auto printDefaults = printDefaultValues(); - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << "}"; + printer << ' '; + printer << getTensorDesc(); + printer << ","; + if (!getOffsets().empty()) { + printer << ' ' << "["; + printer << getOffsets(); + printer << "]"; } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + printer << ' ' << ":"; printer << ' '; printer << getTensorDesc().getType(); + printer << ' ' << "->"; + printer << ' '; + printer << getResult().getType(); } -//===----------------------------------------------------------------------===// -// XeGPU_UpdateNDOffsetOp -//===----------------------------------------------------------------------===// LogicalResult UpdateNDOffsetOp::verify() { // number of offsets specified must match the rank of the tensor descriptor if (getTensorDesc().getType().getRank() != (int64_t)getOffsets().size()) { @@ -903,46 +922,49 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state, } ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand rawOperands[2]; - llvm::ArrayRef operands(rawOperands); + llvm::SmallVector Operands(2); + llvm::SmallVector Types(2); llvm::SMLoc operandsLoc = parser.getCurrentLocation(); // parse the source operand - if (parser.parseOperand(rawOperands[0])) + if (parser.parseOperand(Operands[0])) return failure(); if (parser.parseComma()) return failure(); // parse the offset operand - if (parser.parseOperand(rawOperands[1])) + if (parser.parseOperand(Operands[1])) return failure(); // parse the optional attributes + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + if (parser.parseColon()) return failure(); - Type rawTypes[2]; - llvm::ArrayRef types(rawTypes); - if (parser.parseType(rawTypes[0])) + if (parser.parseType(Types[0])) return failure(); if (parser.parseComma()) return failure(); - if (parser.parseType(rawTypes[1])) + if (parser.parseType(Types[1])) return failure(); if (parser.parseArrow()) return failure(); - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); - if (parser.parseType(TensorDescRawTypes[0])) + llvm::SmallVector TensorDescTypes(1); + if (parser.parseType(TensorDescTypes[0])) return failure(); result.addTypes(TensorDescTypes); - if (parser.resolveOperands(operands, types, operandsLoc, + if (parser.resolveOperands(Operands, Types, operandsLoc, result.operands)) return failure(); return success(); @@ -950,7 +972,6 @@ ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { void CreateDescOp::print(OpAsmPrinter &printer) { auto mode = getMode(); - bool printSep = false; auto chunk = getChunkSizePerLane(); auto printDefaults = printDefaultValues(); @@ -960,24 +981,14 @@ void CreateDescOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getOffsets(); - if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { - printer << ' ' << "{"; - } - - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << mode; - printSep = true; - } - - if (printDefaults || chunk != 1) { - if (printSep) - printer << "," << ' '; - printer << "chunk_size_per_lane = " << chunk; - } - - if (printDefaults || mode != ModeKind::SIMT || chunk != 1) { - printer << "}"; + llvm::SmallVector elidedAttrs; + if (!printDefaults) { + if (mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + if (chunk == 1) + elidedAttrs.push_back("chunk_size_per_lane"); } + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -1084,70 +1095,58 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, Type value, } ParseResult LoadGatherOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; - llvm::ArrayRef TensorDescOperands( - TensorDescRawOperands); - llvm::SMLoc TensorDescOperandsLoc; - OpAsmParser::UnresolvedOperand maskRawOperands[1]; - llvm::ArrayRef maskOperands(maskRawOperands); - llvm::SMLoc maskOperandsLoc; + llvm::SmallVector Operands(2); + llvm::SmallVector Types(2); + llvm::SmallVector valueTypes(1); + llvm::SMLoc OperandsLoc; - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); - Type maskRawTypes[1]; - llvm::ArrayRef maskTypes(maskRawTypes); - Type valueRawTypes[1]; - llvm::ArrayRef valueTypes(valueRawTypes); - - TensorDescOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(TensorDescRawOperands[0])) + OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(Operands[0])) return failure(); if (parser.parseComma()) return failure(); - maskOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(maskRawOperands[0])) + if (parser.parseOperand(Operands[1])) return failure(); + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); if (parser.parseColon()) return failure(); - if (parser.parseType(TensorDescRawTypes[0])) + if (parser.parseType(Types[0])) return failure(); if (parser.parseComma()) return failure(); - if (parser.parseType(maskRawTypes[0])) + if (parser.parseType(Types[1])) return failure(); if (parser.parseArrow()) return failure(); - if (parser.parseType(valueRawTypes[0])) + if (parser.parseType(valueTypes[0])) return failure(); result.addTypes(valueTypes); - if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, - TensorDescOperandsLoc, result.operands)) + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); - if (parser.resolveOperands(maskOperands, maskTypes, maskOperandsLoc, - result.operands)) - return failure(); return success(); } void LoadGatherOp::print(OpAsmPrinter &printer) { auto mode = getMode(); - bool printSep = false; auto printDefaults = printDefaultValues(); - auto numAttrs = (*this)->getAttrs().size(); printer << ' '; printer << getTensorDesc(); @@ -1155,35 +1154,10 @@ void LoadGatherOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getMask(); - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << ' ' << "{"; - } - - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << getMode(); - printSep = true; - } - - if (getVnniAxisAttr()) { - if (printSep) - printer << "," << ' '; - printer << "vnni_axis = " << getVnniAxis().value(); - printSep = true; - } - - if (getTransposeAttr()) { - if (printSep) - printer << "," << ' '; - printer << "transpose = "; - getTransposeAttr().print(printer); - printSep = true; - } - - printCacheHintAttrs(printer, *this, printSep); - - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << "}"; - } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -1317,87 +1291,37 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, } ParseResult StoreScatterOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; - llvm::ArrayRef TensorDescOperands( - TensorDescRawOperands); - llvm::SMLoc TensorDescOperandsLoc; - - OpAsmParser::UnresolvedOperand valueRawOperands[1]; - llvm::ArrayRef valueOperands( - valueRawOperands); - llvm::SMLoc valueOperandsLoc; - - OpAsmParser::UnresolvedOperand maskRawOperands[1]; - llvm::ArrayRef maskOperands(maskRawOperands); - llvm::SMLoc maskOperandsLoc; - - Type valueRawTypes[1]; - llvm::ArrayRef valueTypes(valueRawTypes); - - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); - - Type maskRawTypes[1]; - llvm::ArrayRef maskTypes(maskRawTypes); + llvm::SmallVector Operands; + llvm::SmallVector Types; + llvm::SMLoc OperandsLoc; - valueOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(valueRawOperands[0])) - return failure(); - - if (parser.parseComma()) - return failure(); - - TensorDescOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(TensorDescRawOperands[0])) - return failure(); - - if (parser.parseComma()) - return failure(); - - maskOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(maskRawOperands[0])) + OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(Operands)) return failure(); + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) + return failure(); + if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) return failure(); - - if (parser.parseType(valueRawTypes[0])) - return failure(); - - if (parser.parseComma()) - return failure(); - - if (parser.parseType(TensorDescRawTypes[0])) - return failure(); - - if (parser.parseComma()) - return failure(); - - if (parser.parseType(maskRawTypes[0])) - return failure(); - - if (parser.resolveOperands(valueOperands, valueTypes, valueOperandsLoc, - result.operands)) + + if (parser.parseTypeList(Types)) return failure(); - if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, - TensorDescOperandsLoc, result.operands)) + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); - if (parser.resolveOperands(maskOperands, maskTypes, maskOperandsLoc, - result.operands)) - return failure(); return success(); } void StoreScatterOp::print(OpAsmPrinter &printer) { auto mode = getMode(); - bool printSep = false; auto printDefaults = printDefaultValues(); - auto numAttrs = (*this)->getAttrs().size(); printer << ' '; printer << getValue(); @@ -1408,20 +1332,10 @@ void StoreScatterOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getMask(); - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << ' ' << "{"; - } - - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << getMode(); - printSep = true; - } - - printCacheHintAttrs(printer, *this, printSep); - - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << "}"; - } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -1513,15 +1427,12 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, } ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand TensorDescRawOperands[1]; - llvm::ArrayRef TensorDescOperands( - TensorDescRawOperands); + llvm::SmallVector TensorDescOperands(1); + llvm::SmallVector TensorDescTypes(1); llvm::SMLoc TensorDescOperandsLoc; - Type TensorDescRawTypes[1]; - llvm::ArrayRef TensorDescTypes(TensorDescRawTypes); TensorDescOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(TensorDescRawOperands[0])) + if (parser.parseOperand(TensorDescOperands[0])) return failure(); auto loc = parser.getCurrentLocation(); @@ -1532,12 +1443,12 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { }))) return failure(); - if (parser.parseColon()) return failure(); - if (parser.parseType(TensorDescRawTypes[0])) + if (parser.parseType(TensorDescTypes[0])) return failure(); + if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, TensorDescOperandsLoc, result.operands)) return failure(); @@ -1546,27 +1457,15 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { void PrefetchOp::print(OpAsmPrinter &printer) { auto mode = getMode(); - bool printSep = false; auto printDefaults = printDefaultValues(); - auto numAttrs = (*this)->getAttrs().size(); printer << ' '; printer << getTensorDesc(); - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << ' ' << "{"; - } - - if (printDefaults || mode != ModeKind::SIMT) { - printer << "mode = " << getMode(); - printSep = true; - } - - printCacheHintAttrs(printer, *this, printSep); - - if (printDefaults || mode != ModeKind::SIMT || numAttrs > 1) { - printer << "}"; - } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -1621,50 +1520,35 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, } ParseResult UpdateOffsetOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand RawOperands[2]; - llvm::ArrayRef Operands(RawOperands); - - Type resultRawTypes[1]; - llvm::ArrayRef resultTypes(resultRawTypes); - - Type RawTypes[2]; - llvm::ArrayRef Types(Types); + llvm::SmallVector Operands; + llvm::SmallVector Types; auto OperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(RawOperands[0])) - return failure(); - if (parser.parseComma()) - return failure(); - - if (parser.parseOperand(RawOperands[1])) + if (parser.parseOperandList(Operands)) return failure(); - auto AttrLoc = parser.getCurrentLocation(); + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); - if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(AttrLoc) << "'" << result.name.getStringRef() << "' op "; + return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; }))) return failure(); if (parser.parseColon()) return failure(); - if (parser.parseType(RawTypes[0])) + if (parser.parseTypeList(Types)) return failure(); - if (parser.parseComma()) - return failure(); - - if (parser.parseType(RawTypes[1])) - return failure(); if (parser.parseArrow()) return failure(); - if (parser.parseType(resultRawTypes[0])) + llvm::SmallVector resultTypes(1); + if (parser.parseType(resultTypes[0])) return failure(); result.addTypes(resultTypes); + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); return success(); @@ -1680,11 +1564,9 @@ void UpdateOffsetOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getOffsets(); - llvm::SmallVector elidedAttrs; - if (!printDefaults) { - if (mode == ModeKind::SIMT) - elidedAttrs.push_back("mode"); - } + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -1729,49 +1611,16 @@ LogicalResult UpdateOffsetOp::verify() { // XeGPU_DpasOp //===----------------------------------------------------------------------===// ParseResult DpasOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand lhsRawOperands[1]; - llvm::ArrayRef lhsOperands(lhsRawOperands); - llvm::SMLoc lhsOperandsLoc; - OpAsmParser::UnresolvedOperand rhsRawOperands[1]; - llvm::ArrayRef rhsOperands(rhsRawOperands); - llvm::SMLoc rhsOperandsLoc; - llvm::SmallVector accOperands; - llvm::SMLoc accOperandsLoc; - Type lhsRawTypes[1]; - llvm::ArrayRef lhsTypes(lhsRawTypes); - Type rhsRawTypes[1]; - llvm::ArrayRef rhsTypes(rhsRawTypes); - llvm::SmallVector accTypes; - Type resultRawTypes[1]; - llvm::ArrayRef resultTypes(resultRawTypes); - - lhsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(lhsRawOperands[0])) - return failure(); - - if (parser.parseComma()) - return failure(); + llvm::SmallVector Operands; + llvm::SmallVector Types; - rhsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(rhsRawOperands[0])) + llvm::SMLoc OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(Operands)) return failure(); - // parse optional acc operand - if (succeeded(parser.parseOptionalComma())) { - accOperandsLoc = parser.getCurrentLocation(); - OpAsmParser::UnresolvedOperand operand; - OptionalParseResult parseResult = parser.parseOptionalOperand(operand); - if (parseResult.has_value()) { - if (failed(*parseResult)) - return failure(); - accOperands.push_back(operand); - } - } - auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); - if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; }))) @@ -1780,39 +1629,20 @@ ParseResult DpasOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseColon()) return failure(); - if (parser.parseType(lhsRawTypes[0])) - return failure(); - - if (parser.parseComma()) - return failure(); - - if (parser.parseType(rhsRawTypes[0])) + if (parser.parseTypeList(Types)) return failure(); - // parse type for optional acc - if (succeeded(parser.parseOptionalComma())) { - Type optionalType; - OptionalParseResult parseResult = parser.parseOptionalType(optionalType); - if (parseResult.has_value()) { - if (failed(*parseResult)) - return failure(); - accTypes.push_back(optionalType); - } - - } - if (parser.parseArrow()) return failure(); - if (parser.parseType(resultRawTypes[0])) + llvm::SmallVector resultTypes(1); + if (parser.parseType(resultTypes[0])) return failure(); result.addTypes(resultTypes); - if (parser.resolveOperands(lhsOperands, lhsTypes, lhsOperandsLoc, result.operands)) - return failure(); - if (parser.resolveOperands(rhsOperands, rhsTypes, rhsOperandsLoc, result.operands)) - return failure(); - if (parser.resolveOperands(accOperands, accTypes, accOperandsLoc, result.operands)) + + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); + return success(); } @@ -1829,10 +1659,8 @@ void DpasOp::print(OpAsmPrinter &printer) { printer << ", " << value; llvm::SmallVector elidedAttrs; - if (!printDefaults) { - if (mode == ModeKind::SIMT) - elidedAttrs.push_back("mode"); - } + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -1857,19 +1685,16 @@ LogicalResult DpasOp::verify() { Type lhsElemType = getLhsType().getElementType(); Type rhsElemType = getRhsType().getElementType(); - if (lhsElemType != rhsElemType) { - return emitOpError("lhs and rhs element type does not match for dpas op"); - } + if (lhsElemType != rhsElemType) + return emitOpError("lhs and rhs element type does not match for dpas op"); - if (getAcc() && getAccType() != getResultType()) { + if (getAcc() && getAccType() != getResultType()) return emitOpError("Accumulator and Result for dpas op should have the " - "same type (both shape and element type)."); - } + "same type (both shape and element type)."); - if (lhsRank != rhsRank || lhsRank != 3) { + if (lhsRank != rhsRank || lhsRank != 3) return emitOpError( "lhs and rhs rank does not match for dpas op, or their rank is not 3."); - } return success(); } @@ -1928,49 +1753,25 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, } ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { - xegpu::AtomicRMWKindAttr kindAttr; - OpAsmParser::UnresolvedOperand tensorDescRawOperands[1]; - llvm::ArrayRef tensorDescOperands(tensorDescRawOperands); - llvm::SMLoc tensorDescOperandsLoc; - OpAsmParser::UnresolvedOperand maskRawOperands[1]; - llvm::ArrayRef maskOperands(maskRawOperands); llvm::SMLoc maskOperandsLoc; - llvm::SmallVector valueOperands; - llvm::SMLoc valueOperandsLoc; - xegpu::ModeKindAttr modeAttr; - llvm::SmallVector allOperandTypes; - Type resultRawTypes[1]; - llvm::ArrayRef resultTypes(resultRawTypes); - - if (parser.parseCustomAttributeWithFallback(kindAttr, Type{})) { - return failure(); - } - if (kindAttr) result.getOrAddProperties().kind = kindAttr; + llvm::SmallVector Operands; + llvm::SmallVector Types; + llvm::SMLoc OperandsLoc; + + llvm::SmallVector resultTypes(1); - tensorDescOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(tensorDescRawOperands[0])) - return failure(); - if (parser.parseComma()) + xegpu::AtomicRMWKindAttr kindAttr; + if (parser.parseCustomAttributeWithFallback(kindAttr, Type{})) return failure(); + if (kindAttr) + result.getOrAddProperties().kind = kindAttr; - maskOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(maskRawOperands[0])) + OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(Operands)) return failure(); - if (succeeded(parser.parseOptionalComma())) { - valueOperandsLoc = parser.getCurrentLocation(); - OpAsmParser::UnresolvedOperand operand; - OptionalParseResult parseResult = parser.parseOptionalOperand(operand); - if (parseResult.has_value()) { - if (failed(*parseResult)) - return failure(); - valueOperands.push_back(operand); - } - } - - auto loc = parser.getCurrentLocation();(void)loc; + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); - if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; }))) @@ -1979,18 +1780,17 @@ ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseColon()) return failure(); - if (parser.parseTypeList(allOperandTypes)) + if (parser.parseTypeList(Types)) return failure(); + if (parser.parseArrow()) return failure(); - Type type; - if (parser.parseCustomTypeWithFallback(type)) - return failure(); - resultRawTypes[0] = type; - + if (parser.parseCustomTypeWithFallback(resultTypes[0])) + return failure(); result.addTypes(resultTypes); - if (parser.resolveOperands(llvm::concat(tensorDescOperands, maskOperands, valueOperands), allOperandTypes, parser.getNameLoc(), result.operands)) + + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); return success(); } @@ -1999,7 +1799,6 @@ void AtomicRMWOp::print(OpAsmPrinter &printer) { auto mode = getMode(); auto printDefaults = printDefaultValues(); - printer << ' '; printer.printStrippedAttrOrType(getKindAttr()); printer << ' '; printer << getTensorDesc(); @@ -2010,24 +1809,17 @@ void AtomicRMWOp::print(OpAsmPrinter &printer) { printer << ", " << value; llvm::SmallVector elidedAttrs; - if (!printDefaults) { - if (mode == ModeKind::SIMT) - elidedAttrs.push_back("mode"); - } + elidedAttrs.push_back("kind"); + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; printer << getOperation()->getOperandTypes(); printer << ' ' << "->"; - printer << ' '; - - auto type = getResult().getType(); - if (auto validType = llvm::dyn_cast(type)) - printer.printStrippedAttrOrType(validType); - else - printer << type; - + printer << ' '; + printer << getResult().getType(); } LogicalResult AtomicRMWOp::verify() { @@ -2041,29 +1833,14 @@ LogicalResult AtomicRMWOp::verify() { // XeGPU_CreateNbarrierOp //===----------------------------------------------------------------------===// ParseResult CreateNbarrierOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand nbarrier_idRawOperands[1]; - llvm::ArrayRef nbarrier_idOperands(nbarrier_idRawOperands); - llvm::SMLoc nbarrier_idOperandsLoc; - OpAsmParser::UnresolvedOperand nbarrier_roleRawOperands[1]; - llvm::ArrayRef nbarrier_roleOperands(nbarrier_roleRawOperands); - llvm::SMLoc nbarrier_roleOperandsLoc; - Type nbarrier_idRawTypes[1]; - llvm::ArrayRef nbarrier_idTypes(nbarrier_idRawTypes); - Type nbarrier_roleRawTypes[1]; - llvm::ArrayRef nbarrier_roleTypes(nbarrier_roleRawTypes); - Type resultRawTypes[1]; - llvm::ArrayRef resultTypes(resultRawTypes); - - nbarrier_idOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(nbarrier_idRawOperands[0])) - return failure(); - if (parser.parseComma()) - return failure(); + llvm::SmallVector Operands; + llvm::SmallVector Types; + llvm::SMLoc OperandsLoc; - nbarrier_roleOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperand(nbarrier_roleRawOperands[0])) + OperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(Operands)) return failure(); - + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); @@ -2075,27 +1852,25 @@ ParseResult CreateNbarrierOp::parse(OpAsmParser &parser, OperationState &result) if (parser.parseColon()) return failure(); + if (parser.parseLParen()) return failure(); - if (parser.parseType(nbarrier_idRawTypes[0])) - return failure(); - if (parser.parseComma()) + if (parser.parseTypeList(Types)) return failure(); - if (parser.parseType(nbarrier_roleRawTypes[0])) - return failure(); if (parser.parseRParen()) return failure(); + if (parser.parseArrow()) return failure(); - if (parser.parseType(resultRawTypes[0])) + llvm::SmallVector resultTypes(1); + if (parser.parseType(resultTypes[0])) return failure(); + result.addTypes(resultTypes); - if (parser.resolveOperands(nbarrier_idOperands, nbarrier_idTypes, nbarrier_idOperandsLoc, result.operands)) - return failure(); - if (parser.resolveOperands(nbarrier_roleOperands, nbarrier_roleTypes, nbarrier_roleOperandsLoc, result.operands)) + if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) return failure(); return success(); } @@ -2103,18 +1878,15 @@ ParseResult CreateNbarrierOp::parse(OpAsmParser &parser, OperationState &result) void CreateNbarrierOp::print(OpAsmPrinter &printer) { auto mode = getMode(); auto printDefaults = printDefaultValues(); + llvm::SmallVector elidedAttrs; + if (!printDefaults && mode == xegpu::ModeKind::SIMT) + elidedAttrs.push_back("mode"); printer << ' '; printer << getNbarrierId(); printer << ","; printer << ' '; printer << getNbarrierRole(); - llvm::SmallVector elidedAttrs; - if (!printDefaults) { - if (mode == ModeKind::SIMT) - elidedAttrs.push_back("mode"); - } - printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' ' << "("; @@ -2128,8 +1900,6 @@ void CreateNbarrierOp::print(OpAsmPrinter &printer) { printer << getResult().getType(); } - - } // namespace xegpu } // namespace mlir diff --git a/mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir index 24aa836f80d44..64a6f547fbd29 100644 --- a/mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/IR/XeGPUOps.mlir @@ -9,12 +9,12 @@ func.func @test_create_nd_tdesc_vc(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc + // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu} // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - // CHECK: xegpu.create_nd_tdesc + // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu} // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> %2 = xegpu.create_nd_tdesc %src[2, 4] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> @@ -24,8 +24,7 @@ func.func @test_create_nd_tdesc_vc(%src: memref<24x32xf32>) { // CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) { func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc, chunk_size_per_lane = 2} + // CHECK: xegpu.create_tdesc {{.*}} {chunk_size_per_lane = 2 : i64, mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr> %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr> @@ -34,14 +33,12 @@ func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { // CHECK-LABEL: func @test_load_nd_vc({{.*}}) { func.func @test_load_nd_vc(%src: memref<24x32xf16>, %x : index, %y : index) { - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg1, %arg2] + // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd {{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu, vnni_axis = 0 : i64} // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> %2 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> return @@ -52,71 +49,62 @@ func.func @test_store_nd_vc(%src: memref<24x32xf16>, %dst: memref<24x32xf16>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd {{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> %3 = xegpu.load_nd %1 {mode=vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> - // CHECK: xegpu.store_nd - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK: xegpu.store_nd {{%[0-9], %[0-9]}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> return } - // CHECK-LABEL: func @test_dpas_vc({{.*}}) { func.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) { - // CHECK: xegpu.dpas + // CHECK: xegpu.dpas {{.*}} {mode = #xegpu} // CHECK-SAME: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> %1 = xegpu.dpas %a, %b {mode = vc}: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> return } - // CHECK-LABEL: func @test_update_nd_offset_vc({{.*}}) { func.func @test_update_nd_offset_vc(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu} // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd {{%[0-9]}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - // CHECK: xegpu.update_nd_offset + // CHECK: xegpu.update_nd_offset {{%[0-9]}}, [{{%c[0-9], %c[0-9]}}] {mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %3 = xegpu.update_nd_offset %1, [%c0, %c1]: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.update_nd_offset %1, [%c0, %c1] {mode = vc}: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> return } // CHECK-LABEL: func @test_prefetch_nd_vc({{.*}}) { func.func @test_prefetch_nd_vc(%src: memref<24x32xf16>, %x : index, %y : index) { - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc {{.*}} {mode = #xegpu} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // CHECK: xegpu.prefetch_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> + // CHECK: xegpu.prefetch_nd {{%[0-9]}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> xegpu.prefetch_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> return } diff --git a/mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir b/mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir index 9cbc8b4f7d94b..245d862e302a7 100644 --- a/mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/create_tdesc_vc.mlir @@ -7,8 +7,7 @@ // CHECK-LABEL: func @test_create_tdesc_vc({{.*}}) { func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc}: ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> return @@ -16,8 +15,7 @@ func.func @test_create_tdesc_vc(%src: ui64, %offsets : vector<16 x index>) { // CHECK-LABEL: func @test_create_tdesc_vc_2({{.*}}) { func.func @test_create_tdesc_vc_2(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr> %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr> @@ -26,8 +24,7 @@ func.func @test_create_tdesc_vc_2(%src: ui64, %offsets : vector<16 x index>) { // CHECK-LABEL: func @test_create_tdesc_vc_3({{.*}}) { func.func @test_create_tdesc_vc_3(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 - // CHECK-SAME: {mode = vc, chunk_size_per_lane = 8} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 8 : i64, mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> @@ -36,8 +33,7 @@ func.func @test_create_tdesc_vc_3(%src: ui64, %offsets : vector<16 x index>) { // CHECK-LABEL: func @test_create_tdesc_vc_4({{.*}}) { func.func @test_create_tdesc_vc_4(%src: ui64, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc %arg0, %arg1 - // CHECK-SAME: {mode = vc, chunk_size_per_lane = 2} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 2 : i64, mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr> %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr> @@ -47,8 +43,7 @@ func.func @test_create_tdesc_vc_4(%src: ui64, %offsets : vector<16 x index>) { // CHECK-LABEL: func @test_create_tdesc_vc_5({{.*}}) { func.func @test_create_tdesc_vc_5(%src: memref, %offsets : vector<16 x index>) { - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc, chunk_size_per_lane = 2} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 2 : i64, mode = #xegpu} // CHECK-SAME: memref, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr> %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 2} : memref, vector<16 x index> -> !xegpu.tensor_desc<16x2xf32, #xegpu.tdesc_attr> diff --git a/mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir b/mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir index 2689d401dc316..a3cb890483e63 100644 --- a/mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/load_gather_vc.mlir @@ -8,13 +8,11 @@ // CHECK-LABEL: func @test_load_gather_vc({{.*}}) { func.func @test_load_gather_vc(%src: ui64, %offsets : vector<16xindex>) { %0 = arith.constant dense<1>: vector<16xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc}: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - // CHECK: xegpu.load - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> @@ -24,31 +22,27 @@ func.func @test_load_gather_vc(%src: ui64, %offsets : vector<16xindex>) { // CHECK-LABEL: func @test_load_gather_vc_2({{.*}}) { func.func @test_load_gather_vc_2(%src: ui64, %offsets : vector<16xindex>) { %0 = arith.constant dense<1>: vector<16x8xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc, chunk_size_per_lane = 8} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {chunk_size_per_lane = 8 : i64, mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 8} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scattered> - // CHECK: xegpu.load - // CHECK-SAME: {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu, transpose = array} // CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> %2 = xegpu.load %1, %0 {mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16x8xf32, #xegpu.scattered>, vector<16x8xi1> -> vector<8x16xf32> return } -// CHECK-LABEL: func @test_load_gather_vc_4({{.*}}) { -func.func @test_load_gather_vc_4(%src: ui64, %offsets : vector<16xindex>) { +// CHECK-LABEL: func @test_load_gather_vc_3({{.*}}) { +func.func @test_load_gather_vc_3(%src: ui64, %offsets : vector<16xindex>) { %0 = arith.constant dense<1>: vector<16xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc, chunk_size_per_lane = 1} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - // CHECK: xegpu.load - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> From 28d8f13aefa17458e7172150912d5deb318d82c5 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 17 Jan 2024 16:23:40 +0000 Subject: [PATCH 7/8] update testcases --- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +- mlir/test/Dialect/XeGPU/IR/load_nd.mlir | 103 ++++++------------ mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir | 48 ++++---- .../test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir | 29 +++-- mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir | 51 ++++----- mlir/test/Dialect/XeGPU/IR/store_scatter.mlir | 12 +- .../Dialect/XeGPU/IR/store_scatter_vc.mlir | 14 +-- .../Dialect/XeGPU/IR/update_nd_offset.mlir | 14 +-- .../Dialect/XeGPU/IR/update_offset_vc.mlir | 8 +- 9 files changed, 112 insertions(+), 169 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 5703c08430a30..41626093250e5 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -130,7 +130,7 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result CacheKindAttr attr; return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); } else if (nameId == "transpose") { - // in form of [4, 5] + // in form of [4, 5], acctually it is a copy of DenseI63ArrayAttr::parse() if (succeeded(parser.parseOptionalLSquare())) { Attribute attr; // handle empty list case diff --git a/mlir/test/Dialect/XeGPU/IR/load_nd.mlir b/mlir/test/Dialect/XeGPU/IR/load_nd.mlir index d05a0b523c51d..0644565c3f002 100644 --- a/mlir/test/Dialect/XeGPU/IR/load_nd.mlir +++ b/mlir/test/Dialect/XeGPU/IR/load_nd.mlir @@ -13,50 +13,38 @@ func.func @test_load_nd_fp16(%A: memref<24x32xf16>, %B : memref<24x32xf16>, %C : %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16> // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map> %1 = xegpu.create_nd_tdesc %A[%c0, %c1] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> - // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 1} - // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map> - // CHECK-SAME: -> vector<4x1x2xf16> + // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map> -> vector<4x1x2xf16> %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_a> -> vector<4x1x2xf16> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16> // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map> %3 = xegpu.create_nd_tdesc %B[%c0, %c1] : memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> - // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 0} - // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map> - // CHECK-SAME: -> vector<8x1x2xf16> + // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 0 : i64} + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.sg_map> -> vector<8x1x2xf16> %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xf16, #sg_map_fp16_b> -> vector<8x1x2xf16> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16> // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> %5 = xegpu.create_nd_tdesc %C[%c0, %c1] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> - // CHECK: xegpu.load_nd - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> - // CHECK-SAME: -> vector<8x1xf32> + // CHECK: xegpu.load_nd %{{[0-9]}} : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x1xf32> %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> -> vector<8x1xf32> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xf16> // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map> %7 = xegpu.create_nd_tdesc %A[%c0, %c1] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_d> - // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 1} - // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map> - // CHECK-SAME: -> vector<4x1x2xf16> + // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.sg_map> -> vector<4x1x2xf16> %8 = xegpu.load_nd %7 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xf16, #sg_map_fp16_d> -> vector<4x1x2xf16> return @@ -70,39 +58,27 @@ func.func @test_load_nd_bf16(%A: memref<24x32xbf16>, %B : memref<24x32xbf16>, %C %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xbf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xbf16> // CHECK-SAME: -> !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map> - %1 = xegpu.create_nd_tdesc %A[%c0, %c1] - : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_bf16_a> + %1 = xegpu.create_nd_tdesc %A[%c0, %c1] : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16, #sg_map_bf16_a> - // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 1} - // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map> - // CHECK-SAME: -> vector<4x1x2xbf16> + // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64} + // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16, #xegpu.sg_map> -> vector<4x1x2xbf16> %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x16xbf16, #sg_map_bf16_a> -> vector<4x1x2xbf16> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xbf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xbf16> // CHECK-SAME: -> !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map> - %3 = xegpu.create_nd_tdesc %B[%c0, %c1] - : memref<24x32xbf16> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_bf16_b> + %3 = xegpu.create_nd_tdesc %B[%c0, %c1] : memref<24x32xbf16> -> !xegpu.tensor_desc<16x16xbf16, #sg_map_bf16_b> - // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 0} - // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map> - // CHECK-SAME: -> vector<8x1x2xbf16> + // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 0 : i64} + // CHECK-SAME: !xegpu.tensor_desc<16x16xbf16, #xegpu.sg_map> -> vector<8x1x2xbf16> %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<16x16xbf16, #sg_map_bf16_b> -> vector<8x1x2xbf16> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<24x32xbf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<24x32xbf16> // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> - %5 = xegpu.create_nd_tdesc %C[%c0, %c1] - : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> + %5 = xegpu.create_nd_tdesc %C[%c0, %c1] : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xf32, #sg_map_fp16_c> - // CHECK: xegpu.load_nd - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> - // CHECK-SAME: -> vector<8x1xf32> + // CHECK: xegpu.load_nd %{{[0-9]}} : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x1xf32> %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32, #sg_map_bf16_c> -> vector<8x1xf32> return @@ -116,39 +92,28 @@ func.func @test_load_nd_i8(%A: memref<64x64xi8>, %B : memref<64x64xi8>, %C : mem %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<64x64xi8> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<64x64xi8> // CHECK-SAME: -> !xegpu.tensor_desc<8x32xi8, #xegpu.sg_map> - %1 = xegpu.create_nd_tdesc %A[%c0, %c1] - : memref<64x64xi8> -> !xegpu.tensor_desc<8x32xi8, #sg_map_i8_a> + %1 = xegpu.create_nd_tdesc %A[%c0, %c1] : memref<64x64xi8> -> !xegpu.tensor_desc<8x32xi8, #sg_map_i8_a> - // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 1} - // CHECK-SAME: !xegpu.tensor_desc<8x32xi8, #xegpu.sg_map> - // CHECK-SAME: -> vector<4x1x4xi8> + // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 1 : i64} + // CHECK-SAME: !xegpu.tensor_desc<8x32xi8, #xegpu.sg_map> -> vector<4x1x4xi8> %2 = xegpu.load_nd %1 {vnni_axis = 1} : !xegpu.tensor_desc<8x32xi8, #sg_map_i8_a> -> vector<4x1x4xi8> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<64x64xi8> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<64x64xi8> // CHECK-SAME: -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map> - %3 = xegpu.create_nd_tdesc %B[%c0, %c1] - : memref<64x64xi8> -> !xegpu.tensor_desc<32x16xi8, #sg_map_i8_b> + %3 = xegpu.create_nd_tdesc %B[%c0, %c1] : memref<64x64xi8> -> !xegpu.tensor_desc<32x16xi8, #sg_map_i8_b> - // CHECK: xegpu.load_nd - // CHECK-SAME: {vnni_axis = 0} - // CHECK-SAME: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map> - // CHECK-SAME: -> vector<8x1x4xi8> + // CHECK: xegpu.load_nd %{{[0-9]}} {vnni_axis = 0 : i64} + // CHECK-SAME: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map> -> vector<8x1x4xi8> %4 = xegpu.load_nd %3 {vnni_axis = 0} : !xegpu.tensor_desc<32x16xi8, #sg_map_i8_b> -> vector<8x1x4xi8> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: memref<64x64xi8> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] : memref<64x64xi8> // CHECK-SAME: -> !xegpu.tensor_desc<8x16xi32, #xegpu.sg_map> - %5 = xegpu.create_nd_tdesc %C[%c0, %c1] - : memref<64x64xi8> -> !xegpu.tensor_desc<8x16xi32, #sg_map_i8_c> + %5 = xegpu.create_nd_tdesc %C[%c0, %c1] : memref<64x64xi8> -> !xegpu.tensor_desc<8x16xi32, #sg_map_i8_c> - // CHECK: xegpu.load_nd - // CHECK-SAME: !xegpu.tensor_desc<8x16xi32, #xegpu.sg_map> - // CHECK-SAME: -> vector<8x1xi32> + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: !xegpu.tensor_desc<8x16xi32, #xegpu.sg_map> -> vector<8x1xi32> %6 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xi32, #sg_map_i8_c> -> vector<8x1xi32> return diff --git a/mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir b/mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir index 8703f171ac9df..78980b551c067 100644 --- a/mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/load_nd_vc.mlir @@ -10,36 +10,31 @@ func.func @test_load_nd_simd_f32(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} - // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]+}}, %{{c[0-9]+}}] + // CHECK-SAME: {mode = #xegpu} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc} - // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {mode = #xegpu} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %2 = xegpu.load_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - // CHECK: xegpu.load_nd - // CHECK-SAME:{mode = vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint = streaming} - // CHECK-SAME:!xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, l3_hint = #xegpu, mode = #xegpu, transpose = array} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> %3 = xegpu.load_nd %1 {mode= vc, transpose = [1, 0], l1_hint = cached, l2_hint = uncached, l3_hint=streaming} : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32> return } // CHECK-LABEL: func @test_load_nd_simd_f16({{.*}}) { func.func @test_load_nd_simd_f16(%src: memref<24x32xf16>, %x : index, %y : index) { - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg1, %arg2] - // CHECK-SAME: {mode = vc} - // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} - : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}] + // CHECK-SAME: {mode = #xegpu} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} - // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> + // CHECK: xegpu.load_nd %{{[0-9]+}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu, vnni_axis = 0 : i64} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> %2 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> return } @@ -47,13 +42,11 @@ func.func @test_load_nd_simd_f16(%src: memref<24x32xf16>, %x : index, %y : index // CHECK-LABEL: func @test_load_nd_simd_bf16({{.*}}) { func.func @test_load_nd_simd_bf16(%src: ui64, %w : index, %h : index, %x : index, %y : index) { %c1 = arith.constant 1 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: %arg0[%arg3, %arg4], [%arg2, %arg1], [%arg1, %c1] - // CHECK-SAME: {mode = vc} - // CHECK-SAME: ui64 -> !xegpu.tensor_desc<8x16xbf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{arg[0-9]}}], [%{{arg[0-9]}}, %{{c[0-9]}}] + // CHECK-SAME: {mode = #xegpu} : ui64 -> !xegpu.tensor_desc<8x16xbf16> %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] {mode = vc} : ui64 -> !xegpu.tensor_desc<8x16xbf16> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, vnni_axis = 1, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu, vnni_axis = 1 : i64} // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16> %2 = xegpu.load_nd %1 {mode=vc, vnni_axis = 1, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xbf16> -> vector<8x8x2xbf16> @@ -62,14 +55,13 @@ func.func @test_load_nd_simd_bf16(%src: ui64, %w : index, %h : index, %x : index // CHECK-LABEL: func @test_load_nd_block_array_simd_f16({{.*}}) { func.func @test_load_nd_block_array_simd_f16(%src: memref<8x32xf16>) { - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[0, 0] {mode = #xegpu} // CHECK-SAME: memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> %1 = xegpu.create_nd_tdesc %src[0, 0] {mode = vc} : memref<8x32xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> -> vector<2x8x16xf16> %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> -> vector<2x8x16xf16> diff --git a/mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir b/mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir index aec7689e92e70..6e2cb4de4ce1d 100644 --- a/mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/prefetch_nd_vc.mlir @@ -8,9 +8,11 @@ func.func @test_prefetch_nd_tdesc_vc_0(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} + // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - // CHECK: xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xf32> + // CHECK: xegpu.prefetch_nd %{{[0-9]}} {mode = #xegpu} : !xegpu.tensor_desc<8x16xf32> xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xf32> return @@ -18,12 +20,14 @@ func.func @test_prefetch_nd_tdesc_vc_0(%src: memref<24x32xf32>) { // CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_1({{.*}}) { func.func @test_prefetch_nd_tdesc_vc_1(%src: memref<24x32xf16>, %x : index, %y : index) { - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}] + // CHECK-SAME: {mode = #xegpu} // CHECK-SAME: memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} - : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // CHECK: xegpu.prefetch_nd %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> + + // CHECK: xegpu.prefetch_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16> xegpu.prefetch_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf16> return } @@ -34,9 +38,11 @@ func.func @test_prefetch_nd_tdesc_vc_i8(%src: memref<24x32xi8>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} + // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> - // CHECK: xegpu.prefetch_nd %0 {mode = vc} : !xegpu.tensor_desc<8x16xi8> + // CHECK: xegpu.prefetch_nd %{{[0-9]}} {mode = #xegpu} : !xegpu.tensor_desc<8x16xi8> xegpu.prefetch_nd %1 {mode = vc} : !xegpu.tensor_desc<8x16xi8> return @@ -44,12 +50,13 @@ func.func @test_prefetch_nd_tdesc_vc_i8(%src: memref<24x32xi8>) { // CHECK-LABEL: func @test_prefetch_nd_tdesc_vc_bf16({{.*}}) { func.func @test_prefetch_nd_tdesc_vc_bf16(%src: memref<24x32xbf16>, %x : index, %y : index) { - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} - // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{arg[0-9]}}, %{{arg[0-9]}}] + // CHECK-SAME: {mode = #xegpu} : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> %1 = xegpu.create_nd_tdesc %src[%x, %y] {mode = vc} : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> - // CHECK: xegpu.prefetch_nd %0 {mode = vc, l1_hint = uncached, l2_hint = cached} : !xegpu.tensor_desc<8x16xbf16> + // CHECK: xegpu.prefetch_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} + // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> xegpu.prefetch_nd %1 {mode = vc, l1_hint = uncached, l2_hint = cached}: !xegpu.tensor_desc<8x16xbf16> return } diff --git a/mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir b/mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir index 695c189627e1a..170b3a9fe8147 100644 --- a/mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/store_nd_vc.mlir @@ -9,25 +9,21 @@ func.func @test_store_nd_vc_bf16(%src: memref<24x32xbf16>, %dst: memref<24x32xbf %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} - : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} // CHECK-SAME: memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> - %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} - : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> + %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} : memref<24x32xbf16> -> !xegpu.tensor_desc<8x16xbf16> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16> - // CHECK: xegpu.store_nd - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK: xegpu.store_nd %{{[0-9]}}, %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16> xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16> return @@ -38,25 +34,22 @@ func.func @test_store_nd_vc_f64(%src: memref<24x32xf64>, %dst: memref<24x32xf64> %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} // CHECK-SAME: memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> - %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} - : memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> + %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} // CHECK-SAME: memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} : memref<24x32xf64> -> !xegpu.tensor_desc<8x16xf64> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xf64> -> vector<8x16xf64> %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xf64> -> vector<8x16xf64> - // CHECK: xegpu.store_nd - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK: xegpu.store_nd %{{[0-9]}}, %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: vector<8x16xf64>, !xegpu.tensor_desc<8x16xf64> xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xf64>, !xegpu.tensor_desc<8x16xf64> return @@ -67,25 +60,23 @@ func.func @test_store_nd_vc_i8(%src: memref<24x32xi8>, %dst: memref<24x32xi8>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} // CHECK-SAME: memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> %2 = xegpu.create_nd_tdesc %dst[%c0, %c1] {mode = vc} : memref<24x32xi8> -> !xegpu.tensor_desc<8x16xi8> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xi8> -> vector<8x16xi8> %3 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached}: !xegpu.tensor_desc<8x16xi8> -> vector<8x16xi8> - // CHECK: xegpu.store_nd - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK: xegpu.store_nd %{{[0-9]}}, %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: vector<8x16xi8>, !xegpu.tensor_desc<8x16xi8> xegpu.store_nd %3, %2 {mode = vc, l1_hint = write_back, l2_hint = uncached}: vector<8x16xi8>, !xegpu.tensor_desc<8x16xi8> return diff --git a/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir b/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir index 4bc631acc5125..6d98ac3950c31 100644 --- a/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir +++ b/mlir/test/Dialect/XeGPU/IR/store_scatter.mlir @@ -7,25 +7,21 @@ // CHECK-LABEL: func @test_store_scatter({{.*}}) { func.func @test_store_scatter(%src: ui64, %offsets : vector<16xindex>, %dst: ui64) { %0 = arith.constant dense: vector<16xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %2 = xegpu.create_tdesc %dst, %offsets {mode = vc} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - // CHECK: xegpu.load - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - // CHECK: xegpu.store - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} + // CHECK: xegpu.store %{{[0-9]}}, %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> diff --git a/mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir b/mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir index d1e57ddda45e2..c1a51712e7003 100644 --- a/mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/store_scatter_vc.mlir @@ -7,26 +7,22 @@ // CHECK-LABEL: func @test_store_scatter_vc({{.*}}) { func.func @test_store_scatter_vc(%src: ui64, %offsets : vector<16 x index>, %dst: ui64) { %0 = arith.constant dense<1>: vector<16xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %2 = xegpu.create_tdesc %dst, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - // CHECK: xegpu.load - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> %3 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> - // CHECK: xegpu.store - // CHECK-SAME: {mode = vc, l1_hint = write_back, l2_hint = uncached} - // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> + // CHECK: xegpu.store %{{[0-9]}}, %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} + // CHECK-SAME: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> xegpu.store %3, %2, %0 {mode = vc, l1_hint = write_back, l2_hint = uncached} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> return diff --git a/mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir b/mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir index e25edcdd72b2a..1b97be77a2d79 100644 --- a/mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir +++ b/mlir/test/Dialect/XeGPU/IR/update_nd_offset.mlir @@ -8,22 +8,20 @@ func.func @test_update_nd_offset_vc_0(%src: memref<24x32xf32>) { %c0 = arith.constant 2 : index %c1 = arith.constant 4 : index - // CHECK: xegpu.create_nd_tdesc - // CHECK-SAME: {mode = vc} - // CHECK-SAME: memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: xegpu.create_nd_tdesc %{{arg[0-9]}}[%{{c[0-9]}}, %{{c[0-9]}}] + // CHECK-SAME: {mode = #xegpu} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> %1 = xegpu.create_nd_tdesc %src[%c0, %c1] {mode = vc} : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> - // CHECK: xegpu.load_nd - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load_nd %{{[0-9]}} + // CHECK-SAME: {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> %2 = xegpu.load_nd %1 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - // CHECK: xegpu.update_nd_offset + // CHECK: xegpu.update_nd_offset %{{[0-9]}}, [%{{c[0-9]}}, %{{c[0-9]}}] {mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - %3 = xegpu.update_nd_offset %1, [%c0, %c1] {mode = vc} - : !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %3 = xegpu.update_nd_offset %1, [%c0, %c1] {mode = vc} : !xegpu.tensor_desc<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> return } diff --git a/mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir b/mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir index 0852484423693..05b0092d2379b 100644 --- a/mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir +++ b/mlir/test/Dialect/XeGPU/IR/update_offset_vc.mlir @@ -7,14 +7,12 @@ // CHECK-LABEL: func @test_update_offset_VC({{.*}}) { func.func @test_update_offset_VC(%src: ui64, %offsets : vector<16 x index>) { %0 = arith.constant dense<1>: vector<16xi1> - // CHECK: xegpu.create_tdesc - // CHECK-SAME: {mode = vc} + // CHECK: xegpu.create_tdesc %{{arg[0-9]}}, %{{arg[0-9]}} {mode = #xegpu} // CHECK-SAME: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %1 = xegpu.create_tdesc %src, %offsets {mode = vc} : ui64, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> - // CHECK: xegpu.load - // CHECK-SAME: {mode = vc, l1_hint = cached, l2_hint = uncached} + // CHECK: xegpu.load %{{[0-9]}}, %{{.*}} {l1_hint = #xegpu, l2_hint = #xegpu, mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> %2 = xegpu.load %1, %0 {mode = vc, l1_hint = cached, l2_hint = uncached} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xi1> -> vector<16xf32> @@ -22,7 +20,7 @@ func.func @test_update_offset_VC(%src: ui64, %offsets : vector<16 x index>) { %3 = arith.constant dense<16>: vector<16 x index> %4 = arith.addi %offsets, %3: vector<16 x index> - // CHECK: xegpu.update_offset + // CHECK: xegpu.update_offset %{{[0-9]}}, %{{[0-9]}} {mode = #xegpu} // CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> %5 = xegpu.update_offset %1, %4 {mode = vc} : !xegpu.tensor_desc<16xf32, #xegpu.scattered>, vector<16 x index> -> !xegpu.tensor_desc<16xf32, #xegpu.scattered> From 6093760fdbb69016254e6850410a07e66a5d715f Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Wed, 17 Jan 2024 16:49:26 +0000 Subject: [PATCH 8/8] run clang-format --- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 11 +- mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 255 +++++++++++---------- 2 files changed, 144 insertions(+), 122 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 7c85f44cc9360..60ab50227c224 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -156,7 +156,8 @@ mlir::Attribute TensorDescAttr::parse(mlir::AsmParser &parser, if (parser.parseGreater()) return {}; return TensorDescAttr::get( - parser.getContext(), memory_scope.value_or(xegpu::MemoryScopeKind::GLOBAL), + parser.getContext(), + memory_scope.value_or(xegpu::MemoryScopeKind::GLOBAL), array_length.value_or(1), boundary_check.value_or(true), scattered.value_or(xegpu::ScatteredAttr()), map.value_or(xegpu::SubGroupMapAttr())); @@ -285,12 +286,12 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const { if (printDefaultValues()) { auto encoding = getEncoding(); if (auto attr = getEncodingAsMapAttr()) { - encoding = - TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1, {}, attr); + encoding = TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1, + {}, attr); } if (auto attr = getEncodingAsScatteredAttr()) { - encoding = - TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1, attr, {}); + encoding = TensorDescAttr::get(getContext(), MemoryScopeKind::GLOBAL, 1, + attr, {}); } printer << ", " << encoding; } else if (auto encoding = getEncodingAsTensorDescAttr()) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 41626093250e5..627680e84ec94 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -88,7 +88,8 @@ static bool verifyAndInferShape(std::vector &shape, } static ParseResult -parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result) { +parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, + OperationState &result) { // no optional attributes, return success if (failed(parser.parseOptionalLBrace())) return success(); @@ -96,19 +97,20 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result llvm::SmallDenseSet seenKeys; auto parseElt = [&]() -> ParseResult { // The name of an attribute can either be a keyword, or a string. - // as compared to mlir::parseOptionalAttrList, the cases of using + // as compared to mlir::parseOptionalAttrList, the cases of using // TOken::bare_identifier and Token::inttype as key maybe not handlered std::string nameId; auto loc = parser.getCurrentLocation(); if (parser.parseOptionalKeywordOrString(&nameId)) - return parser.emitError(loc, "invalid attribute name: ") << nameId << ".\n"; + return parser.emitError(loc, "invalid attribute name: ") + << nameId << ".\n"; if (nameId.empty()) return parser.emitError(loc, "expected valid attribute name"); if (!seenKeys.insert(nameId).second) - return parser.emitError(loc, "duplicate key '") - << nameId << "' in dictionary attribute."; + return parser.emitError(loc, "duplicate key '") + << nameId << "' in dictionary attribute."; // Lazy load a dialect in the context if there is a possible namespace. auto splitName = StringRef(nameId).split('.'); @@ -125,25 +127,29 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result // for xegpu specific attributes if (nameId == "mode") { ModeKindAttr attr; - return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); - } else if (nameId == "l1_hint" || nameId == "l2_hint" || nameId == "l3_hint") { + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, + result.attributes); + } else if (nameId == "l1_hint" || nameId == "l2_hint" || + nameId == "l3_hint") { CacheKindAttr attr; - return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, result.attributes); + return parser.parseCustomAttributeWithFallback(attr, Type{}, nameId, + result.attributes); } else if (nameId == "transpose") { // in form of [4, 5], acctually it is a copy of DenseI63ArrayAttr::parse() - if (succeeded(parser.parseOptionalLSquare())) { + if (succeeded(parser.parseOptionalLSquare())) { Attribute attr; // handle empty list case if (succeeded(parser.parseOptionalRSquare())) { attr = DenseI64ArrayAttr::get(parser.getContext(), {}); } else { attr = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); - if (failed(parser.parseRSquare())) - return failure(); + if (failed(parser.parseRSquare())) + return failure(); } - if (!attr) return failure(); + if (!attr) + return failure(); result.addAttribute(nameId, attr); - return success(); + return success(); } else { // in form of array DenseI64ArrayAttr attr; @@ -167,7 +173,8 @@ parseOptionalAttrDictWithCustomAttrs(OpAsmParser &parser, OperationState &result void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type TensorDesc, Value source, ValueRange offsets, ValueRange shape, ValueRange strides, - llvm::ArrayRef static_offsets, ModeKind mode) { + llvm::ArrayRef static_offsets, + ModeKind mode) { auto offsetRank = static_offsets.size(); auto shapeRank = shape.size() ? shape.size() : getRankOf(source); @@ -199,7 +206,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, Value source, - llvm::ArrayRef offsets, ModeKind mode) { + llvm::ArrayRef offsets, + ModeKind mode) { auto ty = llvm::dyn_cast_if_present(source.getType()); assert(ty && ty.hasStaticShape() && offsets.size() == getRankOf(source)); @@ -216,8 +224,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, Type tdesc, Value source, llvm::ArrayRef offsets, - ValueRange shape, ValueRange stride, - ModeKind mode) { + ValueRange shape, ValueRange stride, ModeKind mode) { assert(shape.size() && offsets.size() && stride.size() && shape.size() == stride.size() && shape.size() == offsets.size()); @@ -281,8 +288,9 @@ ParseResult CreateNdDescOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) @@ -349,7 +357,7 @@ void CreateNdDescOp::print(OpAsmPrinter &printer) { elidedAttrs.push_back("operandSegmentSizes"); if (!printDefaults && mode == xegpu::ModeKind::SIMT) elidedAttrs.push_back("mode"); - + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -449,7 +457,7 @@ llvm::SmallVector CreateNdDescOp::getShape() { return shape; } - llvm_unreachable("Unexpected error in CreateNdDescOp. " + llvm_unreachable("Unexpected error in CreateNdDescOp. " "The shape information is missing.\n"); } @@ -496,7 +504,6 @@ llvm::ArrayRef CreateNdDescOp::getTensorDescShape() { return getTensorDescType().getShape(); } - //===----------------------------------------------------------------------===// // XeGPU_LoadNDOp //===----------------------------------------------------------------------===// @@ -512,8 +519,9 @@ ParseResult LoadNDOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) @@ -547,7 +555,7 @@ void LoadNDOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) elidedAttrs.push_back("mode"); - + printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -659,8 +667,9 @@ ParseResult StoreNDOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) @@ -688,7 +697,7 @@ void StoreNDOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -772,8 +781,9 @@ ParseResult PrefetchNDOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) @@ -796,7 +806,7 @@ void PrefetchNDOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -807,7 +817,8 @@ void PrefetchNDOp::print(OpAsmPrinter &printer) { //===----------------------------------------------------------------------===// // XeGPU_UpdateNDOffsetOp //===----------------------------------------------------------------------===// -ParseResult UpdateNDOffsetOp::parse(OpAsmParser &parser, OperationState &result) { +ParseResult UpdateNDOffsetOp::parse(OpAsmParser &parser, + OperationState &result) { llvm::SmallVector TensorDescOperands(1); llvm::SmallVector offsetsOperands; llvm::SmallVector TensorDescTypes(1); @@ -820,7 +831,7 @@ ParseResult UpdateNDOffsetOp::parse(OpAsmParser &parser, OperationState &result) return failure(); if (parser.parseComma()) return failure(); - + // parse offsets, e.g., [x, y] if (succeeded(parser.parseOptionalLSquare())) { offsetsOperandsLoc = parser.getCurrentLocation(); @@ -835,8 +846,9 @@ ParseResult UpdateNDOffsetOp::parse(OpAsmParser &parser, OperationState &result) auto loc = parser.getCurrentLocation(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) @@ -850,11 +862,13 @@ ParseResult UpdateNDOffsetOp::parse(OpAsmParser &parser, OperationState &result) if (parser.parseType(resultTypes[0])) return failure(); result.addTypes(resultTypes); - if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, TensorDescOperandsLoc, result.operands)) + if (parser.resolveOperands(TensorDescOperands, TensorDescTypes, + TensorDescOperandsLoc, result.operands)) return failure(); Type indexType = parser.getBuilder().getIndexType(); - if (parser.resolveOperands(offsetsOperands, indexType, offsetsOperandsLoc, result.operands)) + if (parser.resolveOperands(offsetsOperands, indexType, offsetsOperandsLoc, + result.operands)) return failure(); return success(); } @@ -874,7 +888,7 @@ void UpdateNDOffsetOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -936,14 +950,15 @@ ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOperand(Operands[1])) return failure(); - // parse the optional attributes + // parse the optional attributes auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) @@ -964,8 +979,7 @@ ParseResult CreateDescOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); result.addTypes(TensorDescTypes); - if (parser.resolveOperands(Operands, Types, operandsLoc, - result.operands)) + if (parser.resolveOperands(Operands, Types, operandsLoc, result.operands)) return failure(); return success(); } @@ -984,7 +998,7 @@ void CreateDescOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults) { if (mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); if (chunk == 1) elidedAttrs.push_back("chunk_size_per_lane"); } @@ -1114,8 +1128,9 @@ ParseResult LoadGatherOp::parse(OpAsmParser &parser, OperationState &result) { if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) @@ -1156,7 +1171,7 @@ void LoadGatherOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -1224,9 +1239,9 @@ LogicalResult LoadGatherOp::verify() { if (getTransposeAttr()) { auto trans = getTranspose().value(); if (tdescShape.size() < trans.size()) - return emitWarning("Invalid transpose attr. It is ignored."); + return emitWarning("Invalid transpose attr. It is ignored."); transpose(trans, tdescShape); - } + } if (getVnniAxis()) { auto axis = getVnniAxis().value(); @@ -1249,23 +1264,21 @@ LogicalResult LoadGatherOp::verify() { return success(); } - //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// void StoreScatterOp::build(OpBuilder &builder, OperationState &state, Value value, Value TensorDesc, Value mask, - CacheKindAttr l1_hint, - CacheKindAttr l2_hint, + CacheKindAttr l1_hint, CacheKindAttr l2_hint, CacheKindAttr l3_hint) { state.addOperands(value); state.addOperands(TensorDesc); state.addOperands(mask); - if (l1_hint) + if (l1_hint) state.getOrAddProperties().l1_hint = l1_hint; - if (l2_hint) + if (l2_hint) state.getOrAddProperties().l2_hint = l2_hint; - if (l3_hint) + if (l3_hint) state.getOrAddProperties().l3_hint = l3_hint; state.getOrAddProperties().mode = ModeKindAttr::get(builder.getContext(), ModeKind::VC); @@ -1301,15 +1314,16 @@ ParseResult StoreScatterOp::parse(OpAsmParser &parser, OperationState &result) { auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) - return failure(); + return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); if (parser.parseColon()) return failure(); - + if (parser.parseTypeList(Types)) return failure(); @@ -1334,7 +1348,7 @@ void StoreScatterOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -1413,8 +1427,8 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, } void PrefetchOp::build(OpBuilder &builder, OperationState &state, - Value TensorDesc, CacheKind l1_hint, - CacheKind l2_hint, CacheKind l3_hint) { + Value TensorDesc, CacheKind l1_hint, CacheKind l2_hint, + CacheKind l3_hint) { state.addOperands(TensorDesc); state.getOrAddProperties().l1_hint = CacheKindAttr::get(builder.getContext(), l1_hint); @@ -1439,10 +1453,11 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) { if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) - return failure(); - + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) + return failure(); + if (parser.parseColon()) return failure(); @@ -1464,7 +1479,7 @@ void PrefetchOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; @@ -1478,21 +1493,20 @@ LogicalResult PrefetchOp::verify() { auto mapping = tdescTy.getMapping(); auto isValidHint = [&](CacheKindAttr attr) -> bool { - if (!attr) return true; + if (!attr) + return true; auto kind = attr.getValue(); - return kind == CacheKind::CACHED || - kind == CacheKind::UNCACHED || - kind == CacheKind::STREAMING || - kind == CacheKind::READ_INVALIDATE; + return kind == CacheKind::CACHED || kind == CacheKind::UNCACHED || + kind == CacheKind::STREAMING || kind == CacheKind::READ_INVALIDATE; }; - if (!isValidHint(getL1HintAttr())) + if (!isValidHint(getL1HintAttr())) return emitOpError("invlid l1_hint: ") << getL1HintAttr(); - if (!isValidHint(getL2HintAttr())) + if (!isValidHint(getL2HintAttr())) return emitOpError("invlid l2_hint: ") << getL2HintAttr(); - if (!isValidHint(getL3HintAttr())) + if (!isValidHint(getL3HintAttr())) return emitOpError("invlid l3_hint: ") << getL3HintAttr(); if (!tdescTy.getScattered()) @@ -1510,31 +1524,32 @@ LogicalResult PrefetchOp::verify() { //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// -void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, +void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, Type result, Value TensorDesc, Value offsets) { state.addOperands(TensorDesc); state.addOperands(offsets); - state.getOrAddProperties().mode = - xegpu::ModeKindAttr::get(builder.getContext(), xegpu::ModeKind::VC); - state.addTypes(result); + state.getOrAddProperties().mode = + xegpu::ModeKindAttr::get(builder.getContext(), xegpu::ModeKind::VC); + state.addTypes(result); } ParseResult UpdateOffsetOp::parse(OpAsmParser &parser, OperationState &result) { - llvm::SmallVector Operands; + llvm::SmallVector Operands; llvm::SmallVector Types; auto OperandsLoc = parser.getCurrentLocation(); if (parser.parseOperandList(Operands)) return failure(); - + auto loc = parser.getCurrentLocation(); if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); - + if (parser.parseColon()) return failure(); @@ -1555,7 +1570,7 @@ ParseResult UpdateOffsetOp::parse(OpAsmParser &parser, OperationState &result) { } void UpdateOffsetOp::print(OpAsmPrinter &printer) { - auto mode =getMode(); + auto mode = getMode(); auto printDefaults = printDefaultValues(); printer << ' '; @@ -1566,7 +1581,7 @@ void UpdateOffsetOp::print(OpAsmPrinter &printer) { llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) - elidedAttrs.push_back("mode"); + elidedAttrs.push_back("mode"); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << ' ' << ":"; printer << ' '; @@ -1581,7 +1596,7 @@ void UpdateOffsetOp::print(OpAsmPrinter &printer) { LogicalResult UpdateOffsetOp::verify() { auto mode = getMode(); - if (mode != ModeKind::VC) + if (mode != ModeKind::VC) return emitOpError("UpdateOffsetOp only work on VC mode.\n"); auto srcTy = getTensorDesc().getType(); @@ -1611,7 +1626,7 @@ LogicalResult UpdateOffsetOp::verify() { // XeGPU_DpasOp //===----------------------------------------------------------------------===// ParseResult DpasOp::parse(OpAsmParser &parser, OperationState &result) { - llvm::SmallVector Operands; + llvm::SmallVector Operands; llvm::SmallVector Types; llvm::SMLoc OperandsLoc = parser.getCurrentLocation(); @@ -1622,10 +1637,11 @@ ParseResult DpasOp::parse(OpAsmParser &parser, OperationState &result) { if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); - + if (parser.parseColon()) return failure(); @@ -1655,9 +1671,9 @@ void DpasOp::print(OpAsmPrinter &printer) { printer << ","; printer << ' '; printer << getRhs(); - if (Value value = getAcc()) + if (Value value = getAcc()) printer << ", " << value; - + llvm::SmallVector elidedAttrs; if (!printDefaults && mode == xegpu::ModeKind::SIMT) elidedAttrs.push_back("mode"); @@ -1685,12 +1701,12 @@ LogicalResult DpasOp::verify() { Type lhsElemType = getLhsType().getElementType(); Type rhsElemType = getRhsType().getElementType(); - if (lhsElemType != rhsElemType) - return emitOpError("lhs and rhs element type does not match for dpas op"); + if (lhsElemType != rhsElemType) + return emitOpError("lhs and rhs element type does not match for dpas op"); if (getAcc() && getAccType() != getResultType()) return emitOpError("Accumulator and Result for dpas op should have the " - "same type (both shape and element type)."); + "same type (both shape and element type)."); if (lhsRank != rhsRank || lhsRank != 3) return emitOpError( @@ -1720,15 +1736,16 @@ void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, void InvokeSIMDOp::build(OpBuilder &builder, OperationState &state, llvm::StringRef callee, TypeRange results, ArgTypeKindAttr argType, ValueRange operands) { - build(builder, state, StringAttr::get(builder.getContext(), callee), - results, argType, operands); + build(builder, state, StringAttr::get(builder.getContext(), callee), results, + argType, operands); } //===----------------------------------------------------------------------===// // XeGPU_AtomicRMWOp //===----------------------------------------------------------------------===// void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, - AtomicRMWKindAttr kind, Value tensorDesc, Value mask, Value value) { + AtomicRMWKindAttr kind, Value tensorDesc, Value mask, + Value value) { state.addOperands(tensorDesc); state.addOperands(mask); if (value) @@ -1740,7 +1757,8 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, } void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, - AtomicRMWKind kind, Value tensorDesc, Value mask, Value value) { + AtomicRMWKind kind, Value tensorDesc, Value mask, + Value value) { state.addOperands(tensorDesc); state.addOperands(mask); if (value) @@ -1753,16 +1771,16 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state, Type result, } ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { - llvm::SmallVector Operands; + llvm::SmallVector Operands; llvm::SmallVector Types; llvm::SMLoc OperandsLoc; - + llvm::SmallVector resultTypes(1); xegpu::AtomicRMWKindAttr kindAttr; if (parser.parseCustomAttributeWithFallback(kindAttr, Type{})) return failure(); - if (kindAttr) + if (kindAttr) result.getOrAddProperties().kind = kindAttr; OperandsLoc = parser.getCurrentLocation(); @@ -1773,10 +1791,11 @@ ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { if (parseOptionalAttrDictWithCustomAttrs(parser, result)) return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); - + if (parser.parseColon()) return failure(); @@ -1785,9 +1804,9 @@ ParseResult AtomicRMWOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseArrow()) return failure(); - + if (parser.parseCustomTypeWithFallback(resultTypes[0])) - return failure(); + return failure(); result.addTypes(resultTypes); if (parser.resolveOperands(Operands, Types, OperandsLoc, result.operands)) @@ -1805,9 +1824,9 @@ void AtomicRMWOp::print(OpAsmPrinter &printer) { printer << ","; printer << ' '; printer << getMask(); - if (Value value = getValue()) + if (Value value = getValue()) printer << ", " << value; - + llvm::SmallVector elidedAttrs; elidedAttrs.push_back("kind"); if (!printDefaults && mode == xegpu::ModeKind::SIMT) @@ -1818,13 +1837,13 @@ void AtomicRMWOp::print(OpAsmPrinter &printer) { printer << ' '; printer << getOperation()->getOperandTypes(); printer << ' ' << "->"; - printer << ' '; + printer << ' '; printer << getResult().getType(); } LogicalResult AtomicRMWOp::verify() { auto mode = getMode(); - if (mode != ModeKind::VC) + if (mode != ModeKind::VC) return emitOpError("AtomicRMWOp only work on VC mode.\n"); return success(); } @@ -1832,8 +1851,9 @@ LogicalResult AtomicRMWOp::verify() { //===----------------------------------------------------------------------===// // XeGPU_CreateNbarrierOp //===----------------------------------------------------------------------===// -ParseResult CreateNbarrierOp::parse(OpAsmParser &parser, OperationState &result) { - llvm::SmallVector Operands; +ParseResult CreateNbarrierOp::parse(OpAsmParser &parser, + OperationState &result) { + llvm::SmallVector Operands; llvm::SmallVector Types; llvm::SMLoc OperandsLoc; @@ -1846,10 +1866,11 @@ ParseResult CreateNbarrierOp::parse(OpAsmParser &parser, OperationState &result) return failure(); if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() { - return parser.emitError(loc) << "'" << result.name.getStringRef() << "' op "; - }))) + return parser.emitError(loc) + << "'" << result.name.getStringRef() << "' op "; + }))) return failure(); - + if (parser.parseColon()) return failure();