diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 34f333e556deb..b33b0a6110b1e 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -161,4 +161,66 @@ def SetOpLayoutAttrOp : Op, + TransformOpInterface +]> { + + let summary = "Convert xegpu.layout attribute for a value."; + let description = [{ + Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute + of a value. First, the `xegpu.load_nd` producer op of the value is found. + It must already be annotated with a layout. An `xegpu.convert_layout` op, + whose destination layout is defined by the `sg_layout`, `sg_data` and + optional `inst_data` attributes, is inserted after the load op. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target, + Variadic:$sg_layout, + Variadic:$sg_data, + Variadic:$inst_data, + DefaultValuedOptionalAttr:$static_sg_layout, + DefaultValuedOptionalAttr:$static_sg_data, + DefaultValuedOptionalAttr:$static_inst_data + ); + + let results = (outs); + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef":$mixedSgLayout, + "ArrayRef":$mixedSgData, + "ArrayRef":$mixedInstData + )>, + ]; + + let assemblyFormat = [{ + $target + `sg_layout` `=` custom($sg_layout, $static_sg_layout) + `sg_data` `=` custom($sg_data, $static_sg_data) + (`inst_data` `=` custom($inst_data, $static_inst_data)^)? + attr-dict `:` qualified(type(operands)) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() { + Builder b(getContext()); + return getMixedValues(getStaticSgLayout(), getSgLayout(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() { + Builder b(getContext()); + return getMixedValues(getStaticSgData(), getSgData(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() { + Builder b(getContext()); + return getMixedValues(getStaticInstData(), getInstData(), b); + } + }]; +} + #endif // XEGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 5fdd8534e4e51..45c76a7859a19 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -341,6 +341,85 @@ void transform::SetOpLayoutAttrOp::getEffects( modifiesPayload(effects); } +void transform::ConvertLayoutOp::build(OpBuilder &builder, + OperationState &ostate, Value target, + ArrayRef mixedSgLayout, + ArrayRef mixedSgData, + ArrayRef mixedInstData) { + SmallVector staticSgLayout, staticSgData, staticInstData; + SmallVector dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, ostate, target.getType(), + /*target=*/target, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData); +} + +DiagnosedSilenceableFailure +transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) { + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + } + + auto value = *targetValues.begin(); + + xegpu::LayoutAttr layoutAttr = nullptr; + auto status = getLayoutAttrFromOperands(getContext(), state, (*this), + getMixedSgLayout(), getMixedSgData(), + getMixedInstData(), layoutAttr); + if (!status.succeeded()) + return status; + + // Get load op. + auto maybeLoadOp = findProducerOfType(value); + if (!maybeLoadOp) { + return emitSilenceableFailure(getLoc()) << "Could not find load op."; + } + auto loadOp = *maybeLoadOp; + // Get load op operand value layout + auto producerLayoutAttr = + xegpu::getDistributeLayoutAttr(loadOp.getOperand(0)); + if (!producerLayoutAttr) { + return emitSilenceableFailure(getLoc()) + << "Operand producer op does not have a layout attr."; + } + + if (producerLayoutAttr != layoutAttr) { + rewriter.setInsertionPointAfter(loadOp.getOperation()); + auto source = loadOp.getResult(); + auto convLayoutOp = xegpu::ConvertLayoutOp::create( + rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr, + layoutAttr); + // Replace load op result with the converted layout. + rewriter.replaceUsesWithIf( + source, convLayoutOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != convLayoutOp.getOperation(); + }); + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::ConvertLayoutOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + modifiesPayload(effects); +} + namespace { class XeGPUTransformDialectExtension : public transform::TransformDialectExtension< diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index ce8015d8f557b..6bf8ad3064be1 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -132,3 +132,46 @@ def __init__( loc=loc, ip=ip, ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConvertLayoutOp(ConvertLayoutOp): + """Specialization for ConvertLayoutOp class.""" + + def __init__( + self, + target: Value, + sg_layout: MixedValues, + sg_data: MixedValues, + *, + inst_data: Optional[MixedValues] = None, + loc=None, + ip=None, + ): + inst_data = [] if inst_data is None else inst_data + ( + dynamic_sg_layout, + static_sg_layout, + _, + ) = _dispatch_dynamic_index_list(sg_layout) + ( + dynamic_sg_data, + static_sg_data, + _, + ) = _dispatch_dynamic_index_list(sg_data) + ( + dynamic_inst_data, + static_inst_data, + _, + ) = _dispatch_dynamic_index_list(inst_data) + super().__init__( + target, + dynamic_sg_layout, + dynamic_sg_data, + dynamic_inst_data, + static_sg_layout=static_sg_layout, + static_sg_data=static_sg_data, + static_inst_data=static_inst_data, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index bd6a79244ed30..2a914d7604ba9 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -252,3 +252,66 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @convert_layout_a +func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c0 = arith.constant 0 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout> -> vector<256x32xf16> + // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] + // CHECK: input_layout = #xegpu.layout + // CHECK: target_layout = #xegpu.layout + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V2]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.convert_layout %{{.*}} + transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @convert_layout_a_sg_param +func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c0 = arith.constant 0 : index + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout> + // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]] + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout> -> vector<256x32xf16> + // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]] + // CHECK: input_layout = #xegpu.layout + // CHECK: target_layout = #xegpu.layout + %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: = xegpu.dpas %[[V2]] + %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.convert_layout %{{.*}} + %layout0 = transform.param.constant 8 : i64 -> !transform.param + transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index 0b587d2020aa6..4fda801e48964 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -113,3 +113,47 @@ def setOpLayoutAttrResult(): # CHECK: sg_layout = [6, 4] # CHECK: sg_data = [32, 16] # CHECK: inst_data = [8, 16] + + +@run +def ConvertLayoutMinimal(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.ConvertLayoutOp( + operand, + sg_layout=[6, 4], + sg_data=[32, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: ConvertLayoutMinimal + # CHECK: transform.xegpu.convert_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + + +@run +def ConvertLayout(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1]) + xegpu.ConvertLayoutOp( + operand, + sg_layout=[6, 4], + sg_data=[32, 16], + inst_data=[8, 16], + ) + transform.YieldOp() + # CHECK-LABEL: TEST: ConvertLayout + # CHECK: transform.xegpu.convert_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16]