diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp index 3787fbb44e1b8..2431970884242 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp @@ -5,6 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -628,6 +629,14 @@ void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() { } }); } + // iterate the IR and attach each function op with sub_group_size attribute to + // support shuffle lowering in later stages. + root->walk([&](gpu::GPUFuncOp funcOp) { + auto uArch = getUArch(xegpu::getChipStr(funcOp).value_or("")); + funcOp->setAttr("intel_reqd_sub_group_size", + IntegerAttr::get(IntegerType::get(funcOp.getContext(), 32), + uArch->getSubgroupSize())); + }); } void xegpu::populateXeGPUSgToWiDistributeTypeConversions( diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 99c2da386fab6..29048e035e37a 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -2159,4 +2159,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() { op->erase(); return WalkResult::advance(); }); + + // iterate the IR and attach each function op with sub_group_size attribute to + // support shuffle lowering in later stages. + getOperation()->walk([&](gpu::GPUFuncOp funcOp) { + auto uArch = getUArch(xegpu::getChipStr(funcOp).value_or("")); + funcOp->setAttr("intel_reqd_sub_group_size", + IntegerAttr::get(IntegerType::get(funcOp.getContext(), 32), + uArch->getSubgroupSize())); + }); } diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir index 9172cd3018b71..2faa37ad2cd81 100644 --- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir +++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir @@ -216,4 +216,11 @@ gpu.func @gemm_with_postop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x102 gpu.return } +// CHECK-LABEL: gpu.func @dummy_func() +// CHECK-SAME: attributes {intel_reqd_sub_group_size = 16 : i32} +// CHECK : gpu.return +gpu.func @dummy_func(){ + gpu.return +} + } diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir index dae00838fdcb6..966bff4a6aa02 100644 --- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir +++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: gpu.func @load_dpas_postop_store // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, -// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) attributes {intel_reqd_sub_group_size = 16 : i32} { // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]][%{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> @@ -51,7 +51,7 @@ gpu.module @xevm_module{ // ----- // CHECK-LABEL: gpu.func @gemm // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, -// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) { +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) attributes {intel_reqd_sub_group_size = 16 : i32} { // CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x // CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y // CHECK-DAG: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index @@ -120,7 +120,7 @@ gpu.func @gemm(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %ar // ----- // CHECK-LABEL: gpu.func @scatter_ops_scf_yield -// CHECK: (%{{.*}}: memref<256xf16>, %[[PREDICATE:[a-zA-Z0-9]+]]: i1) { +// CHECK: (%{{.*}}: memref<256xf16>, %[[PREDICATE:[a-zA-Z0-9]+]]: i1) attributes {intel_reqd_sub_group_size = 16 : i32} { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.200000e+01> : vector<1x8xf16> // CHECK-DAG: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> // CHECK-DAG: %[[MASK:.*]] = arith.constant dense : vector<1xi1> @@ -156,7 +156,7 @@ gpu.module @xevm_module{ } // ----- -// CHECK-LABEL: gpu.func @scatter_ops_scf_non_yield({{.*}}) { +// CHECK-LABEL: gpu.func @scatter_ops_scf_non_yield({{.*}}) attributes {intel_reqd_sub_group_size = 16 : i32} { // CHECK: %[[OFFSET:.*]] = arith.constant dense<12> : vector<1xindex> // CHECK: %[[MASK:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[PREDICATE:.*]] = llvm.mlir.poison : i1 @@ -183,7 +183,7 @@ gpu.module @xevm_module{ // ----- // CHECK-LABEL: gpu.func @mma_transpose_b( -// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { +// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) // CHECK-DAG: %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> // CHECK-DAG: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32> // CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[ADESC]][%{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> @@ -194,7 +194,7 @@ gpu.module @xevm_module{ // CHECK-NEXT: %[[BCAST2:.*]] = vector.shape_cast %[[BCAST1]] : vector<1x16xf16> to vector<16xf16> // CHECK-NEXT: %[[C:.*]] = xegpu.dpas %[[A]], %[[BCAST2]] : vector<8xf16>, vector<16xf16> -> vector<8xf32> gpu.module @xevm_module{ - gpu.func @mma_transpose_b(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) { + gpu.func @mma_transpose_b(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) attributes {intel_reqd_sub_group_size = 16 : i32} { %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout> @@ -274,7 +274,7 @@ gpu.module @xevm_module{ } // ----- -// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) { +// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C8:.*]] = arith.constant 8 : index // CHECK: %[[LANE_ID:.*]] = gpu.lane_id @@ -295,7 +295,7 @@ gpu.module @xevm_module{ } // ----- -// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) { +// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) // CHECK: %[[C8:.*]] = arith.constant 8 : index // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[C4:.*]] = arith.constant 4 : index @@ -321,7 +321,7 @@ gpu.module @xevm_module{ } // ----- -// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) { +// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) // CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>: // CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout>, index, index -> vector<1x2xf32> // CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>: @@ -339,7 +339,7 @@ gpu.module @xevm_module{ } // ----- -// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) { +// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) gpu.module @xevm_module{ gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) { %c0 = arith.constant 0 : index @@ -358,7 +358,7 @@ gpu.module @xevm_module{ } // ----- -// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) { +// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) gpu.module @xevm_module{ gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) { %c0 = arith.constant 0 : index @@ -381,7 +381,7 @@ gpu.module @xevm_module{ } // ----- -// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) { +// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) gpu.module @xevm_module{ gpu.func @vector_shape_cast_scalar_to_vector(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) { %c0 = arith.constant 0 : index