Skip to content

Conversation

@tkarna
Copy link
Contributor

@tkarna tkarna commented Nov 10, 2025

Adds transform.xegpu.insert_prefetch transform op that inserts xegpu.prefetch_nd ops for the given Value in an scf.for loop.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

There are some changes with respect to the RFC:

  • insert_prefetch only inserts the prefetch ops, it does not set the xegpu.layout attributes. Those must be set separately using the transform.xegpu.set_desc_layout transform op.
  • insert_prefetch op returns a handle to the newly created xegpu.create_nd_desc op of the prefetch ops.
  • The parent scf.for loop is matched automatically (as parent of the load op), it is no longer passed in as an argument.
  • Like the other revised ops, insert_prefetch does not take an operation handle and operand index to define the operand value anymore. Instead, it takes a handle to the value itself.

Example:

%tile_a = transform.get_operand %dpas_op[0] : (!transform.any_op) -> !transform.any_value
%desc_op  = transform.xegpu.insert_prefetch %tile_a nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
%desc_op2 = transform.xegpu.set_desc_layout %desc_op sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op

@tkarna
Copy link
Contributor Author

tkarna commented Nov 10, 2025

@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Tuomas Kärnä (tkarna)

Changes

Adds transform.xegpu.insert_prefetch transform op that inserts xegpu.prefetch_nd ops for the given Value in an scf.for loop.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

There are some changes with respect to the RFC:

  • insert_prefetch only inserts the prefetch ops, it does not set the xegpu.layout attributes. Those must be set separately using the transform.xegpu.set_desc_layout transform op.
  • insert_prefetch op returns a handle to the newly created xegpu.create_nd_desc op of the prefetch ops.
  • The parent scf.for loop is matched automatically (as parent of the load op), it is no longer passed in as an argument.
  • Like the other revised ops, insert_prefetch does not take an operation handle and operand index to define the operand value anymore. Instead, it takes a handle to the value itself.

Example:

%tile_a = transform.get_operand %dpas_op[0] : (!transform.any_op) -> !transform.any_value
%desc_op  = transform.xegpu.insert_prefetch %tile_a nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
%desc_op2 = transform.xegpu.set_desc_layout %desc_op sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op

Full diff: https://github.com/llvm/llvm-project/pull/167356.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+43)
  • (modified) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+138)
  • (modified) mlir/python/mlir/dialects/transform/xegpu.py (+32)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir (+31)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+76)
  • (modified) mlir/test/python/dialects/transform_xegpu_ext.py (+66-1)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 34f333e556deb..41f7afa78212e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -161,4 +161,47 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
   }];
 }
 
