From fdb548597573b51af867414ebee8046b23ec59da Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 20 Nov 2025 19:18:40 +0200 Subject: [PATCH] add optional slice_dims argument to set_op_layout_attr and set_desc_layout ops --- .../XeGPU/TransformOps/XeGPUTransformOps.td | 24 ++++++--- .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 35 ++++++++++--- mlir/python/mlir/dialects/transform/xegpu.py | 8 +++ mlir/test/Dialect/XeGPU/transform-ops.mlir | 38 ++++++++++++++ .../python/dialects/transform_xegpu_ext.py | 49 ++++++++++++++++++- 5 files changed, 139 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 16044838aa27d..29579acc727ed 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -42,10 +42,12 @@ def SetDescLayoutOp : Op` attribute. Returns a handle to + the transformed op. }]; let arguments = (ins @@ -55,7 +57,8 @@ def SetDescLayoutOp : Op:$inst_data, DefaultValuedOptionalAttr:$static_sg_layout, DefaultValuedOptionalAttr:$static_sg_data, - DefaultValuedOptionalAttr:$static_inst_data + DefaultValuedOptionalAttr:$static_inst_data, + DefaultValuedOptionalAttr:$slice_dims ); let results = (outs TransformHandleTypeInterface:$transformed); @@ -63,7 +66,8 @@ def SetDescLayoutOp : Op":$mixedSgLayout, "ArrayRef":$mixedSgData, - "ArrayRef":$mixedInstData + "ArrayRef":$mixedInstData, + "ArrayRef":$sliceDims )>, ]; @@ -72,6 +76,7 @@ def SetDescLayoutOp : Op($sg_layout, $static_sg_layout) `sg_data` `=` custom($sg_data, $static_sg_data) (`inst_data` `=` custom($inst_data, $static_inst_data)^)? + (`slice_dims` `=` $slice_dims^)? attr-dict `:` functional-type(operands, results) }]; @@ -107,7 +112,9 @@ def SetOpLayoutAttrOp : Op` attribute. }]; let arguments = (ins TransformHandleTypeInterface:$target, @@ -118,6 +125,7 @@ def SetOpLayoutAttrOp : Op:$static_sg_layout, DefaultValuedOptionalAttr:$static_sg_data, DefaultValuedOptionalAttr:$static_inst_data, + DefaultValuedOptionalAttr:$slice_dims, DefaultValuedAttr:$result ); @@ -128,6 +136,7 @@ def SetOpLayoutAttrOp : Op":$mixedSgLayout, "ArrayRef":$mixedSgData, "ArrayRef":$mixedInstData, + "ArrayRef":$sliceDims, CArg<"bool", "false">:$result )>, ]; @@ -137,6 +146,7 @@ def SetOpLayoutAttrOp : Op($sg_layout, $static_sg_layout) `sg_data` `=` custom($sg_data, $static_sg_data) (`inst_data` `=` custom($inst_data, $static_inst_data)^)? + (`slice_dims` `=` $slice_dims^)? attr-dict `:` qualified(type(operands)) }]; diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index e301d4d9bd108..8995ab3082d24 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -167,7 +167,8 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, /// Replace xegpu.create_nd_desc op with a new one with the given layout. static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, - xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) { + xegpu::CreateNdDescOp descOp, + xegpu::DistributeLayoutAttr layout) { assert(descOp.getMixedOffsets().size() == 0 && "create desc op with offsets is not supported"); auto oldTensorDesc = descOp.getType(); @@ -212,7 +213,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedSgLayout, ArrayRef mixedSgData, - ArrayRef mixedInstData) { + ArrayRef mixedInstData, + ArrayRef sliceDims) { SmallVector staticSgLayout, staticSgData, staticInstData; SmallVector dynamicSgLayout, dynamicSgData, dynamicInstData; dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); @@ -225,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder, /*inst_data=*/dynamicInstData, /*static_sg_layout=*/staticSgLayout, /*static_sg_data=*/staticSgData, - /*static_inst_data=*/staticInstData); + /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims); } DiagnosedSilenceableFailure @@ -246,6 +249,14 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, if (!status.succeeded()) return status; + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + // For now only create_nd_desc op is supported. auto descOp = dyn_cast(target); if (!descOp) { @@ -257,7 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, } // Set layout attr in desc op's return type. Replaces old desc op. - auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); + auto newdescOp = setDescLayout(rewriter, descOp, layout); // Map result handles. results.set(cast(getTransformed()), {newdescOp.getOperation()}); @@ -278,7 +289,8 @@ void transform::SetDescLayoutOp::getEffects( void transform::SetOpLayoutAttrOp::build( OpBuilder &builder, OperationState &ostate, Value target, int64_t index, ArrayRef mixedSgLayout, ArrayRef mixedSgData, - ArrayRef mixedInstData, bool result) { + ArrayRef mixedInstData, ArrayRef sliceDims, + bool result) { SmallVector staticSgLayout, staticSgData, staticInstData; SmallVector dynamicSgLayout, dynamicSgData, dynamicInstData; dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); @@ -293,6 +305,7 @@ void transform::SetOpLayoutAttrOp::build( /*static_sg_layout=*/staticSgLayout, /*static_sg_data=*/staticSgData, /*static_inst_data=*/staticInstData, + /*slice_dims=*/sliceDims, /*result=*/result); } @@ -326,11 +339,19 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, if (!status.succeeded()) return status; + xegpu::DistributeLayoutAttr layout = layoutAttr; + auto sliceDims = getSliceDims(); + if (sliceDims.size() > 0) { + // Wrap layoutAttr in a slice attribute. + layout = xegpu::SliceAttr::get( + getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); + } + // Set layout attribute for the op result or operand if (resultTarget) - xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr); + xegpu::setDistributeLayoutAttr(target->getResult(index), layout); else - xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr); + xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 7169b5e28ab5e..5aa6453b7cb8a 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -62,6 +62,7 @@ def __init__( sg_data: MixedValues, *, inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, loc=None, ip=None, ): @@ -92,6 +93,7 @@ def __init__( static_sg_layout=static_sg_layout, static_sg_data=static_sg_data, static_inst_data=static_inst_data, + slice_dims=slice_dims, loc=loc, ip=ip, ) @@ -103,6 +105,7 @@ def set_desc_layout( sg_data: MixedValues, *, inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, loc=None, ip=None, ) -> OpResult: @@ -111,6 +114,7 @@ def set_desc_layout( sg_layout, sg_data, inst_data=inst_data, + slice_dims=slice_dims, loc=loc, ip=ip, ).result @@ -127,6 +131,7 @@ def __init__( sg_data: MixedValues, *, inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, index: Optional[Union[int, Attribute]] = None, result: Optional[Union[bool, Attribute]] = None, loc=None, @@ -156,6 +161,7 @@ def __init__( static_sg_layout=static_sg_layout, static_sg_data=static_sg_data, static_inst_data=static_inst_data, + slice_dims=slice_dims, index=index, result=result, loc=loc, @@ -169,6 +175,7 @@ def set_op_layout_attr( sg_data: MixedValues, *, inst_data: Optional[MixedValues] = None, + slice_dims: Optional[MixedInt] = None, index: Optional[Union[int, Attribute]] = None, result: Optional[Union[bool, Attribute]] = None, loc=None, @@ -179,6 +186,7 @@ def set_op_layout_attr( sg_layout, sg_data, inst_data=inst_data, + slice_dims=slice_dims, index=index, result=result, loc=loc, diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index ff0accdec7532..561034fb5880b 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -121,6 +121,25 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @set_desc_layout_slice +func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.slice<#xegpu.layout, dims = [0]> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK-LABEL: @set_op_layout_attr_result_default_index func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> @@ -212,6 +231,25 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @set_op_layout_attr_result_slice +func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) { + // CHECK: = arith.extf + // CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} + %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op + transform.yield + } +} + +// ----- + // CHECK-LABEL: @set_op_layout_attr_operand_minimal func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) { %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index 4f89982ad1c44..2b11acb04ed5b 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -66,6 +66,25 @@ def setDescLayoutInstData(): # CHECK: inst_data = [8, 16] +@run +def setDescLayoutSlice(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.create_nd_tdesc"), + ) + with InsertionPoint(sequence.body): + xegpu.set_desc_layout( + sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setDescLayoutSlice + # CHECK: %0 = transform.xegpu.set_desc_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: slice_dims = [0] + + @run def setOpLayoutAttrOperandMinimal(): sequence = transform.SequenceOp( @@ -106,13 +125,41 @@ def setOpLayoutAttrResult(): result=True, ) transform.YieldOp() - # CHECK-LABEL: TEST: setOpLayoutAttr + # CHECK-LABEL: TEST: setOpLayoutAttrResult + # CHECK: transform.xegpu.set_op_layout_attr % + # NO-CHECK: index = 0 + # CHECK: result + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16] + + +@run +def setOpLayoutAttrResultSlice(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.set_op_layout_attr( + sequence.bodyTarget, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16], + slice_dims=[0], + result=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOpLayoutAttrResultSlice # CHECK: transform.xegpu.set_op_layout_attr % # NO-CHECK: index = 0 # CHECK: result # CHECK: sg_layout = [6, 4] # CHECK: sg_data = [32, 16] # CHECK: inst_data = [8, 16] + # CHECK: slice_dims = [0] @run