Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);

/// Appends patterns for XeGPU SIMT distribution into `patterns`.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns);
/// Appends patterns for XeGPU workgroup to subgroup distribution into
/// `patterns`.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);

/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ static bool requireTranspose(const xegpu::LayoutAttr layout,
/// }
/// return %0
/// }
struct MoveFuncBodyToWarpExecuteOnLane0
: public OpRewritePattern<gpu::GPUFuncOp> {
struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1447,6 +1446,11 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
/*pattern benefit=*/highPatternBenefit);
}

void xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(
RewritePatternSet &patterns) {
patterns.add<MoveFuncBodyToWarpOp>(patterns.getContext());
}

void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 1: Attach layouts to op operands.
// TODO: Following assumptions are made:
Expand All @@ -1473,7 +1477,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// gpu.warp_execute_on_lane_0 operation.
{
RewritePatternSet patterns(&getContext());
patterns.add<MoveFuncBodyToWarpExecuteOnLane0>(&getContext());
xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
Expand Down
63 changes: 63 additions & 0 deletions mlir/test/Dialect/XeGPU/move-gpu-func-to-warp-op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: mlir-opt -test-xegpu-move-func-to-warp-op -split-input-file --allow-unregistered-dialect %s | FileCheck %s

gpu.module @test {
gpu.func @empty() {
gpu.return
}
}

// CHECK-LABEL: gpu.func @empty() {
// CHECK-NEXT: gpu.return
// CHECK-NEXT: }

// -----
gpu.module @test {
gpu.func @gemm(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
%2 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%3 = xegpu.load_nd %1[%c0, %c0] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
%5 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %4, %5[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}

// CHECK-LABEL: gpu.func @gemm(
// CHECK: %[[ARG0:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<16x16xf16>,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<8x16xf32>) {
// CHECK: %[[LANEID:.*]] = gpu.lane_id
// CHECK-NEXT: gpu.warp_execute_on_lane_0(%[[LANEID]])[16]
// CHECK-SAME: args(%[[ARG0]], %[[ARG1]], %[[ARG2]] : memref<8x16xf16>, memref<16x16xf16>, memref<8x16xf32>) {
// CHECK: ^bb0(%[[ARG3:[a-zA-Z0-9]+]]: memref<8x16xf16>, %[[ARG4:[a-zA-Z0-9]+]]: memref<16x16xf16>,
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: memref<8x16xf32>):
// CHECK-NEXT: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG3]] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG4]] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T1]][{{.*}}] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// CHECK-NEXT: %[[T4:.*]] = xegpu.load_nd %[[T2]][{{.*}}] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
// CHECK-NEXT: %[[T5:.*]] = xegpu.dpas %[[T3]], %[[T4]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK-NEXT: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG5]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: xegpu.store_nd %[[T5]], %[[T6]][%{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: }
// CHECK-NEXT: gpu.return

// -----
gpu.module @test {
gpu.func @already_in_warp_op() {
%laneid = gpu.lane_id
gpu.warp_execute_on_lane_0(%laneid)[16] {
"some_op"() : () -> ()
gpu.yield
}
gpu.return
}
}

// CHECK-LABEL: gpu.func @already_in_warp_op() {
// CHECK: %[[LANEID:.*]] = gpu.lane_id
// CHECK: gpu.warp_execute_on_lane_0(%[[LANEID]])[16] {
// CHECK: "some_op"() : () -> ()
// CHECK: }
// CHECK: gpu.return
83 changes: 81 additions & 2 deletions mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ gpu.module @xevm_module{
// CHECK-NEXT: }
// CHECK-NEXT: %[[T1:.*]] = vector.transpose %[[W]]#1, [1, 0] : vector<1x2xf32> to vector<2x1xf32>
gpu.module @xevm_module{
gpu.func @vector_transpose(%arg0: memref<2x16xf32>, %laneid: index) {
gpu.func @vector_transpose(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<2x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
Expand All @@ -556,7 +556,7 @@ gpu.module @xevm_module{
// CHECK: }
// CHECK: vector.bitcast %[[W]]#1 : vector<4x2xi8> to vector<4x1xi16>
gpu.module @xevm_module{
gpu.func @vector_bitcast(%arg0: memref<4x16xi16>, %laneid: index) {
gpu.func @vector_bitcast(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<4x1xi16>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
Expand All @@ -573,3 +573,82 @@ gpu.module @xevm_module{
gpu.return
}
}

// -----
// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<1x16xf32>, vector<16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1xf32> to vector<1x1xf32>
gpu.module @xevm_module {
gpu.func @vector_shapecast_rank_increasing(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
: () -> (vector<16xf32>)
%cast = vector.shape_cast %cst
{
layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16xf32> to vector<1x16xf32>
gpu.yield %cast : vector<1x16xf32>
}
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
gpu.return
}
}

// -----
// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing(
// CHECK: %{{.*}}:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1xf32>, vector<1x1xf32>) {
// CHECK: gpu.yield %{{.*}} : vector<16xf32>, vector<1x16xf32>
// CHECK: }
// CHECK: %{{.*}} = vector.shape_cast %{{.*}}#1 : vector<1x1xf32> to vector<1xf32>
gpu.module @xevm_module {
gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
: () -> (vector<1x16xf32>)
%cast = vector.shape_cast %cst
{
layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
}
: vector<1x16xf32> to vector<16xf32>
gpu.yield %cast : vector<16xf32>
}
"some_user_op"(%r) : (vector<1xf32>) -> ()
gpu.return
}
}

// -----
// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
//
// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
// CHECK: gpu.yield %[[T1]] : vector<1x16xf32>
// CHECK: }
// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
// CHECK: gpu.return
gpu.module @xevm_module {
gpu.func @vector_shapecast_unsupported(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
: () -> (vector<16xf32>)
%cast = vector.shape_cast %cst
{
layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
}
: vector<16xf32> to vector<1x16xf32>
gpu.yield %cast : vector<1x16xf32>
}
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
gpu.return
}
}
32 changes: 32 additions & 0 deletions mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -247,6 +248,36 @@ struct TestXeGPUSGDistribute
}
};

struct TestXeGPUMoveFuncBodyToWarpOp
: public PassWrapper<TestXeGPUMoveFuncBodyToWarpOp,
OperationPass<gpu::GPUModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUMoveFuncBodyToWarpOp)

StringRef getArgument() const final {
return "test-xegpu-move-func-to-warp-op";
}

StringRef getDescription() const final {
return "Test the implementation of XeGPU move gpu function body to "
"WarpExecuteOnLane0 op.";
}

void getDependentDialects(::mlir::DialectRegistry &registry) const override {
registry.insert<xegpu::XeGPUDialect>();
registry.insert<gpu::GPUDialect>();
}

TestXeGPUMoveFuncBodyToWarpOp() = default;
TestXeGPUMoveFuncBodyToWarpOp(const TestXeGPUMoveFuncBodyToWarpOp &pass) =
default;

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
xegpu::populateXeGPUMoveFuncBodyToWarpOpPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

struct TestXeGPULayoutInterface
: public PassWrapper<TestXeGPULayoutInterface,
OperationPass<gpu::GPUModuleOp>> {
Expand Down Expand Up @@ -312,6 +343,7 @@ void registerTestXeGPULowerings() {
PassRegistration<TestXeGPUUnrollingPatterns>();
PassRegistration<TestXeGPULayoutInterface>();
PassRegistration<TestXeGPUSGDistribute>();
PassRegistration<TestXeGPUMoveFuncBodyToWarpOp>();
}
} // namespace test
} // namespace mlir