+def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
+  let description = [{
+    Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
+    transform finds the corresponding `xegpu.load_nd` op and inserts
+    `xegpu.prefetch` operations for the tile. The load op must reside within the
+    `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
+    argument. Returns a handle to the created `xegpu.create_nd_desc` op.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$target,
+                   Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
+                   DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
+                   );
+
+  let results = (outs TransformHandleTypeInterface:$desc_op);
+
+  let assemblyFormat = [{
+    $target
+    `nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)?
+    attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    OpFoldResult getNbPrefetch() {
+      auto cxt = getContext();
+      if (getDynamicNbPrefetch())
+        return OpFoldResult(getDynamicNbPrefetch());
+      return OpFoldResult(IntegerAttr::get(
+                          IntegerType::get(cxt, 64), getStaticNbPrefetch()));
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 5fdd8534e4e51..d9222d10a5e70 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 
@@ -341,6 +342,143 @@ void transform::SetOpLayoutAttrOp::getEffects(
   modifiesPayload(effects);
 }
 
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+                                   transform::TransformResults &results,
+                                   transform::TransformState &state) {
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+  auto value = *targetValues.begin();
+
+  int64_t nbPrefetch = getStaticNbPrefetch();
+  if (getDynamicNbPrefetch()) {
+    // Get dynamic prefetch count from transform param or handle.
+    SmallVector<int32_t> dynamicNbPrefetch;
+    auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+                                          {getDynamicNbPrefetch()});
+    if (!status.succeeded())
+      return status;
+    if (dynamicNbPrefetch.size() != 1) {
+      return emitDefiniteFailure()
+             << "requires exactly one value for dynamic_nb_prefetch";
+    }
+    nbPrefetch = dynamicNbPrefetch[0];
+  }
+  if (nbPrefetch <= 0) {
+    return emitSilenceableFailure(getLoc())
+           << "nb_prefetch must be a positive integer.";
+  }
+
+  // Find load operation of the operand.
+  auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+  if (!maybeLoadOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+  }
+  auto loadOp = *maybeLoadOp;
+  if (loadOp.getMixedOffsets().size() == 0) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Load op must have offsets.";
+    diag.attachNote(loadOp.getLoc()) << "load op";
+    return diag;
+  }
+
+  // Find the parent scf.for loop.
+  auto forOp = loadOp->getParentOfType<scf::ForOp>();
+  if (!forOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Load op is not contained in a scf.for loop.";
+    diag.attachNote(loadOp.getLoc()) << "load op";
+    return diag;
+  }
+
+  // Find descriptor op.
+  auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+  if (!maybeDescOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+  }
+  auto descOp = *maybeDescOp;
+  if (descOp.getMixedOffsets().size() > 0) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "desc op with offsets is not supported.";
+    diag.attachNote(descOp.getLoc()) << "desc op";
+  }
+
+  // Clone desc op outside the loop.
+  rewriter.setInsertionPoint(forOp);
+  auto newDescOp =
+      cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+  // Clone reduction loop to emit initial prefetches.
+  // Compute upper bound of the init loop: start + nbPrefetch * step.
+  auto nbPrefetchCst =
+      arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+  auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+  auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+      forOp.getLoc(), forOp.getLowerBound(), nbStep);
+  auto initForOp =
+      scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+                         initUpBound, forOp.getStep());
+
+  auto ctx = rewriter.getContext();
+  auto readCacheHint =
+      xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+  // Modify loadOp mixedOffsets by replacing the for loop induction variable
+  // with the given value.
+  auto getPrefetchOffsets =
+      [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+    IRMapping mapping;
+    mapping.map(forOp.getInductionVar(), replacementVal);
+    SmallVector<Value> dynamicOffsets =
+        llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+          return mapping.lookupOrDefault(v);
+        }));
+    auto constOffsets = loadOp.getConstOffsets().value();
+    return getMixedValues(constOffsets, dynamicOffsets, ctx);
+  };
+
+  // Insert prefetch op in init loop.
+  // Replace induction var with the init loop induction var.
+  rewriter.setInsertionPointToStart(initForOp.getBody());
+  xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+                              newDescOp.getResult(),
+                              getPrefetchOffsets(initForOp.getInductionVar()),
+                              readCacheHint, readCacheHint, readCacheHint);
+
+  // Insert prefetch op in main loop.
+  // Calculate prefetch offset after the init prefetches have been issued.
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+                                              forOp.getInductionVar(), nbStep);
+  // Replace induction var with correct offset.
+  xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+                              newDescOp.getResult(),
+                              getPrefetchOffsets(prefetchOffset), readCacheHint,
+                              readCacheHint, readCacheHint);
+
+  // Unroll the init loop.
+  if (failed(loopUnrollFull(initForOp))) {
+    return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
+  }
+
+  results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::InsertPrefetchOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index ce8015d8f557b..b14426b13fae3 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -11,6 +11,7 @@
     from .._ods_common import _cext as _ods_cext
     from .._ods_common import (
         MixedValues,
+        MixedInt,
         get_op_result_or_value as _get_op_result_or_value,
         _dispatch_dynamic_index_list,
     )
