diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index b985d5450be0e..4e0eae1007c8f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -78,4 +78,69 @@ def SetDescLayoutOp : Op, + TransformOpInterface +]> { + + let summary = "Set xegpu.layout attribute of an op."; + let description = [{ + Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the + `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The + target operand/result value is defined by the `index` argument. The layout + is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes. + }]; + + let arguments = (ins TransformHandleTypeInterface : $target, + DefaultValuedOptionalAttr : $index, + Variadic : $sg_layout, + Variadic : $sg_data, + Variadic : $inst_data, + DefaultValuedOptionalAttr:$static_sg_layout, + DefaultValuedOptionalAttr:$static_sg_data, + DefaultValuedOptionalAttr:$static_inst_data, + DefaultValuedAttr:$result + ); + + let results = (outs); + let builders = [ + OpBuilder<(ins "Value":$target, + "int64_t":$index, + "ArrayRef":$mixedSgLayout, + "ArrayRef":$mixedSgData, + "ArrayRef":$mixedInstData, + CArg<"bool", "false">:$result + )>, + ]; + + let assemblyFormat = [{ + $target (`result` $result^)? (`index` `=` $index^)? + `sg_layout` `=` custom($sg_layout, $static_sg_layout) + `sg_data` `=` custom($sg_data, $static_sg_data) + (`inst_data` `=` custom($inst_data, $static_inst_data)^)? + attr-dict `:` qualified(type(operands)) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() { + Builder b(getContext()); + return getMixedValues(getStaticSgLayout(), getSgLayout(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() { + Builder b(getContext()); + return getMixedValues(getStaticSgData(), getSgData(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() { + Builder b(getContext()); + return getMixedValues(getStaticInstData(), getInstData(), b); + } + }]; +} + #endif // XEGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 8943ba09d9c34..456cfb9ddd2bc 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef sgLayout, /*order=*/nullptr); } +/// Generate `xegpu::LayoutAttr` from op mixed layout values. +DiagnosedSilenceableFailure +getLayoutAttrFromOperands(transform::TransformRewriter &rewriter, + transform::TransformState &state, + TransformOpInterface transformOp, + ArrayRef<::mlir::OpFoldResult> mixedSgLayout, + ArrayRef<::mlir::OpFoldResult> mixedSgData, + ArrayRef<::mlir::OpFoldResult> mixedInstData, + xegpu::LayoutAttr &layoutAttr) { + SmallVector sgLayout, sgData, instData; + auto status = + convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData); + if (!status.succeeded()) + return status; + + status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData); + if (!status.succeeded()) + return status; + auto maybeInstData = instData.empty() + ? std::nullopt + : std::optional>(instData); + + layoutAttr = + createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); + + return DiagnosedSilenceableFailure::success(); +} + /// Replace xegpu.create_nd_desc op with a new one with the given layout. static xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter, @@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, } Operation *target = *targetOps.begin(); - SmallVector sgLayout; - DiagnosedSilenceableFailure status = - convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(rewriter, state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); if (!status.succeeded()) return status; - SmallVector sgData; - status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); - if (!status.succeeded()) - return status; - - SmallVector instData; - status = - convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); - if (!status.succeeded()) - return status; - auto maybeInstData = instData.empty() - ? std::nullopt - : std::optional>(instData); - // For now only create_nd_desc op is supported. auto descOp = dyn_cast(target); if (!descOp) { @@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, } // Set layout attr in desc op's return type. Replaces old desc op. - auto layoutAttr = - createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); // Map result handles. @@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects( modifiesPayload(effects); } +void transform::SetOpLayoutAttrOp::build( + OpBuilder &builder, OperationState &ostate, Value target, int64_t index, + ArrayRef mixedSgLayout, ArrayRef mixedSgData, + ArrayRef mixedInstData, bool result) { + SmallVector staticSgLayout, staticSgData, staticInstData; + SmallVector dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*index=*/index, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData, + /*result=*/result); +} + +DiagnosedSilenceableFailure +transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + bool resultTarget = getResult(); + + int64_t index = getIndex(); + if (resultTarget && index >= target->getNumResults()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op results"; + } + if (!resultTarget && index >= target->getNumOperands()) { + return emitSilenceableFailure(getLoc()) + << "Index exceeds the number of op operands"; + } + + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(rewriter, state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); + if (!status.succeeded()) + return status; + + // Set layout attribute for the op result or operand + if (resultTarget) { + xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr); + } else { + xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr); + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetOpLayoutAttrOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + modifiesPayload(effects); +} + namespace { class XeGPUTransformDialectExtension : public transform::TransformDialectExtension< diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 2918bf592880a..46a1f032630d1 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -64,3 +64,50 @@ def __init__( loc=loc, ip=ip, ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SetOpLayoutAttrOp(SetOpLayoutAttrOp): + """Specialization for SetOpLayoutAttrOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + sg_layout: MixedValues, + sg_data: MixedValues, + *, + inst_data: MixedValues = None, + index: Union[int, Attribute] = None, + result: Union[bool, Attribute] = None, + loc=None, + ip=None, + ): + inst_data = [] if inst_data is None else inst_data + ( + dynamic_sg_layout, + static_sg_layout, + _, + ) = _dispatch_dynamic_index_list(sg_layout) + ( + dynamic_sg_data, + static_sg_data, + _, + ) = _dispatch_dynamic_index_list(sg_data) + ( + dynamic_inst_data, + static_inst_data, + _, + ) = _dispatch_dynamic_index_list(inst_data) + super().__init__( + _get_op_result_or_value(target), + dynamic_sg_layout, + dynamic_sg_data, + dynamic_inst_data, + static_sg_layout=static_sg_layout, + static_sg_data=static_sg_data, + static_inst_data=static_inst_data, + index=index, + result=result, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir index 303584518f9f4..726b6748452ae 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir @@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_bad_result_index +func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Index exceeds the number of op results}} + transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_bad_operand_index +func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Index exceeds the number of op operands}} + transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_multiple +func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Requires exactly one targetOp handle (got 2)}} + transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index 23e1cd946b4cd..089a8fb4fd9b6 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -56,3 +56,137 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result_default_index +func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas + // CHECK-SAME: {layout_result_0 = #xegpu.layout} + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result_sg_param +func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_result_0 = #xegpu.layout} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + %layout0 = transform.param.constant 8 : i64 -> !transform.param + transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result_sg_param2 +func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_result_0 = #xegpu.layout} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + %layout0 = transform.param.constant 8 : i64 -> !transform.param + %layout1 = transform.param.constant 4 : i64 -> !transform.param + transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param, !transform.param + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_result0 +func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_result_0 = #xegpu.layout} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_op_layout_attr_operand_minimal +func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.extf %1 + // CHECK-SAME: {layout_operand_0 = #xegpu.layout} + %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op + transform.yield + } +} +// ----- + +// CHECK-LABEL: @set_op_layout_attr_operand1 +func.func @set_op_layout_attr_operand1(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>) { + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + // CHECK: = arith.addf %1, %3 + // CHECK-SAME: {layout_operand_1 = #xegpu.layout} + %6 = arith.addf %1, %3 : vector<256x32xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} + transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index 1c8a2bcc6a2fb..0f48ef9dc529f 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -49,3 +49,52 @@ def setDescLayoutInstData(): # CHECK: sg_layout = [6, 4] # CHECK: sg_data = [32, 16] # CHECK: inst_data = [8, 16] + + +@run +def setOpLayoutAttrOperandMinimal(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.SetOpLayoutAttrOp( + sequence.bodyTarget, + sg_layout=[6, 4], + sg_data=[32, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOpLayoutAttr + # CHECK: transform.xegpu.set_op_layout_attr % + # NO-CHECK: index = 0 + # NO-CHECK: result + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # NO-CHECK: inst_data + + +@run +def setOpLayoutAttrResult(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + xegpu.SetOpLayoutAttrOp( + sequence.bodyTarget, + index=0, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16], + result=True, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setOpLayoutAttr + # CHECK: transform.xegpu.set_op_layout_attr % + # NO-CHECK: index = 0 + # CHECK: result + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16]