Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [

let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
let description = [{
Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
attribute to the result tensor descriptor. The layout is defined by the
`sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
to the transformed op.
Given an `xegpu.create_nd_desc` operation, this transform adds
`xegpu.layout` attribute to the result tensor descriptor. The layout is
defined by the `sg_layout`, and `sg_data` and optional `inst_data`
attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute. Returns a handle to
the transformed op.
}];

let arguments = (ins
Expand All @@ -55,15 +57,17 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
);

let results = (outs TransformHandleTypeInterface:$transformed);
let builders = [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData
"ArrayRef<OpFoldResult>":$mixedInstData,
"ArrayRef<int64_t>":$sliceDims
)>,
];

Expand All @@ -72,6 +76,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
(`slice_dims` `=` $slice_dims^)?
attr-dict `:` functional-type(operands, results)
}];

Expand Down Expand Up @@ -107,7 +112,9 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
target operand/result value is defined by the `index` argument. The layout
is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
is defined by the `sg_layout`, `sg_data` and optional `inst_data`
attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
Expand All @@ -118,6 +125,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
DefaultValuedAttr<UnitAttr, "false">:$result
);

Expand All @@ -128,6 +136,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
"ArrayRef<int64_t>":$sliceDims,
CArg<"bool", "false">:$result
)>,
];
Expand All @@ -137,6 +146,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
(`slice_dims` `=` $slice_dims^)?
attr-dict `:` qualified(type(operands))
}];

Expand Down
35 changes: 28 additions & 7 deletions mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
xegpu::CreateNdDescOp descOp,
xegpu::DistributeLayoutAttr layout) {
assert(descOp.getMixedOffsets().size() == 0 &&
"create desc op with offsets is not supported");
auto oldTensorDesc = descOp.getType();
Expand Down Expand Up @@ -212,7 +213,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedSgLayout,
ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData) {
ArrayRef<OpFoldResult> mixedInstData,
ArrayRef<int64_t> sliceDims) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
Expand All @@ -225,7 +227,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
/*inst_data=*/dynamicInstData,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData);
/*static_inst_data=*/staticInstData,
/*slice_dims=*/sliceDims);
}

DiagnosedSilenceableFailure
Expand All @@ -246,6 +249,14 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
if (!status.succeeded())
return status;

xegpu::DistributeLayoutAttr layout = layoutAttr;
auto sliceDims = getSliceDims();
if (sliceDims.size() > 0) {
// Wrap layoutAttr in a slice attribute.
layout = xegpu::SliceAttr::get(
getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
}

// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
Expand All @@ -257,7 +268,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}

// Set layout attr in desc op's return type. Replaces old desc op.
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
auto newdescOp = setDescLayout(rewriter, descOp, layout);

// Map result handles.
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
Expand All @@ -278,7 +289,8 @@ void transform::SetDescLayoutOp::getEffects(
void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData, bool result) {
ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
bool result) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
Expand All @@ -293,6 +305,7 @@ void transform::SetOpLayoutAttrOp::build(
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
/*slice_dims=*/sliceDims,
/*result=*/result);
}

Expand Down Expand Up @@ -326,11 +339,19 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
if (!status.succeeded())
return status;

xegpu::DistributeLayoutAttr layout = layoutAttr;
auto sliceDims = getSliceDims();
if (sliceDims.size() > 0) {
// Wrap layoutAttr in a slice attribute.
layout = xegpu::SliceAttr::get(
getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims));
}

// Set layout attribute for the op result or operand
if (resultTarget)
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
else
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
return DiagnosedSilenceableFailure::success();
}

Expand Down
8 changes: 8 additions & 0 deletions mlir/python/mlir/dialects/transform/xegpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
loc=None,
ip=None,
):
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
slice_dims=slice_dims,
loc=loc,
ip=ip,
)
Expand All @@ -103,6 +105,7 @@ def set_desc_layout(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
loc=None,
ip=None,
) -> OpResult:
Expand All @@ -111,6 +114,7 @@ def set_desc_layout(
sg_layout,
sg_data,
inst_data=inst_data,
slice_dims=slice_dims,
loc=loc,
ip=ip,
).result
Expand All @@ -127,6 +131,7 @@ def __init__(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
loc=None,
Expand Down Expand Up @@ -156,6 +161,7 @@ def __init__(
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
slice_dims=slice_dims,
index=index,
result=result,
loc=loc,
Expand All @@ -169,6 +175,7 @@ def set_op_layout_attr(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
loc=None,
Expand All @@ -179,6 +186,7 @@ def set_op_layout_attr(
sg_layout,
sg_data,
inst_data=inst_data,
slice_dims=slice_dims,
index=index,
result=result,
loc=loc,
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/XeGPU/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ module attributes {transform.with_named_sequence} {

// -----

// CHECK-LABEL: @set_desc_layout_slice
func.func @set_desc_layout_slice(%arg0: memref<4096xf16>) {
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
// CHECK-SAME: #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096xf16> -> !xegpu.tensor_desc<256xf16>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_desc_layout %{{.*}}
%1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] slice_dims = [0] : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_result_default_index
func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
Expand Down Expand Up @@ -212,6 +231,25 @@ module attributes {transform.with_named_sequence} {

// -----

// CHECK-LABEL: @set_op_layout_attr_result_slice
func.func @set_op_layout_attr_result_slice(%arg0: vector<256xf16>) {
// CHECK: = arith.extf
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>, dims = [0]>}
%2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] slice_dims = [0] : !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_operand_minimal
func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
Expand Down
49 changes: 48 additions & 1 deletion mlir/test/python/dialects/transform_xegpu_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ def setDescLayoutInstData():
# CHECK: inst_data = [8, 16]


@run
def setDescLayoutSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.create_nd_tdesc"),
)
with InsertionPoint(sequence.body):
xegpu.set_desc_layout(
sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], slice_dims=[0]
)
transform.YieldOp()
# CHECK-LABEL: TEST: setDescLayoutSlice
# CHECK: %0 = transform.xegpu.set_desc_layout %
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: slice_dims = [0]


@run
def setOpLayoutAttrOperandMinimal():
sequence = transform.SequenceOp(
Expand Down Expand Up @@ -106,13 +125,41 @@ def setOpLayoutAttrResult():
result=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttr
# CHECK-LABEL: TEST: setOpLayoutAttrResult
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# CHECK: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]


@run
def setOpLayoutAttrResultSlice():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
xegpu.set_op_layout_attr(
sequence.bodyTarget,
index=0,
sg_layout=[6, 4],
sg_data=[32, 16],
inst_data=[8, 16],
slice_dims=[0],
result=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
# CHECK: transform.xegpu.set_op_layout_attr %
# NO-CHECK: index = 0
# CHECK: result
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
# CHECK: slice_dims = [0]


@run
Expand Down