diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index d04933423ecd0..a5b1df0f93f57 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1255,7 +1255,6 @@ struct WgToSgMultiDimReductionOp bool isScalarResult = !dstVecType; auto originalSrcShape = srcType.getShape(); - int srcVecRank = originalSrcShape.size(); Type elemTy = srcType.getElementType(); xegpu::DistributeLayoutAttr layout = @@ -1268,9 +1267,11 @@ struct WgToSgMultiDimReductionOp // Get sg_layout and sg_data from the parent layout SmallVector sgLayout; SmallVector sgData; + xegpu::DistributeLayoutAttr parentLayout; if (auto sliceAttr = dyn_cast(layout)) { - sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt(); - sgData = sliceAttr.getParent().getEffectiveSgDataAsInt(); + parentLayout = sliceAttr.getParent(); + sgLayout = parentLayout.getEffectiveSgLayoutAsInt(); + sgData = parentLayout.getEffectiveSgDataAsInt(); } else return rewriter.notifyMatchFailure( op, "Reduction should have SliceAttr layout"); @@ -1330,26 +1331,33 @@ struct WgToSgMultiDimReductionOp return success(); } - // Step 2: cross-subgroup reduction using SLM + // Step 2: cross-subgroup reduction using SLM - allocating slm memory auto slmStoreDataShape = sgSrcShape; for (int64_t dim : reductionDims) slmStoreDataShape[dim] = 1; VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy); - Value slmStoreData; - if (isScalarResult) { - // Scalar result: broadcast scalar to vector<1x...x1> for SLM store - slmStoreData = vector::BroadcastOp::create( - rewriter, loc, slmStoreDataType, localReductions[0]); - } else { - slmStoreData = vector::ShapeCastOp::create( - rewriter, loc, slmStoreDataType, localReductions[0]); + SmallVector slmStoreData; + for (auto localResult : localReductions) { + if (isScalarResult) { + // Scalar result: broadcast scalar to vector<1x...x1> for SLM store + slmStoreData.push_back(vector::BroadcastOp::create( + rewriter, loc, slmStoreDataType, localResult)); + } else { + slmStoreData.push_back(vector::ShapeCastOp::create( + rewriter, loc, slmStoreDataType, localResult)); + } } - + // for reduction dimension, SLM stores partial results from each subgroup SmallVector slmShape(originalSrcShape.begin(), originalSrcShape.end()); - // for reduction dimension, SLM stores partial results from each subgroup - for (int64_t dim : reductionDims) + SmallVector slmSgData(sgData.begin(), sgData.end()); + SmallVector slmSgLayout(sgLayout.begin(), sgLayout.end()); + for (int dim : reductionDims) { slmShape[dim] = sgLayout[dim]; + slmSgData[dim] = 1; + } + xegpu::LayoutAttr slmStoreLayout = + xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData); // Allocate SLM auto bitWidth = elemTy.getIntOrFloatBitWidth(); @@ -1363,82 +1371,61 @@ struct WgToSgMultiDimReductionOp auto memDesc = xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm); - // if localReductions have more than 1 result, not support - if (localReductions.size() > 1) { - return rewriter.notifyMatchFailure( - op, - "Multiple local reductions not supported in current implementation."); - } - - // Step 4: Store local results to SLM + // Step 3: Store local results to SLM auto sgId = gpu::SubgroupIdOp::create(rewriter, loc, rewriter.getIndexType(), nullptr); - // Convert sgLayout to Values for delinearizeIndex - SmallVector sgLayoutValues; - for (int64_t dim : sgLayout) - sgLayoutValues.push_back( - arith::ConstantIndexOp::create(rewriter, loc, dim)); - - auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(), - sgLayoutValues); - if (failed(sgIdsResult)) + auto slmStoreCoords = + slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape); + if (failed(slmStoreCoords)) return failure(); - SmallVector sgIds = *sgIdsResult; - - auto getSlmOffsets = [&](int64_t reductionDimStride) { - SmallVector offsets; - offsets.reserve(srcVecRank); - for (int i = 0; i < srcVecRank; ++i) { - Value dimVal = sgIds[i]; - int64_t sgDataStride = (llvm::is_contained(reductionDims, i)) - ? reductionDimStride - : sgSrcShape[i]; - Value strideVal = - arith::ConstantIndexOp::create(rewriter, loc, sgDataStride); - Value offsetVal = - arith::MulIOp::create(rewriter, loc, dimVal, strideVal); - offsets.push_back(offsetVal); - } - return offsets; - }; - - SmallVector slmStoreOffsets = - getSlmOffsets(/*reductionDimStride=*/1); - - xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData, - memDesc.getResult(), slmStoreOffsets, - /*layout=*/nullptr); + for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) { + SmallVector coordOfr(coord.begin(), coord.end()); + xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(), + coordOfr, + /*layout=*/nullptr); + } gpu::BarrierOp::create(rewriter, loc); - // Step 5: Load from SLM for final reduction + // Step 4: Load from SLM for final reduction SmallVector slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end()); - for (int64_t dim : reductionDims) + for (int64_t dim : reductionDims) { slmLoadDataShape[dim] = slmShape[dim]; - - SmallVector slmLoadOffsets = - getSlmOffsets(/*reductionDimStride=*/0); + slmSgData[dim] = slmShape[dim]; + } + xegpu::LayoutAttr slmLoadLayout = + xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData); + auto slmLoadCoords = + slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape); + if (failed(slmLoadCoords)) + return failure(); VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy); - auto slmLoadOp = xegpu::LoadMatrixOp::create( - rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets, - /*layout=*/nullptr); + SmallVector slmLoadData; + for (auto coord : *slmLoadCoords) { + SmallVector coordOfr(coord.begin(), coord.end()); + slmLoadData.push_back(xegpu::LoadMatrixOp::create( + rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr, + /*layout=*/nullptr)); + } - // Step 6: Perform final reduction with neutral accumulator + // Step 5: Perform final reduction with neutral accumulator and add the + // original accumulator at the end Value neutralFinalAcc = xegpu::createReductionNeutralValue( rewriter, loc, sgDstType, op.getKind()); - auto finalReduce = vector::MultiDimReductionOp::create( - rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(), - neutralFinalAcc, reductionDims); - - // Step 7: Add the original accumulator at the end - auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(), - finalReduce.getResult(), - adaptor.getAcc()[0]); - - rewriter.replaceOp(op, finalResult); + SmallVector finalResults; + for (size_t i = 0; i < slmLoadData.size(); ++i) { + auto loaded = slmLoadData[i]; + auto finalReduce = vector::MultiDimReductionOp::create( + rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc, + reductionDims); + finalResults.push_back(vector::makeArithReduction( + rewriter, loc, op.getKind(), finalReduce.getResult(), + adaptor.getAcc()[i])); + } + rewriter.replaceOpWithMultiple(op, {finalResults}); return success(); } }; diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 320a2fb1f72ac..897eab12329e2 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -165,6 +165,58 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: gpu.func @reduction_cross_sg_rr + gpu.func @reduction_cross_sg_rr(%arg0: memref<2048xf32, 1>) kernel { + // CHECK: %[[CST_OFFSETS0:.*]] = arith.constant dense<0> : vector<4x16xindex> + // CHECK: %[[CST_OFFSETS1:.*]] = arith.constant dense<0> : vector<4x16xindex> + // CHECK: %[[CST_ACC0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[CST_ACC1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[CST_MASK0:.*]] = arith.constant dense : vector<4x16xi1> + // CHECK: %[[CST_MASK1:.*]] = arith.constant dense : vector<4x16xi1> + // + // CHECK: %[[LOAD0:.*]] = xegpu.load %arg0[%[[CST_OFFSETS0]]], %[[CST_MASK0]] + // CHECK-SAME: -> vector<4x16xf32> + // CHECK: %[[LOAD1:.*]] = xegpu.load %arg0[%[[CST_OFFSETS1]]], %[[CST_MASK1]] + // CHECK-SAME: -> vector<4x16xf32> + // + // Local reductions + // CHECK: %[[NEUTRAL0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[LOCAL_RED0:.*]] = vector.multi_reduction , %[[LOAD0]], %[[NEUTRAL0]] [1] : vector<4x16xf32> to vector<4xf32> + // CHECK: %[[NEUTRAL1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[LOCAL_RED1:.*]] = vector.multi_reduction , %[[LOAD1]], %[[NEUTRAL1]] [1] : vector<4x16xf32> to vector<4xf32> + // + // Shape cast for SLM store + // CHECK: %[[SC0:.*]] = vector.shape_cast %[[LOCAL_RED0]] : vector<4xf32> to vector<4x1xf32> + // CHECK: %[[SC1:.*]] = vector.shape_cast %[[LOCAL_RED1]] : vector<4xf32> to vector<4x1xf32> + // + // SLM allocation and mem_desc + // CHECK: %[[SLM:.*]] = memref.alloca() : memref<512xi8, 3> + // CHECK: %[[MEMDESC:.*]] = xegpu.create_mem_desc %[[SLM]] : memref<512xi8, 3> -> !xegpu.mem_desc<8x16xf32> + // + // Store to SLM + // CHECK: xegpu.store_matrix %[[SC0]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32> + // CHECK: xegpu.store_matrix %[[SC1]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32> + // CHECK: gpu.barrier + // + // Load from SLM + // CHECK: %[[SLM_LOAD0:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32> + // CHECK: %[[SLM_LOAD1:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32> + // + // Final reduction + // CHECK: %[[FINAL_NEUTRAL:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[FINAL_RED0:.*]] = vector.multi_reduction , %[[SLM_LOAD0]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32> + // CHECK: %[[RES0:.*]] = arith.addf %[[FINAL_RED0]], %[[CST_ACC0]] : vector<4xf32> + // CHECK: %[[FINAL_RED1:.*]] = vector.multi_reduction , %[[SLM_LOAD1]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32> + // CHECK: %[[RES1:.*]] = arith.addf %[[FINAL_RED1]], %[[CST_ACC1]] : vector<4xf32> + + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<8x256xindex> + %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} dense<0.000000e+00> : vector<8xf32> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<8x256xi1> + %val = xegpu.load %arg0[%offset], %mask <{layout = #xegpu.layout}> : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32> + %reduce = vector.multi_reduction , %val, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32> + gpu.return + } + // CHECK-LABEL: splat_constant gpu.func @splat_constant() { // CHECK-COUNT-2: %[[CST:.*]] = arith.constant dense<0> : vector<4xindex> diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 90c6a73497630..bbdffa0986962 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -1,13 +1,4 @@ // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s - -// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 4)> -// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 4)> -// CHECK-DAG: #map2 = affine_map<()[s0] -> (s0 floordiv 32)> -// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 mod 32)> -// CHECK-DAG: #map4 = affine_map<()[s0] -> (0)> -// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)> -// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)> -// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)> gpu.module @test_distribution { // CHECK-LABEL: create_nd_tdesc_no_offset // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32> @@ -681,18 +672,9 @@ gpu.module @test_distribution { // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3> // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[AFF0:.*]] = affine.apply #map2()[%[[SGID]]] - // CHECK-DAG: %[[AFF1:.*]] = affine.apply #map3()[%[[SGID]]] - // CHECK-DAG: %[[AFF2:.*]] = affine.apply #map4()[%[[SGID]]] - // CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index - // CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index - // CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index - // CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]][%[[ROW]], %[[COL0]], %[[COL1]]] : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index + // CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index // CHECK-DAG: gpu.barrier - // CHECK-DAG: %[[ROW_L:.*]] = arith.muli %[[AFF0]], %[[C1C:.*]] : index - // CHECK-DAG: %[[COL0_L:.*]] = arith.muli %[[AFF1]], %[[C0:.*]] : index - // CHECK-DAG: %[[COL1_L:.*]] = arith.muli %[[AFF2]], %[[C32B:.*]] : index - // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ROW_L]], %[[COL0_L]], %[[COL1_L]]] : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32> + // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32> // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32> // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction , %[[LOAD_SLM]], %[[CST_3]] [1] : vector<1x32x32xf32> to vector<1x32xf32> // CHECK-DAG: %[[ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x32xf32> @@ -725,15 +707,9 @@ gpu.module @test_distribution { // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3> // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32> // CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]] - // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]] - // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index - // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[AFFINE2]], %[[C32_1:.*]] : index - // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index + // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index // CHECK-DAG: gpu.barrier - // CHECK-DAG: %[[ZERO_ROW:.*]] = arith.muli %[[AFFINE1]], %[[C0:.*]] : index - // CHECK-DAG: %[[COL_OFFSET2:.*]] = arith.muli %[[AFFINE2]], %[[C32_2:.*]] : index - // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ZERO_ROW]], %[[COL_OFFSET2]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32> + // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32> // CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32> // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction , %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32> // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<32xf32> @@ -761,31 +737,9 @@ gpu.module @test_distribution { // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3> // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<2x2x4x4xf32> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[I0:.*]] = arith.muli %[[AFFINE0]], %[[C1]] : index - // CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[I1:.*]] = arith.muli %[[AFFINE2]], %[[C1_0]] : index - // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[I2:.*]] = arith.muli %[[AFFINE4]], %[[C1_1]] : index - // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[I3:.*]] = arith.muli %[[AFFINE5]], %[[C1_2]] : index - // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[I0]], %[[I1]], %[[I2]], %[[I3]]] : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index + // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index // CHECK-DAG: gpu.barrier - // CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C1_3]] : index - // CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C1_4]] : index - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0]] : index - // CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_0]] : index - // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32> + // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32> // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32> // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction , %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<1x1x4x4xf32> to vector<1x1xf32> // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x1xf32> @@ -811,23 +765,9 @@ gpu.module @test_distribution { // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<65536xi8, 3> // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<32x32x4x4xf32> // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index - // CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]] - // CHECK-DAG: %[[R0:.*]] = arith.muli %[[AFFINE0]], %[[C16_0:.*]] : index - // CHECK-DAG: %[[R1:.*]] = arith.muli %[[AFFINE2]], %[[C16_1:.*]] : index - // CHECK-DAG: %[[R2:.*]] = arith.muli %[[AFFINE4]], %[[C1_0:.*]] : index - // CHECK-DAG: %[[R3:.*]] = arith.muli %[[AFFINE5]], %[[C1_1:.*]] : index - // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[R0]], %[[R1]], %[[R2]], %[[R3]]] : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index + // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index // CHECK-DAG: gpu.barrier - // CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C16_2:.*]] : index - // CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C16_3:.*]] : index - // CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0_0:.*]] : index - // CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_1:.*]] : index - // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32> + // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32> // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32> // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction , %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<16x16x4x4xf32> to vector<16x16xf32> // CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<16x16xf32>