-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU][TransformOps] Add insert_prefetch op #167356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Tuomas Kärnä (tkarna) ChangesAdds For reference, the rationale behind xegpu transform ops is outlined in this RFC document. There are some changes with respect to the RFC:
Example: %tile_a = transform.get_operand %dpas_op[0] : (!transform.any_op) -> !transform.any_value
%desc_op = transform.xegpu.insert_prefetch %tile_a nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
%desc_op2 = transform.xegpu.set_desc_layout %desc_op sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_opFull diff: https://github.com/llvm/llvm-project/pull/167356.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 34f333e556deb..41f7afa78212e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -161,4 +161,47 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
}];
}
+def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+]> {
+
+ let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
+ let description = [{
+ Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
+ transform finds the corresponding `xegpu.load_nd` op and inserts
+ `xegpu.prefetch` operations for the tile. The load op must reside within the
+ `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
+ argument. Returns a handle to the created `xegpu.create_nd_desc` op.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$target,
+ Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
+ DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
+ );
+
+ let results = (outs TransformHandleTypeInterface:$desc_op);
+
+ let assemblyFormat = [{
+ $target
+ `nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)?
+ attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ OpFoldResult getNbPrefetch() {
+ auto cxt = getContext();
+ if (getDynamicNbPrefetch())
+ return OpFoldResult(getDynamicNbPrefetch());
+ return OpFoldResult(IntegerAttr::get(
+ IntegerType::get(cxt, 64), getStaticNbPrefetch()));
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 5fdd8534e4e51..d9222d10a5e70 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -341,6 +342,143 @@ void transform::SetOpLayoutAttrOp::getEffects(
modifiesPayload(effects);
}
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues)) {
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ }
+ auto value = *targetValues.begin();
+
+ int64_t nbPrefetch = getStaticNbPrefetch();
+ if (getDynamicNbPrefetch()) {
+ // Get dynamic prefetch count from transform param or handle.
+ SmallVector<int32_t> dynamicNbPrefetch;
+ auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+ {getDynamicNbPrefetch()});
+ if (!status.succeeded())
+ return status;
+ if (dynamicNbPrefetch.size() != 1) {
+ return emitDefiniteFailure()
+ << "requires exactly one value for dynamic_nb_prefetch";
+ }
+ nbPrefetch = dynamicNbPrefetch[0];
+ }
+ if (nbPrefetch <= 0) {
+ return emitSilenceableFailure(getLoc())
+ << "nb_prefetch must be a positive integer.";
+ }
+
+ // Find load operation of the operand.
+ auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+ if (!maybeLoadOp) {
+ return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+ }
+ auto loadOp = *maybeLoadOp;
+ if (loadOp.getMixedOffsets().size() == 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op must have offsets.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find the parent scf.for loop.
+ auto forOp = loadOp->getParentOfType<scf::ForOp>();
+ if (!forOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op is not contained in a scf.for loop.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find descriptor op.
+ auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+ if (!maybeDescOp) {
+ return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+ }
+ auto descOp = *maybeDescOp;
+ if (descOp.getMixedOffsets().size() > 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "desc op with offsets is not supported.";
+ diag.attachNote(descOp.getLoc()) << "desc op";
+ }
+
+ // Clone desc op outside the loop.
+ rewriter.setInsertionPoint(forOp);
+ auto newDescOp =
+ cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+ // Clone reduction loop to emit initial prefetches.
+ // Compute upper bound of the init loop: start + nbPrefetch * step.
+ auto nbPrefetchCst =
+ arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+ auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+ auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+ forOp.getLoc(), forOp.getLowerBound(), nbStep);
+ auto initForOp =
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+ initUpBound, forOp.getStep());
+
+ auto ctx = rewriter.getContext();
+ auto readCacheHint =
+ xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+ // Modify loadOp mixedOffsets by replacing the for loop induction variable
+ // with the given value.
+ auto getPrefetchOffsets =
+ [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+ IRMapping mapping;
+ mapping.map(forOp.getInductionVar(), replacementVal);
+ SmallVector<Value> dynamicOffsets =
+ llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+ return mapping.lookupOrDefault(v);
+ }));
+ auto constOffsets = loadOp.getConstOffsets().value();
+ return getMixedValues(constOffsets, dynamicOffsets, ctx);
+ };
+
+ // Insert prefetch op in init loop.
+ // Replace induction var with the init loop induction var.
+ rewriter.setInsertionPointToStart(initForOp.getBody());
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(initForOp.getInductionVar()),
+ readCacheHint, readCacheHint, readCacheHint);
+
+ // Insert prefetch op in main loop.
+ // Calculate prefetch offset after the init prefetches have been issued.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+ forOp.getInductionVar(), nbStep);
+ // Replace induction var with correct offset.
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(prefetchOffset), readCacheHint,
+ readCacheHint, readCacheHint);
+
+ // Unroll the init loop.
+ if (failed(loopUnrollFull(initForOp))) {
+ return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
+ }
+
+ results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::InsertPrefetchOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index ce8015d8f557b..b14426b13fae3 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -11,6 +11,7 @@
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
MixedValues,
+ MixedInt,
get_op_result_or_value as _get_op_result_or_value,
_dispatch_dynamic_index_list,
)
@@ -132,3 +133,34 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class InsertPrefetchOp(InsertPrefetchOp):
+ """Specialization for InsertPrefetchOp class."""
+
+ def __init__(
+ self,
+ target: Value,
+ *,
+ nb_prefetch: Optional[MixedInt] = 1,
+ loc=None,
+ ip=None,
+ ):
+ static_nb_prefetch = 1
+ dynamic_nb_prefetch = None
+ if isinstance(nb_prefetch, int):
+ static_nb_prefetch = nb_prefetch
+ elif isinstance(nb_prefetch, IntegerAttr):
+ static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error
+ elif isinstance(nb_prefetch, (Operation, Value, OpView)):
+ dynamic_nb_prefetch = nb_prefetch
+
+ super().__init__(
+ transform.AnyOpType.get(),
+ target,
+ dynamic_nb_prefetch=dynamic_nb_prefetch,
+ static_nb_prefetch=static_nb_prefetch,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 726b6748452ae..aaacf2a5f9280 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -71,3 +71,34 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_c
+func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ // expected-note@below {{load op}}
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
+ // expected-error@below {{Load op is not contained in a scf.for loop.}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index bd6a79244ed30..1fb1571b2bb43 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -252,3 +252,79 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a
+func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: xegpu.create_nd_tdesc %arg0
+ // CHECK: xegpu.create_nd_tdesc %arg1
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+ // CHECK: xegpu.prefetch_nd %[[V0]]
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ // CHECK: scf.for
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ // CHECK: xegpu.prefetch_nd %[[V0]]
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
+func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: xegpu.create_nd_tdesc %arg0
+ // CHECK: xegpu.create_nd_tdesc %arg1
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+ // CHECK: xegpu.prefetch_nd %[[V0]]
+ // CHECK: xegpu.prefetch_nd %[[V0]]
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ // CHECK: scf.for
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ // CHECK: xegpu.prefetch_nd %[[V0]]
+ %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %nb = transform.param.constant 2 : i64 -> !transform.param<i64>
+ // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 0b587d2020aa6..4884f6d2b9cb5 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import AnyValueType
+from mlir.dialects.transform import structured, AnyValueType
def run(f):
@@ -113,3 +113,68 @@ def setOpLayoutAttrResult():
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
+
+
+@run
+def insertPrefetch0():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ xegpu.InsertPrefetchOp(
+ operand,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetch0
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+
+
+@run
+def insertPrefetchNbPrefetch():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ xegpu.InsertPrefetchOp(
+ operand,
+ nb_prefetch=2,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetchNbPrefetch
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK-SAME: nb_prefetch = 2
+
+
+@run
+def insertPrefetchNbPrefetchParam():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ int32_t = IntegerType.get_signless(32)
+ param_int32_t = transform.ParamType.get(int32_t)
+ nb_param = transform.ParamConstantOp(
+ param_int32_t,
+ IntegerAttr.get(int32_t, 2),
+ )
+ xegpu.InsertPrefetchOp(
+ operand,
+ nb_prefetch=nb_param,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
|
What happens when there is no parent scf.for, prefetch insertion will still happen? |
The op fails if To be exact, the logic is: for the given |
| // CHECK: xegpu.create_nd_tdesc %arg1 | ||
| // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 | ||
| // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 | ||
| // CHECK: xegpu.prefetch_nd %[[V0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the tests should check for generated offsets for prefetch_nd ops.
Especially for the in loop prefetch_nd that needs offsets adjusted based on nb_prefetch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test is updated an proper offsets are not checked.
rolfmorel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of comments. Others are maybe better suited to judge if the transform "plays nice" with regards to how xegpu ops interact.
| } | ||
|
|
||
| // Find the parent scf.for loop. | ||
| auto forOp = loadOp->getParentOfType<scf::ForOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor suggestion: could this maybe be generalized to support any parent who implements the LoopLikeOpInterface (or whichever name it goes by)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it can be generalized. I however suggest deferring this to time when we have a concrete use case.
|
I think the op design can be more flexible if we further split the current op to: get_load_op and then insert_prefetch applied to the load_op instead of dpas. |
This API is not dpas specific: insert_prefetch takes a handle to a Value (typically That said, I'd propose we postpone this change instead of adding a new |
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am fine with this as initial version.
I think it is better to evolve to associate prefetch with load op, not a value and the implementation implicitly find the producer loads op. A value is more flexible, but I think it is vague when the use case becomes more complex. Say, if the B matrix is computed from 2 loads (A, and bias or scale), it is not clear that user should use this daps's input B operand value, or user should use 2 loads' result values.
| let description = [{ | ||
| Given a target value (e.g., `vector`) residing in a `scf.for` loop, this | ||
| transform finds the corresponding `xegpu.load_nd` op and inserts | ||
| `xegpu.prefetch` operations for the tile. The load op must reside within the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xegpu.prefetch -> xegpu.prefetch_nd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth to mention the default nb_prefetch value here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
I think exposing We don't really depend on |
71a38e8 to
85aafcc
Compare
|
I too like the idea of having a generic Though as previously discussed with @tkarna, to make it truly generic we probably need some kind of My take on it is that in general it is non-trivial. Though we should return to it when we encounter more of this kind of "matching" in schedules. |
|
@Jianhui-Li @rolfmorel Yes, in general this is a question of how we should do op matching for more complex workloads. I'm inclined to think that making better use of existing transform dialect matchers combined with some xegpu specific ones would probably be sufficient. But we can refine as we encounter new use cases. |
Adds `transform.xegpu.insert_prefetch` transform op that inserts `xegpu.prefetch_nd` ops for the given `Value` in an `scf.for` loop.
Adds
transform.xegpu.insert_prefetchtransform op that insertsxegpu.prefetch_ndops for the givenValuein anscf.forloop.For reference, the rationale behind xegpu transform ops is outlined in this RFC document.
There are some changes with respect to the RFC:
insert_prefetchonly inserts the prefetch ops, it does not set thexegpu.layoutattributes. Those must be set separately using thetransform.xegpu.set_desc_layouttransform op.insert_prefetchop returns a handle to the newly createdxegpu.create_nd_descop of the prefetch ops.scf.forloop is matched automatically (as parent of the load op), it is no longer passed in as an argument.insert_prefetchdoes not take an operation handle and operand index to define the operand value anymore. Instead, it takes a handle to the value itself.Example: