Skip to content

Conversation

@tkarna
Copy link
Contributor

@tkarna tkarna commented Nov 6, 2025

Add transform.xegpu.get_desc_op transform op that finds a xegpu.create_nd_tdesc producer op of a Value.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

Contrary to the RFC, get_desc_op takes a value handle, instead of an operation handle and operand index. The operand value handle can be obtained with transform.get_operand:

%tile_c = transform.get_operand %dpas_op[2] : (!transform.any_op) -> !transform.any_value
%desc_op_c = transform.xegpu.get_desc_op %tile_c : (!transform.any_value) -> !transform.any_op

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Tuomas Kärnä (tkarna)

Changes

Add transform.xegpu.get_desc_op transform op that finds a xegpu.create_nd_tdesc producer op of a Value.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

Contrary to the RFC, get_desc_op takes a value handle, instead of an operation handle and operand index. The operand value handle can be obtained with transform.get_operand:

%tile_c = transform.get_operand %dpas_op[2] : (!transform.any_op) -> !transform.any_value
%desc_op_c = transform.xegpu.get_desc_op %tile_c : (!transform.any_value) -> !transform.any_op

Full diff: https://github.com/llvm/llvm-project/pull/166801.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+17)
  • (modified) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+66)
  • (modified) mlir/python/mlir/dialects/transform/xegpu.py (+21)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+25)
  • (modified) mlir/test/python/dialects/transform_xegpu_ext.py (+16-1)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..199bd2024c373 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -16,6 +16,23 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
+  DeclareOpInterfaceMethods<TransformOpInterface>,
+  NavigationTransformOpTrait, MemoryEffectsOpInterface
+]> {
+
+  let summary = "Get a handle to the descriptor op of a value.";
+  let description = [{
+    Traces the producers of the given value until an `xegpu.create_nd_tdesc`
+    descriptor op is found. Returns a handle to it.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface : $target);
+
+  let results = (outs TransformHandleTypeInterface : $descHandle);
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   AttrSizedOperandSegments,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..a4aaed14dfff6 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -13,6 +13,9 @@
 
 #include <optional>
 
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
 using namespace mlir;
 using namespace mlir::transform;
 
@@ -76,6 +79,47 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Find producer operation of type T for the given value.
+/// It's assumed that producer ops are chained through their first operand.
+/// Producer chain is traced trough loop block arguments (init values).
+template <typename T>
+static std::optional<T> findProducerOfType(Value val) {
+  Value currentValue = val;
+  if (!currentValue.getDefiningOp()) {
+    // Value may be a block argument initialized outside a loop.
+    if (val.getNumUses() == 0) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to find producer op, value has no uses.");
+      return std::nullopt;
+    }
+    auto userOp = val.getUsers().begin();
+    auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
+    if (!parentLoop) {
+      LLVM_DEBUG(llvm::dbgs() << "Failed to find producer op, not in a loop.");
+      return std::nullopt;
+    }
+    int64_t iterArgIdx;
+    if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
+      auto numInductionVars = parentLoop.getLoopInductionVars()->size();
+      iterArgIdx = iterArg.getArgNumber() - numInductionVars;
+      currentValue = parentLoop.getInits()[iterArgIdx];
+    } else {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to find producer op, value not in init values.");
+      return std::nullopt;
+    }
+  }
+  Operation *producerOp = currentValue.getDefiningOp();
+
+  if (auto matchingOp = dyn_cast<T>(producerOp))
+    return matchingOp;
+
+  if (producerOp->getNumOperands() == 0)
+    return std::nullopt;
+
+  return findProducerOfType<T>(producerOp->getOperand(0));
+}
+
 /// Create a layout attribute from the given parameters.
 static xegpu::LayoutAttr
 createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -111,6 +155,28 @@ setDescLayout(transform::TransformRewriter &rewriter,
   return newDescOp;
 }
 