@@ -132,3 +133,34 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class InsertPrefetchOp(InsertPrefetchOp):
+    """Specialization for InsertPrefetchOp class."""
+
+    def __init__(
+        self,
+        target: Value,
+        *,
+        nb_prefetch: Optional[MixedInt] = 1,
+        loc=None,
+        ip=None,
+    ):
+        static_nb_prefetch = 1
+        dynamic_nb_prefetch = None
+        if isinstance(nb_prefetch, int):
+            static_nb_prefetch = nb_prefetch
+        elif isinstance(nb_prefetch, IntegerAttr):
+            static_nb_prefetch = nb_prefetch.value  # pytype: disable=attribute-error
+        elif isinstance(nb_prefetch, (Operation, Value, OpView)):
+            dynamic_nb_prefetch = nb_prefetch
+
+        super().__init__(
+            transform.AnyOpType.get(),
+            target,
+            dynamic_nb_prefetch=dynamic_nb_prefetch,
+            static_nb_prefetch=static_nb_prefetch,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 726b6748452ae..aaacf2a5f9280 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -71,3 +71,34 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_c
+func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c4096 = arith.constant 4096 : index
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  // expected-note@below {{load op}}
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+    %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+    %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+    %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+    scf.yield %7 : vector<256x256xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
+    // expected-error@below {{Load op is not contained in a scf.for loop.}}
+    %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index bd6a79244ed30..1fb1571b2bb43 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -252,3 +252,79 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a
+func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c4096 = arith.constant 4096 : index
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: xegpu.create_nd_tdesc %arg0
+  // CHECK: xegpu.create_nd_tdesc %arg1
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+  // CHECK: xegpu.prefetch_nd %[[V0]]
+  %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  // CHECK: scf.for
+  %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+    // CHECK: xegpu.prefetch_nd %[[V0]]
+    %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+    %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+    %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+    scf.yield %7 : vector<256x256xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+    %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
+func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c4096 = arith.constant 4096 : index
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: xegpu.create_nd_tdesc %arg0
+  // CHECK: xegpu.create_nd_tdesc %arg1
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+  // CHECK: xegpu.prefetch_nd %[[V0]]
+  // CHECK: xegpu.prefetch_nd %[[V0]]
+  %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  // CHECK: scf.for
+  %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+    // CHECK: xegpu.prefetch_nd %[[V0]]
+    %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+    %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+    %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+    scf.yield %7 : vector<256x256xf16>
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    %nb = transform.param.constant 2 : i64 -> !transform.param<i64>
+    // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+    %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb :  (!transform.any_value, !transform.param<i64>) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 0b587d2020aa6..4884f6d2b9cb5 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
 from mlir.ir import *
 from mlir.dialects import transform
 from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import AnyValueType
+from mlir.dialects.transform import structured, AnyValueType
 
 
 def run(f):
@@ -113,3 +113,68 @@ def setOpLayoutAttrResult():
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
+
+
+@run
+def insertPrefetch0():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        xegpu.InsertPrefetchOp(
+            operand,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: insertPrefetch0
+    # CHECK: %[[OPR:.*]] = get_operand
+    # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+
+
+@run
+def insertPrefetchNbPrefetch():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        xegpu.InsertPrefetchOp(
+            operand,
+            nb_prefetch=2,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: insertPrefetchNbPrefetch
+    # CHECK: %[[OPR:.*]] = get_operand
+    # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+    # CHECK-SAME: nb_prefetch = 2
+
+
+@run
+def insertPrefetchNbPrefetchParam():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        int32_t = IntegerType.get_signless(32)
+        param_int32_t = transform.ParamType.get(int32_t)
+        nb_param = transform.ParamConstantOp(
+            param_int32_t,
+            IntegerAttr.get(int32_t, 2),
+        )
+        xegpu.InsertPrefetchOp(
+            operand,
+            nb_prefetch=nb_param,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
+    # CHECK: %[[OPR:.*]] = get_operand
+    # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
+    # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+    # CHECK-SAME: nb_prefetch = %[[PARAM_OP]]

@silee2
Copy link
Contributor

silee2 commented Nov 10, 2025

The parent scf.for loop is matched automatically (as parent of the load op), it is no longer passed in as an argument.

What happens when there is no parent scf.for, prefetch insertion will still happen?
If insertion does not happen, what would be the value of create_nd_desc returned?
Previous definition required passing scf.for, so this ambiguous case will not happen.

@silee2 silee2 requested review from Jianhui-Li and silee2 November 10, 2025 18:18
@tkarna
Copy link
Contributor Author

tkarna commented Nov 10, 2025

The parent scf.for loop is matched automatically (as parent of the load op), it is no longer passed in as an argument.

What happens when there is no parent scf.for, prefetch insertion will still happen? If insertion does not happen, what would be the value of create_nd_desc returned? Previous definition required passing scf.for, so this ambiguous case will not happen.

The op fails if scf.for op cannot be found, or if the corresponding load op resides outside the scf.for op (indicating that the tile is loop independent). As the op fails, there's no return value. See the transform-ops-invalid.mlir test.

To be exact, the logic is: for the given Value (e.g. vector) a producing load_nd op must be found and its parent must be a scf.for loop. In that case, the prefetch ops are emitted in that scf.for loop.

// CHECK: xegpu.create_nd_tdesc %arg1
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
// CHECK-SAME: !xegpu.tensor_desc<256x32xf16
// CHECK: xegpu.prefetch_nd %[[V0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tests should check for generated offsets for prefetch_nd ops.
Especially for the in loop prefetch_nd that needs offsets adjusted based on nb_prefetch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test is updated an proper offsets are not checked.

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of comments. Others are maybe better suited to judge if the transform "plays nice" with regards to how xegpu ops interact.

}

// Find the parent scf.for loop.
auto forOp = loadOp->getParentOfType<scf::ForOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion: could this maybe be generalized to support any parent who implements the LoopLikeOpInterface (or whichever name it goes by)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it can be generalized. I however suggest deferring this to time when we have a concrete use case.

@Jianhui-Li
Copy link
Contributor

I think the op design can be more flexible if we further split the current op to: get_load_op and then insert_prefetch applied to the load_op instead of dpas.
When the user case becomes more complex, Insert_prefetch for dpas become less intuitive. Since operand A maybe not inside the K loop, or it is inside the loop coming from last dpas result + some post-op, sometime the load can be a load from slm. At that point, user would like to work against the load directly and try to insert prefetch, not caring the dpas op.

@tkarna
Copy link
Contributor Author

tkarna commented Nov 11, 2025

I think the op design can be more flexible if we further split the current op to: get_load_op and then insert_prefetch applied to the load_op instead of dpas. When the user case becomes more complex, Insert_prefetch for dpas become less intuitive. Since operand A maybe not inside the K loop, or it is inside the loop coming from last dpas result + some post-op, sometime the load can be a load from slm. At that point, user would like to work against the load directly and try to insert prefetch, not caring the dpas op.

This API is not dpas specific: insert_prefetch takes a handle to a Value (typically vector). It's not immediately clear that having an API for the load op is better than an API for the value the load op produces. If the producer chain is complex, the user can use some intermediate value in the insert_prefetch op, it does not have to be a dpas operand.

That said, I'd propose we postpone this change instead of adding a new get_load_op now. If in the future we have a generic find_producer_of_type op in the transform dialect, that would work for both get_desc_op and get_load_op use cases.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with this as initial version.

I think it is better to evolve to associate prefetch with load op, not a value and the implementation implicitly find the producer loads op. A value is more flexible, but I think it is vague when the use case becomes more complex. Say, if the B matrix is computed from 2 loads (A, and bias or scale), it is not clear that user should use this daps's input B operand value, or user should use 2 loads' result values.

let description = [{
Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
transform finds the corresponding `xegpu.load_nd` op and inserts
`xegpu.prefetch` operations for the tile. The load op must reside within the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xegpu.prefetch -> xegpu.prefetch_nd

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth to mention the default nb_prefetch value here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@Jianhui-Li
Copy link
Contributor

I think the op design can be more flexible if we further split the current op to: get_load_op and then insert_prefetch applied to the load_op instead of dpas. When the user case becomes more complex, Insert_prefetch for dpas become less intuitive. Since operand A maybe not inside the K loop, or it is inside the loop coming from last dpas result + some post-op, sometime the load can be a load from slm. At that point, user would like to work against the load directly and try to insert prefetch, not caring the dpas op.

This API is not dpas specific: insert_prefetch takes a handle to a Value (typically vector). It's not immediately clear that having an API for the load op is better than an API for the value the load op produces. If the producer chain is complex, the user can use some intermediate value in the insert_prefetch op, it does not have to be a dpas operand.

That said, I'd propose we postpone this change instead of adding a new get_load_op now. If in the future we have a generic find_producer_of_type op in the transform dialect, that would work for both get_desc_op and get_load_op use cases.

I think exposing find_producer_of_type op to user gives user better control. I am not sure that find_producer_of_type can be generic and be part of transform dialect though, say, we have a code sequence -
Load A
Load B
Load D
C = Matmul (A, B)
E = Matmaul (C, D)
If we try to use a generic find_producer_of_type to find load op for matrix C, will that return load A? Or it will stop when it runs into Matmul and report failure?

We don't really depend on find_producer_of_type op being admitted to transform dialect, if it does work better for mor examples, maybe we just expose that as a generic form of get_desc_op an get_load_op for xegpu.

@tkarna tkarna force-pushed the xegpu-tr-ops-insert-prefetch branch from 71a38e8 to 85aafcc Compare November 11, 2025 18:30
@rolfmorel
Copy link
Contributor

I too like the idea of having a generic find_producer_of_type op in the transform dialect.

Though as previously discussed with @tkarna, to make it truly generic we probably need some kind of ProducerConsumerMappingOpInterface for the ops that we want to be "threading through" to implement. Otherwise find_producer_of_type's algorithm for following producer-consumer chains is either going to be heuristic (e.g. always through first operand) or partial (e.g. a table of ops embedded inside find_producer_of_type for which we know the mapping -- note that this mapping is pretty much guaranteed to get out of sync with op definitions).

My take on it is that in general it is non-trivial. Though we should return to it when we encounter more of this kind of "matching" in schedules.

@tkarna
Copy link
Contributor Author

tkarna commented Nov 11, 2025

@Jianhui-Li @rolfmorel Yes, in general this is a question of how we should do op matching for more complex workloads. I'm inclined to think that making better use of existing transform dialect matchers combined with some xegpu specific ones would probably be sufficient. But we can refine as we encounter new use cases.

@rolfmorel rolfmorel merged commit 3c52f53 into llvm:main Nov 12, 2025
10 checks passed
@tkarna tkarna deleted the xegpu-tr-ops-insert-prefetch branch November 12, 2025 10:26
WillFroom added a commit to WillFroom/llvm-project that referenced this pull request Nov 12, 2025
git-crd pushed a commit to git-crd/crd-llvm-project that referenced this pull request Nov 13, 2025
Adds `transform.xegpu.insert_prefetch` transform op that inserts
`xegpu.prefetch_nd` ops for the given `Value` in an `scf.for` loop.
git-crd pushed a commit to git-crd/crd-llvm-project that referenced this pull request Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants