diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h index 7badfaf4a8216..f85dde0b98c8e 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h @@ -27,6 +27,10 @@ class TensorDescType; class DistributeLayoutAttr; class LayoutAttr; class SliceAttr; + +/// Specifies the level of a layout hierarchy for comparison or propagation. +enum class LayoutKind { Lane, InstData, Subgroup }; + } // namespace xegpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 898fb7e1d8e6d..377967dfdb1e5 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -311,7 +311,32 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> { /*methodName=*/"isSliceOf", /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>, - InterfaceMethod, + InterfaceMethod @@ -546,9 +571,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { /// Check if this is slice of some other layout. bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; } - /// Check if this is identical to some other layout. + /// Check if this layout is equal to another layout. bool isEqualTo(const xegpu::DistributeLayoutAttr &other); - }]; let assemblyFormat = "`<` struct(params) `>`"; @@ -734,11 +758,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> { /// Check if this is slice of some other layout. bool isSliceOf(const xegpu::DistributeLayoutAttr &other); + /// Check if this layout is equal to another layout. + bool isEqualTo(const xegpu::DistributeLayoutAttr &other); + /// Drop the slice dims to get the original layout. SliceAttr dropSliceDims(ArrayRef sliceDimsToDrop); - - /// Check if this is identical to some other layout. - bool isEqualTo(const xegpu::DistributeLayoutAttr &other); }]; let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`"; diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 2cbec50772b98..fefc8c9903497 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -1593,9 +1593,8 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, }]; let description = [{ - This operation loads a 2D block of data from shared local memory (SLM) as specified - by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the - subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed. + This operation loads an nD block of data from shared local memory (SLM) as specified + by the provided nD `mem_desc`. Memory descriptors of any rank are supported. This operation serves as an anchor through which users assign a layout attribute to govern computation distribution. @@ -1665,9 +1664,8 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, let assemblyFormat = [{ $data `,` $mem_desc `` custom($offsets, $const_offsets) prop-dict attr-dict `` `:` type(operands)}]; let description = [{ - This operation stores a 2D `data` fragment into the shared local memory region - specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the - subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed. + This operation stores an nD `data` fragment into the shared local memory region + specified by an nD `mem_desc`. Memory descriptors of any rank are supported. This operation serves as an anchor through which users assign a layout attribute to govern computation distribution. diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h index 0d5210f07f05a..35dc9541b5755 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h @@ -33,8 +33,6 @@ class TensorDescType; namespace xegpu { -enum class LayoutKind { Lane, InstData, Subgroup }; - LogicalResult propagateLayouts(OpBuilder &builder, Operation *target, LayoutKind layoutKind, bool printOnly = false); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index d99557e68f0ec..7ace00a746e21 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -755,6 +755,17 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { [&](int64_t dim) { return thisDims.contains(dim); }); } +bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { + if (dyn_cast(other)) + return false; + + auto flattenedThis = flatten(); + auto flattenedOther = dyn_cast(other).flatten(); + + return ((flattenedThis.getParent() == flattenedOther.getParent()) && + (flattenedThis.getDims() == flattenedOther.getDims())); +} + xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef sliceDimsToDrop) { if (sliceDimsToDrop.empty()) return *this; @@ -773,17 +784,6 @@ xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef sliceDimsToDrop) { return sliceWithoutDims; } -bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) { - if (dyn_cast(other)) - return false; - - auto flattenedThis = flatten(); - auto flattenedOther = dyn_cast(other).flatten(); - - return ((flattenedThis.getParent() == flattenedOther.getParent()) && - (flattenedThis.getDims() == flattenedOther.getDims())); -} - // Helper function to adjust dimensions from sliced space to parent space // say we have a parent shape of rank 4, and slice dims [1,3], so the sliced // shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 91ba07a8e0256..c7226c7ebbd5d 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -186,8 +186,8 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, return success(); } - if (mdescTy.getRank() != 2) - return emitError() << "mem_desc must be 2D."; + if (mdescTy.getRank() < 2) + return emitError() << "mem_desc must be 2D or greater."; ArrayRef dataShape = dataTy.getShape(); ArrayRef mdescShape = mdescTy.getShape(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 566fb45f5bcb4..66eb7fc97aa1a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -602,44 +602,110 @@ struct WgToSgElementwiseOp : public ConversionPattern { struct WgToSgConvertLayoutOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); - auto input = op.getInputLayout(); - auto target = op.getTargetLayout(); + VectorType resultType = op.getResult().getType(); + ArrayRef wgShape = resultType.getShape(); + auto inputLayout = op.getInputLayout(); + auto targetLayout = op.getTargetLayout(); - if (!input || !target || !input.isForWorkgroup() || - !target.isForWorkgroup()) + if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() || + !targetLayout.isForWorkgroup()) return rewriter.notifyMatchFailure( op, "Input and target layouts must have subgroup layout"); - SmallVector inputSgLayout = input.getEffectiveSgLayoutAsInt(); - SmallVector inputSgData = input.getEffectiveSgDataAsInt(); - DenseI32ArrayAttr inputOrder = input.getOrder(); - SmallVector targetSgLayout = target.getEffectiveSgLayoutAsInt(); - SmallVector targetSgData = target.getEffectiveSgDataAsInt(); - DenseI32ArrayAttr targetOrder = target.getOrder(); - - // TODO: currently we only support for optimal case, where input and - // output has the same sg_layout and sg_data, so SLM is not involved. - if (inputSgLayout != targetSgLayout || inputSgData != targetSgData || - inputOrder != targetOrder) + SmallVector inputSgLayout = + inputLayout.getEffectiveSgLayoutAsInt(); + SmallVector inputSgData = inputLayout.getEffectiveSgDataAsInt(); + SmallVector targetSgLayout = + targetLayout.getEffectiveSgLayoutAsInt(); + SmallVector targetSgData = targetLayout.getEffectiveSgDataAsInt(); + + // Fast path: if sg_layout and sg_data are identical, no SLM needed + if (inputLayout.isCompatibleWith(targetLayout, + xegpu::LayoutKind::Subgroup)) { + inputLayout = inputLayout.dropSgLayoutAndData(); + targetLayout = targetLayout.dropSgLayoutAndData(); + + SmallVector newOps(adaptor.getSource()); + if (inputLayout && targetLayout) { + for (auto [i, src] : llvm::enumerate(adaptor.getSource())) { + auto newOp = xegpu::ConvertLayoutOp::create( + rewriter, loc, src.getType(), src, inputLayout, targetLayout); + newOps[i] = newOp; + } + } + rewriter.replaceOpWithMultiple(op, {newOps}); + return success(); + } + + // SLM path: layouts differ, need cross-subgroup data redistribution + Type elemTy = cast(op.getSource().getType()).getElementType(); + + SmallVector slmShape = llvm::to_vector(wgShape); + + // Calculate SLM size requirements + auto bitWidth = elemTy.getIntOrFloatBitWidth(); + auto bytesPerElement = bitWidth / 8; + auto slmSize = computeProduct(slmShape) * bytesPerElement; + + // Allocate SLM + auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3); + auto slm = memref::AllocaOp::create(rewriter, loc, slmTy); + + auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape, + elemTy, nullptr); + auto memDesc = + xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm); + + auto sgId = gpu::SubgroupIdOp::create(rewriter, loc, + rewriter.getIndexType(), nullptr); + + // STORE PHASE: Each subgroup stores in SLM using input layout + auto storeCoords = inputLayout.computeDistributedCoords( + rewriter, loc, sgId.getResult(), wgShape); + if (failed(storeCoords)) return failure(); - input = input.dropSgLayoutAndData(); - target = target.dropSgLayoutAndData(); + // Store to SLM + for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) { + SmallVector storeMatrixOffsets; + for (Value coord : coords) { + storeMatrixOffsets.push_back(coord); + } + xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(), + storeMatrixOffsets, nullptr /*layout*/); + } - SmallVector newOps(adaptor.getSource()); - if (input && target) { - // keep the ConvertLayoutOp for rest fields, e.g., inst_data. - for (auto [i, src] : llvm::enumerate(adaptor.getSource())) { - auto newOp = xegpu::ConvertLayoutOp::create( - rewriter, op.getLoc(), src.getType(), src, input, target); - newOps[i] = newOp; + gpu::BarrierOp::create(rewriter, loc); + + // LOAD PHASE: Each target subgroup loads from SLM using target layout + auto loadCoords = targetLayout.computeDistributedCoords( + rewriter, loc, sgId.getResult(), wgShape); + if (failed(loadCoords)) + return failure(); + + VectorType loadType = VectorType::get(targetSgData, elemTy); + + // Load vectors from SLM + SmallVector finalResults; + for (auto coords : *loadCoords) { + SmallVector loadMatrixOffsets; + for (Value coord : coords) { + loadMatrixOffsets.push_back(coord); } + auto loadOp = xegpu::LoadMatrixOp::create( + rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets, + targetLayout.dropSgLayoutAndData()); + + finalResults.push_back(loadOp.getResult()); } - rewriter.replaceOpWithMultiple(op, {newOps}); + + rewriter.replaceOpWithMultiple(op, {finalResults}); return success(); } }; diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index f2011ab86e9e9..e6376e3ecb4cd 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -852,7 +852,7 @@ func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) // ----- func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) { - // expected-error@+1 {{mem_desc must be 2D}} + // expected-error@+1 {{mem_desc must be 2D or greater}} %data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16> return } @@ -873,7 +873,7 @@ func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, % // ----- func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) { - // expected-error@+1 {{mem_desc must be 2D.}} + // expected-error@+1 {{mem_desc must be 2D or greater}} xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16> return } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 69a2ca7c49c2d..e2e94c5f0300f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -828,6 +828,112 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: convert_layout_no_slm + gpu.func @convert_layout_no_slm(%arg0: memref<4096x4096xf32>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.muli %block_id_x, %c256 overflow : index + %1 = arith.muli %block_id_y, %c256 overflow : index + %2 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> -> !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr, #xegpu.layout> + %3 = xegpu.load_nd %2[%0, %1] <{layout = #xegpu.layout}> : !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr, #xegpu.layout> -> vector<256x256xf32> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr, #xegpu.layout> + %5 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr, #xegpu.layout> + %6 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %3) -> (vector<256x256xf32>) { + %7 = xegpu.load_nd %4[%0, %arg3] <{layout = #xegpu.layout}> : !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr, #xegpu.layout> -> vector<256x32xf16> + %8 = xegpu.load_nd %5[%arg3, %1] <{layout = #xegpu.layout}> : !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr, #xegpu.layout> -> vector<32x256xf16> + // CHECK: %[[CONVERT_A:.*]] = xegpu.convert_layout %{{.*}} <{input_layout = #xegpu.layout, target_layout = #xegpu.layout}> : vector<32x32xf16> + // CHECK: %[[CONVERT_B:.*]] = xegpu.convert_layout %{{.*}} <{input_layout = #xegpu.layout, target_layout = #xegpu.layout}> : vector<32x32xf16> + %9 = xegpu.convert_layout %7 <{input_layout = #xegpu.layout, target_layout = #xegpu.layout}> : vector<256x32xf16> + %10 = xegpu.convert_layout %8 <{input_layout = #xegpu.layout, target_layout = #xegpu.layout}> : vector<32x256xf16> + %11 = xegpu.dpas %9, %10, %arg4 {layout_a = #xegpu.layout, layout_b = #xegpu.layout, layout_cd = #xegpu.layout} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> + scf.yield %11 : vector<256x256xf32> + } {layout_result_0 = #xegpu.layout} + xegpu.store_nd %6, %2[%0, %1] <{layout = #xegpu.layout}> : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #xegpu.block_tdesc_attr, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: convert_layout_slm + // CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32> + gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index + // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index + // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index + // CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index + // CHECK-DAG: %[[MUL_X:.*]] = arith.muli %[[SGIDX]], %[[C16:.*]] : index + // CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[MUL_Y]], %[[C128:.*]] : index + // CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[MUL_X]], %[[C256:.*]] : index + // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[OFF_Y]], %[[OFF_X]]] : memref<128x256xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.layout> + // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout> -> vector<32x16xf32> + // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3> + // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32> + // CHECK-DAG: %[[SGID_STORE:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID_STORE]], %[[C16:.*]] : index + // CHECK-DAG: %[[STORE_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C16:.*]] : index + // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_Y_TMP]], %[[C4:.*]] : index + // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index + // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index + // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index + // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index + // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index + // CHECK-DAG: gpu.barrier + // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID_STORE]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_Y_TMP]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index + // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index + // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index + // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index + // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout}>: !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32> + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout> + %1 = xegpu.load_nd %0 {layout = #xegpu.layout} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout> -> vector<128x256xf32> + %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout, + target_layout = #xegpu.layout}> : vector<128x256xf32> + gpu.return + } + + gpu.func @convert_layout_3D(%arg0: memref) { + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x32x16xindex> + // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense : vector<1x32x16xi1> + // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout}> : memref, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32> + // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3> + // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32> + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index + // CHECK-DAG: %[[STORE_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index + // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_YZ_TMP]], %[[C4:.*]] : index + // CHECK-DAG: %[[STORE_Z_TMP:.*]] = arith.divui %[[STORE_YZ_TMP]], %[[C4:.*]] : index + // CHECK-DAG: %[[STORE_Z:.*]] = arith.remui %[[STORE_Z_TMP]], %[[C8:.*]] : index + // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index + // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index + // CHECK-DAG: %[[STORE_OFF_Z:.*]] = arith.remui %[[STORE_Z]], %[[C8:.*]] : index + // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index + // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index + // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Z]], %[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<1x32x16xf32>, !xegpu.mem_desc<8x128x256xf32>, index, index, index + // CHECK-DAG: gpu.barrier + // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_Z_TMP:.*]] = arith.divui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_Z:.*]] = arith.remui %[[LOAD_Z_TMP]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index + // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index + // CHECK-DAG: %[[LOAD_OFF_Z:.*]] = arith.remui %[[LOAD_Z]], %[[C8:.*]] : index + // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index + // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index + // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Z]], %[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout}>: !xegpu.mem_desc<8x128x256xf32>, index, index, index -> vector<1x16x32xf32> + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<8x128x256xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<8x128x256xi1> + %1 = xegpu.load %arg0[%offset], %mask {chunk_size = 1, layout = #xegpu.layout} : memref, vector<8x128x256xindex>, vector<8x128x256xi1> -> vector<8x128x256xf32> + %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout, + target_layout = #xegpu.layout}> : vector<8x128x256xf32> + gpu.return + } + // CHECK-LABEL: distribute_nested_slice // CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32> // CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>