+DiagnosedSilenceableFailure
+transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+
+  auto maybeDescOp =
+      findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
+  if (!maybeDescOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+  }
+
+  results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
+  return DiagnosedSilenceableFailure::success();
+}
+
 void transform::SetDescLayoutOp::build(OpBuilder &builder,
                                        OperationState &result, Value target,
                                        ArrayRef<OpFoldResult> mixedSgLayout,
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..d23f2ac16429f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -7,6 +7,7 @@
 
 try:
     from ...ir import *
+    from ...dialects import transform
     from .._ods_common import _cext as _ods_cext
     from .._ods_common import (
         MixedValues,
@@ -20,6 +21,26 @@
 from typing import Union, Optional
 
 
+@_ods_cext.register_operation(_Dialect, replace=True)
+class GetDescOp(GetDescOp):
+    """Specialization for GetDescOp class."""
+
+    def __init__(
+        self,
+        target: Value,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        desc_type = transform.AnyOpType.get()
+        super().__init__(
+            desc_type,
+            target,
+            loc=loc,
+            ip=ip,
+        )
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class SetDescLayoutOp(SetDescLayoutOp):
     """Specialization for SetDescLayoutOp class."""
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..be8b3155be270 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -1,5 +1,30 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
 
+// CHECK-LABEL: @get_desc_op
+func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // expected-remark @below {{found desc op}}
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value
+    %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
+    transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
 // CHECK-LABEL: @set_desc_layout
 func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..f83c807f571e1 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
 from mlir.ir import *
 from mlir.dialects import transform
 from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import structured
+from mlir.dialects.transform import AnyValueType
 
 
 def run(f):
@@ -16,6 +16,21 @@ def run(f):
     return f
 
 
+@run
+def getDescOpDefaultIndex():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        desc_handle = xegpu.GetDescOp(operand)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: getDescOpDefaultIndex
+    # CHECK: transform.xegpu.get_desc_op %
+
+
 @run
 def setDescLayoutMinimal():
     sequence = transform.SequenceOp(

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir

Author: Tuomas Kärnä (tkarna)

Changes

Add transform.xegpu.get_desc_op transform op that finds a xegpu.create_nd_tdesc producer op of a Value.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

Contrary to the RFC, get_desc_op takes a value handle, instead of an operation handle and operand index. The operand value handle can be obtained with transform.get_operand:

%tile_c = transform.get_operand %dpas_op[2] : (!transform.any_op) -&gt; !transform.any_value
%desc_op_c = transform.xegpu.get_desc_op %tile_c : (!transform.any_value) -&gt; !transform.any_op

Full diff: https://github.com/llvm/llvm-project/pull/166801.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+17)
  • (modified) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+66)
  • (modified) mlir/python/mlir/dialects/transform/xegpu.py (+21)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+25)
  • (modified) mlir/test/python/dialects/transform_xegpu_ext.py (+16-1)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..199bd2024c373 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -16,6 +16,23 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
+  DeclareOpInterfaceMethods<TransformOpInterface>,
+  NavigationTransformOpTrait, MemoryEffectsOpInterface
+]> {
+
+  let summary = "Get a handle to the descriptor op of a value.";
+  let description = [{
+    Traces the producers of the given value until an `xegpu.create_nd_tdesc`
+    descriptor op is found. Returns a handle to it.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface : $target);
+
+  let results = (outs TransformHandleTypeInterface : $descHandle);
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   AttrSizedOperandSegments,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..a4aaed14dfff6 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -13,6 +13,9 @@
 
 #include <optional>
 
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
 using namespace mlir;
 using namespace mlir::transform;
 
@@ -76,6 +79,47 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Find producer operation of type T for the given value.
+/// It's assumed that producer ops are chained through their first operand.
+/// Producer chain is traced trough loop block arguments (init values).
+template <typename T>
+static std::optional<T> findProducerOfType(Value val) {
+  Value currentValue = val;
+  if (!currentValue.getDefiningOp()) {
+    // Value may be a block argument initialized outside a loop.
+    if (val.getNumUses() == 0) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to find producer op, value has no uses.");
+      return std::nullopt;
+    }
+    auto userOp = val.getUsers().begin();
+    auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
+    if (!parentLoop) {
+      LLVM_DEBUG(llvm::dbgs() << "Failed to find producer op, not in a loop.");
+      return std::nullopt;
+    }
+    int64_t iterArgIdx;
+    if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
+      auto numInductionVars = parentLoop.getLoopInductionVars()->size();
+      iterArgIdx = iterArg.getArgNumber() - numInductionVars;
+      currentValue = parentLoop.getInits()[iterArgIdx];
+    } else {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to find producer op, value not in init values.");
+      return std::nullopt;
+    }
+  }
+  Operation *producerOp = currentValue.getDefiningOp();
+
+  if (auto matchingOp = dyn_cast<T>(producerOp))
+    return matchingOp;
+
+  if (producerOp->getNumOperands() == 0)
+    return std::nullopt;
+
+  return findProducerOfType<T>(producerOp->getOperand(0));
+}
+
 /// Create a layout attribute from the given parameters.
 static xegpu::LayoutAttr
 createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -111,6 +155,28 @@ setDescLayout(transform::TransformRewriter &rewriter,
   return newDescOp;
 }
 
+DiagnosedSilenceableFailure
+transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+
+  auto maybeDescOp =
+      findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
+  if (!maybeDescOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+  }
+
+  results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
+  return DiagnosedSilenceableFailure::success();
+}
+
 void transform::SetDescLayoutOp::build(OpBuilder &builder,
                                        OperationState &result, Value target,
                                        ArrayRef<OpFoldResult> mixedSgLayout,
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..d23f2ac16429f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -7,6 +7,7 @@
 
 try:
     from ...ir import *
+    from ...dialects import transform
     from .._ods_common import _cext as _ods_cext
     from .._ods_common import (
         MixedValues,
@@ -20,6 +21,26 @@
 from typing import Union, Optional
 
 
+@_ods_cext.register_operation(_Dialect, replace=True)
+class GetDescOp(GetDescOp):
+    """Specialization for GetDescOp class."""
+
+    def __init__(
+        self,
+        target: Value,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        desc_type = transform.AnyOpType.get()
+        super().__init__(
+            desc_type,
+            target,
+            loc=loc,
+            ip=ip,
+        )
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class SetDescLayoutOp(SetDescLayoutOp):
     """Specialization for SetDescLayoutOp class."""
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..be8b3155be270 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -1,5 +1,30 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
 
+// CHECK-LABEL: @get_desc_op
+func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // expected-remark @below {{found desc op}}
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value
+    %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
+    transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
 // CHECK-LABEL: @set_desc_layout
 func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..f83c807f571e1 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
 from mlir.ir import *
 from mlir.dialects import transform
 from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import structured
+from mlir.dialects.transform import AnyValueType
 
 
 def run(f):
@@ -16,6 +16,21 @@ def run(f):
     return f
 
 
+@run
+def getDescOpDefaultIndex():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        desc_handle = xegpu.GetDescOp(operand)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: getDescOpDefaultIndex
+    # CHECK: transform.xegpu.get_desc_op %
+
+
 @run
 def setDescLayoutMinimal():
     sequence = transform.SequenceOp(

@tkarna
Copy link
Contributor Author

tkarna commented Nov 6, 2025

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good!

Only a couple nits and a request for expanding the tests a little. Other than that, looks good to go to me.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with minor comments

auto maybeDescOp =
findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
if (!maybeDescOp) {
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually user may write a program like below, and findProducerOfType will miss it since it only searches along operand(0).
A_desc = Create_nd_tdesc : <1024x1024xbf16>
A = Load_nd A_desc : <1024x1024xbf16>
Bias = load bias_memref : <1024xbf16>
Bias_bcast = broadcast Bias: <1024x1024xbf16>
A_Bias = add Bias_bcast, A //A tensor is used as 2nd parameter.
C = DPAS A_Bias, ...

I don't think it worths to complicate findProducerOfType(). Maybe the message can add a hint, like "Could not find a matching descriptor op when walking the producer chain of the first operand.";

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I agree. Updated the error message as suggested.

In the future, we can hopefully replace xegpu.get_desc_op with a generic transform dialect find-producer method which tracers through all operands and also supports more than loop block args.

When tracing through all operands the trace will branch out and we could find more than one create desc op. In those cases we could define more anchor ops to deduce the right tiles (e.g. in the above example one could match the broadcast op to get a handle to the bias).

@adam-smnk adam-smnk merged commit 1553f90 into llvm:main Nov 10, 2025
10 checks passed
@tkarna tkarna deleted the xegpu-tr-ops-get-desc-op branch November 10, 2025 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants