diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index b985d5450be0e..ed277ef7bd554 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -16,6 +16,24 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def GetDescOp : Op, + NavigationTransformOpTrait, MemoryEffectsOpInterface +]> { + + let summary = "Get a handle to the descriptor op of a value."; + let description = [{ + Traces the producers of the given value until an `xegpu.create_nd_tdesc` + descriptor op is found. Returns a handle to it. Currently traces + producers by following only the first operand of producer ops. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target); + + let results = (outs TransformHandleTypeInterface:$descHandle); + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; +} + def SetDescLayoutOp : Op, @@ -31,16 +49,16 @@ def SetDescLayoutOp : Op : $sg_layout, - Variadic : $sg_data, - Variadic : $inst_data, + TransformHandleTypeInterface:$target, + Variadic:$sg_layout, + Variadic:$sg_data, + Variadic:$inst_data, DefaultValuedOptionalAttr:$static_sg_layout, DefaultValuedOptionalAttr:$static_sg_data, DefaultValuedOptionalAttr:$static_inst_data ); - let results = (outs TransformHandleTypeInterface : $transformed); + let results = (outs TransformHandleTypeInterface:$transformed); let builders = [ OpBuilder<(ins "Value":$target, "ArrayRef":$mixedSgLayout, diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 8943ba09d9c34..0683699f467e9 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -13,6 +13,9 @@ #include +#include "llvm/Support/DebugLog.h" +#define DEBUG_TYPE "xegpu-transforms" + using namespace mlir; using namespace mlir::transform; @@ -76,6 +79,45 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt( return DiagnosedSilenceableFailure::success(); } +/// Find producer operation of type T for the given value. +/// It's assumed that producer ops are chained through their first operand. +/// Producer chain is traced trough loop block arguments (init values). +template +static std::optional findProducerOfType(Value val) { + Value currentValue = val; + if (!currentValue.getDefiningOp()) { + // Value may be a block argument initialized outside a loop. + if (val.getNumUses() == 0) { + LDBG() << "Failed to find producer op, value has no uses."; + return std::nullopt; + } + auto userOp = val.getUsers().begin(); + auto parentLoop = userOp->getParentOfType(); + if (!parentLoop) { + LDBG() << "Failed to find producer op, not in a loop."; + return std::nullopt; + } + int64_t iterArgIdx; + if (auto iterArg = llvm::dyn_cast(currentValue)) { + auto numInductionVars = parentLoop.getLoopInductionVars()->size(); + iterArgIdx = iterArg.getArgNumber() - numInductionVars; + currentValue = parentLoop.getInits()[iterArgIdx]; + } else { + LDBG() << "Failed to find producer op, value not in init values."; + return std::nullopt; + } + } + Operation *producerOp = currentValue.getDefiningOp(); + + if (auto matchingOp = dyn_cast(producerOp)) + return matchingOp; + + if (producerOp->getNumOperands() == 0) + return std::nullopt; + + return findProducerOfType(producerOp->getOperand(0)); +} + /// Create a layout attribute from the given parameters. static xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef sgLayout, @@ -111,6 +153,29 @@ setDescLayout(transform::TransformRewriter &rewriter, return newDescOp; } +DiagnosedSilenceableFailure +transform::GetDescOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) { + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + } + + auto maybeDescOp = + findProducerOfType(*targetValues.begin()); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) + << "Could not find a matching descriptor op when walking the " + "producer chain of the first operand."; + } + + results.set(llvm::cast(getResult()), {*maybeDescOp}); + return DiagnosedSilenceableFailure::success(); +} + void transform::SetDescLayoutOp::build(OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedSgLayout, diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 2918bf592880a..d23f2ac16429f 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -7,6 +7,7 @@ try: from ...ir import * + from ...dialects import transform from .._ods_common import _cext as _ods_cext from .._ods_common import ( MixedValues, @@ -20,6 +21,26 @@ from typing import Union, Optional +@_ods_cext.register_operation(_Dialect, replace=True) +class GetDescOp(GetDescOp): + """Specialization for GetDescOp class.""" + + def __init__( + self, + target: Value, + *, + loc=None, + ip=None, + ): + desc_type = transform.AnyOpType.get() + super().__init__( + desc_type, + target, + loc=loc, + ip=ip, + ) + + @_ods_cext.register_operation(_Dialect, replace=True) class SetDescLayoutOp(SetDescLayoutOp): """Specialization for SetDescLayoutOp class.""" diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index 23e1cd946b4cd..342de429d2e90 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -1,5 +1,67 @@ // RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s +// CHECK-LABEL: @get_desc_op_a +func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // expected-remark @below {{found desc op}} + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op + transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @get_desc_op_c +func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + // expected-remark @below {{found desc op}} + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value + %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op + transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op + transform.yield + } +} + +// ----- + // CHECK-LABEL: @set_desc_layout func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) { // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index 1c8a2bcc6a2fb..f83c807f571e1 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -3,7 +3,7 @@ from mlir.ir import * from mlir.dialects import transform from mlir.dialects.transform import xegpu -from mlir.dialects.transform import structured +from mlir.dialects.transform import AnyValueType def run(f): @@ -16,6 +16,21 @@ def run(f): return f +@run +def getDescOpDefaultIndex(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + desc_handle = xegpu.GetDescOp(operand) + transform.YieldOp() + # CHECK-LABEL: TEST: getDescOpDefaultIndex + # CHECK: transform.xegpu.get_desc_op % + + @run def setDescLayoutMinimal(): sequence = transform.SequenceOp(