Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface
]> {

let summary = "Get a handle to the descriptor op of a value.";
let description = [{
Traces the producers of the given value until an `xegpu.create_nd_tdesc`
descriptor op is found. Returns a handle to it. Currently traces
producers by following only the first operand of producer ops.
}];

let arguments = (ins TransformValueHandleTypeInterface:$target);

let results = (outs TransformHandleTypeInterface:$descHandle);
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand All @@ -31,16 +49,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];

let arguments = (ins
TransformHandleTypeInterface : $target,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
TransformHandleTypeInterface:$target,
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
);

let results = (outs TransformHandleTypeInterface : $transformed);
let results = (outs TransformHandleTypeInterface:$transformed);
let builders = [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
Expand Down
65 changes: 65 additions & 0 deletions mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

#include <optional>

#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "xegpu-transforms"

using namespace mlir;
using namespace mlir::transform;

Expand Down Expand Up @@ -76,6 +79,45 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
return DiagnosedSilenceableFailure::success();
}

/// Find producer operation of type T for the given value.
/// It's assumed that producer ops are chained through their first operand.
/// Producer chain is traced trough loop block arguments (init values).
template <typename T>
static std::optional<T> findProducerOfType(Value val) {
Value currentValue = val;
if (!currentValue.getDefiningOp()) {
// Value may be a block argument initialized outside a loop.
if (val.getNumUses() == 0) {
LDBG() << "Failed to find producer op, value has no uses.";
return std::nullopt;
}
auto userOp = val.getUsers().begin();
auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
if (!parentLoop) {
LDBG() << "Failed to find producer op, not in a loop.";
return std::nullopt;
}
int64_t iterArgIdx;
if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
auto numInductionVars = parentLoop.getLoopInductionVars()->size();
iterArgIdx = iterArg.getArgNumber() - numInductionVars;
currentValue = parentLoop.getInits()[iterArgIdx];
} else {
LDBG() << "Failed to find producer op, value not in init values.";
return std::nullopt;
}
}
Operation *producerOp = currentValue.getDefiningOp();

if (auto matchingOp = dyn_cast<T>(producerOp))
return matchingOp;

if (producerOp->getNumOperands() == 0)
return std::nullopt;

return findProducerOfType<T>(producerOp->getOperand(0));
}

/// Create a layout attribute from the given parameters.
static xegpu::LayoutAttr
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
Expand Down Expand Up @@ -111,6 +153,29 @@ setDescLayout(transform::TransformRewriter &rewriter,
return newDescOp;
}

DiagnosedSilenceableFailure
transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto targetValues = state.getPayloadValues(getTarget());
if (!llvm::hasSingleElement(targetValues)) {
return emitDefiniteFailure()
<< "requires exactly one target value handle (got "
<< llvm::range_size(targetValues) << ")";
}

auto maybeDescOp =
findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
if (!maybeDescOp) {
return emitSilenceableFailure(getLoc())
<< "Could not find a matching descriptor op when walking the "
"producer chain of the first operand.";
}

results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
return DiagnosedSilenceableFailure::success();
}

void transform::SetDescLayoutOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedSgLayout,
Expand Down
21 changes: 21 additions & 0 deletions mlir/python/mlir/dialects/transform/xegpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
MixedValues,
Expand All @@ -20,6 +21,26 @@
from typing import Union, Optional


@_ods_cext.register_operation(_Dialect, replace=True)
class GetDescOp(GetDescOp):
"""Specialization for GetDescOp class."""

def __init__(
self,
target: Value,
*,
loc=None,
ip=None,
):
desc_type = transform.AnyOpType.get()
super().__init__(
desc_type,
target,
loc=loc,
ip=ip,
)


@_ods_cext.register_operation(_Dialect, replace=True)
class SetDescLayoutOp(SetDescLayoutOp):
"""Specialization for SetDescLayoutOp class."""
Expand Down
62 changes: 62 additions & 0 deletions mlir/test/Dialect/XeGPU/transform-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,67 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @get_desc_op_a
func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
// expected-remark @below {{found desc op}}
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
scf.yield %7 : vector<256x256xf16>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @get_desc_op_c
func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
// expected-remark @below {{found desc op}}
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
scf.yield %7 : vector<256x256xf16>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @set_desc_layout
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
Expand Down
17 changes: 16 additions & 1 deletion mlir/test/python/dialects/transform_xegpu_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import xegpu
from mlir.dialects.transform import structured
from mlir.dialects.transform import AnyValueType


def run(f):
Expand All @@ -16,6 +16,21 @@ def run(f):
return f


@run
def getDescOpDefaultIndex():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("xegpu.dpas"),
)
with InsertionPoint(sequence.body):
operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
desc_handle = xegpu.GetDescOp(operand)
transform.YieldOp()
# CHECK-LABEL: TEST: getDescOpDefaultIndex
# CHECK: transform.xegpu.get_desc_op %


@run
def setDescLayoutMinimal():
sequence = transform.SequenceOp(
Expand Down