From 7dc5a277715a3f817541e91e815907461ece6d1f Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Thu, 30 Oct 2025 10:59:51 +0200 Subject: [PATCH 1/3] [mlir][xegpu][transformops] add insert_prefetch op --- .../XeGPU/TransformOps/XeGPUTransformOps.td | 43 ++++++ .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 138 ++++++++++++++++++ mlir/python/mlir/dialects/transform/xegpu.py | 33 +++++ .../Dialect/XeGPU/transform-ops-invalid.mlir | 31 ++++ mlir/test/Dialect/XeGPU/transform-ops.mlir | 76 ++++++++++ .../python/dialects/transform_xegpu_ext.py | 67 ++++++++- 6 files changed, 387 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index f5e4afad535e5..85ad91f94a379 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -200,4 +200,47 @@ def SetGPULaunchThreadsOp }]; } +def InsertPrefetchOp : Op, + TransformOpInterface +]> { + + let summary = "Adds xegpu prefetch ops to matmul operand tiles."; + let description = [{ + Given a target value (e.g., `vector`) residing in a `scf.for` loop, this + transform finds the corresponding `xegpu.load_nd` op and inserts + `xegpu.prefetch` operations for the tile. The load op must reside within the + `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch` + argument. Returns a handle to the created `xegpu.create_nd_desc` op. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target, + Optional:$dynamic_nb_prefetch, + DefaultValuedOptionalAttr:$static_nb_prefetch + ); + + let results = (outs TransformHandleTypeInterface:$desc_op); + + let assemblyFormat = [{ + $target + `nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)? + attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + OpFoldResult getNbPrefetch() { + auto cxt = getContext(); + if (getDynamicNbPrefetch()) + return OpFoldResult(getDynamicNbPrefetch()); + return OpFoldResult(IntegerAttr::get( + IntegerType::get(cxt, 64), getStaticNbPrefetch())); + } + }]; +} + #endif // XEGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 7a7a8c9066f09..230b4aaaa8e8e 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" @@ -405,6 +406,143 @@ void transform::SetGPULaunchThreadsOp::getEffects( modifiesPayload(effects); } +DiagnosedSilenceableFailure +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) { + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + } + auto value = *targetValues.begin(); + + int64_t nbPrefetch = getStaticNbPrefetch(); + if (getDynamicNbPrefetch()) { + // Get dynamic prefetch count from transform param or handle. + SmallVector dynamicNbPrefetch; + auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch, + {getDynamicNbPrefetch()}); + if (!status.succeeded()) + return status; + if (dynamicNbPrefetch.size() != 1) { + return emitDefiniteFailure() + << "requires exactly one value for dynamic_nb_prefetch"; + } + nbPrefetch = dynamicNbPrefetch[0]; + } + if (nbPrefetch <= 0) { + return emitSilenceableFailure(getLoc()) + << "nb_prefetch must be a positive integer."; + } + + // Find load operation of the operand. + auto maybeLoadOp = findProducerOfType(value); + if (!maybeLoadOp) { + return emitSilenceableFailure(getLoc()) << "Could not find load op."; + } + auto loadOp = *maybeLoadOp; + if (loadOp.getMixedOffsets().size() == 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op must have offsets."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find the parent scf.for loop. + auto forOp = loadOp->getParentOfType(); + if (!forOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op is not contained in a scf.for loop."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find descriptor op. + auto maybeDescOp = findProducerOfType(value); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + } + auto descOp = *maybeDescOp; + if (descOp.getMixedOffsets().size() > 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "desc op with offsets is not supported."; + diag.attachNote(descOp.getLoc()) << "desc op"; + } + + // Clone desc op outside the loop. + rewriter.setInsertionPoint(forOp); + auto newDescOp = + cast(rewriter.clone(*descOp.getOperation())); + + // Clone reduction loop to emit initial prefetches. + // Compute upper bound of the init loop: start + nbPrefetch * step. + auto nbPrefetchCst = + arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch); + auto nbStep = rewriter.createOrFold( + forOp.getLoc(), nbPrefetchCst, forOp.getStep()); + auto initUpBound = rewriter.createOrFold( + forOp.getLoc(), forOp.getLowerBound(), nbStep); + auto initForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + initUpBound, forOp.getStep()); + + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + + // Modify loadOp mixedOffsets by replacing the for loop induction variable + // with the given value. + auto getPrefetchOffsets = + [&](Value replacementVal) -> SmallVector { + IRMapping mapping; + mapping.map(forOp.getInductionVar(), replacementVal); + SmallVector dynamicOffsets = + llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) { + return mapping.lookupOrDefault(v); + })); + auto constOffsets = loadOp.getConstOffsets().value(); + return getMixedValues(constOffsets, dynamicOffsets, ctx); + }; + + // Insert prefetch op in init loop. + // Replace induction var with the init loop induction var. + rewriter.setInsertionPointToStart(initForOp.getBody()); + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(initForOp.getInductionVar()), + readCacheHint, readCacheHint, readCacheHint); + + // Insert prefetch op in main loop. + // Calculate prefetch offset after the init prefetches have been issued. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(), + forOp.getInductionVar(), nbStep); + // Replace induction var with correct offset. + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(prefetchOffset), readCacheHint, + readCacheHint, readCacheHint); + + // Unroll the init loop. + if (failed(loopUnrollFull(initForOp))) { + return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; + } + + results.set(llvm::cast(getResult()), {newDescOp}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::InsertPrefetchOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getDynamicNbPrefetchMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + namespace { class XeGPUTransformDialectExtension : public transform::TransformDialectExtension< diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 309883cfc4518..6443d2a188ec1 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -11,6 +11,7 @@ from .._ods_common import _cext as _ods_cext from .._ods_common import ( MixedValues, + MixedInt, get_op_result_or_value as _get_op_result_or_value, _dispatch_dynamic_index_list, ) @@ -134,6 +135,7 @@ def __init__( ) +@_ods_cext.register_operation(_Dialect, replace=True) class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp): """Specialization for SetGPULaunchThreadsOp class.""" @@ -168,3 +170,34 @@ def set_gpu_launch_threads( ip=None, ) -> SetGPULaunchThreadsOp: return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InsertPrefetchOp(InsertPrefetchOp): + """Specialization for InsertPrefetchOp class.""" + + def __init__( + self, + target: Value, + *, + nb_prefetch: Optional[MixedInt] = 1, + loc=None, + ip=None, + ): + static_nb_prefetch = 1 + dynamic_nb_prefetch = None + if isinstance(nb_prefetch, int): + static_nb_prefetch = nb_prefetch + elif isinstance(nb_prefetch, IntegerAttr): + static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error + elif isinstance(nb_prefetch, (Operation, Value, OpView)): + dynamic_nb_prefetch = nb_prefetch + + super().__init__( + transform.AnyOpType.get(), + target, + dynamic_nb_prefetch=dynamic_nb_prefetch, + static_nb_prefetch=static_nb_prefetch, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir index 24f500658f740..dce4a41982550 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir @@ -124,3 +124,34 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_c +func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + // expected-note@below {{load op}} + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value + // expected-error@below {{Load op is not contained in a scf.for loop.}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index 7f2fbe4271a43..aed8874723801 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -308,3 +308,79 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a +func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: xegpu.prefetch_nd %[[V0]] + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2 +func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]] + // CHECK: xegpu.prefetch_nd %[[V0]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: xegpu.prefetch_nd %[[V0]] + %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %nb = transform.param.constant 2 : i64 -> !transform.param + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index dc91f5e982579..cfe2281ba5eff 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -3,7 +3,7 @@ from mlir.ir import * from mlir.dialects import transform from mlir.dialects.transform import xegpu -from mlir.dialects.transform import AnyValueType +from mlir.dialects.transform import structured, AnyValueType def run(f): @@ -128,3 +128,68 @@ def setGPULaunchThreadsOp(): # CHECK-LABEL: TEST: setGPULaunchThreadsOp # CHECK: transform.xegpu.set_gpu_launch_threads # CHECK: threads = [8, 4, 1] + + +@run +def insertPrefetch0(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.InsertPrefetchOp( + operand, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetch0 + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + + +@run +def insertPrefetchNbPrefetch(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.InsertPrefetchOp( + operand, + nb_prefetch=2, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetch + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = 2 + + +@run +def insertPrefetchNbPrefetchParam(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + int32_t = IntegerType.get_signless(32) + param_int32_t = transform.ParamType.get(int32_t) + nb_param = transform.ParamConstantOp( + param_int32_t, + IntegerAttr.get(int32_t, 2), + ) + xegpu.InsertPrefetchOp( + operand, + nb_prefetch=nb_param, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2 + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = %[[PARAM_OP]] From 85aafcc78ce2b242f2756271c9fd6cc2b784125d Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Tue, 11 Nov 2025 19:03:58 +0200 Subject: [PATCH 2/3] address review comments --- .../XeGPU/TransformOps/XeGPUTransformOps.cpp | 18 ++++------ mlir/python/mlir/dialects/transform/xegpu.py | 10 ++++++ mlir/test/Dialect/XeGPU/transform-ops.mlir | 34 ++++++++++++++----- .../python/dialects/transform_xegpu_ext.py | 6 ++-- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 230b4aaaa8e8e..d2235f18ceaec 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -411,11 +411,10 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto targetValues = state.getPayloadValues(getTarget()); - if (!llvm::hasSingleElement(targetValues)) { + if (!llvm::hasSingleElement(targetValues)) return emitDefiniteFailure() << "requires exactly one target value handle (got " << llvm::range_size(targetValues) << ")"; - } auto value = *targetValues.begin(); int64_t nbPrefetch = getStaticNbPrefetch(); @@ -426,22 +425,19 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, {getDynamicNbPrefetch()}); if (!status.succeeded()) return status; - if (dynamicNbPrefetch.size() != 1) { + if (dynamicNbPrefetch.size() != 1) return emitDefiniteFailure() << "requires exactly one value for dynamic_nb_prefetch"; - } nbPrefetch = dynamicNbPrefetch[0]; } - if (nbPrefetch <= 0) { + if (nbPrefetch <= 0) return emitSilenceableFailure(getLoc()) << "nb_prefetch must be a positive integer."; - } // Find load operation of the operand. auto maybeLoadOp = findProducerOfType(value); - if (!maybeLoadOp) { + if (!maybeLoadOp) return emitSilenceableFailure(getLoc()) << "Could not find load op."; - } auto loadOp = *maybeLoadOp; if (loadOp.getMixedOffsets().size() == 0) { auto diag = emitSilenceableFailure(getLoc()) @@ -461,9 +457,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, // Find descriptor op. auto maybeDescOp = findProducerOfType(value); - if (!maybeDescOp) { + if (!maybeDescOp) return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; - } auto descOp = *maybeDescOp; if (descOp.getMixedOffsets().size() > 0) { auto diag = emitSilenceableFailure(getLoc()) @@ -526,9 +521,8 @@ transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, readCacheHint, readCacheHint); // Unroll the init loop. - if (failed(loopUnrollFull(initForOp))) { + if (failed(loopUnrollFull(initForOp))) return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; - } results.set(llvm::cast(getResult()), {newDescOp}); diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index 6443d2a188ec1..aa3cea58623ea 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -201,3 +201,13 @@ def __init__( loc=loc, ip=ip, ) + + +def insert_prefetch( + target: Value, + *, + nb_prefetch: Optional[MixedInt] = 1, + loc=None, + ip=None, +) -> OpResult: + return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result \ No newline at end of file diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index aed8874723801..b3b883826c1c8 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -313,8 +313,10 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: @insert_prefetch_dpas_a func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[C32:.+]] = arith.constant 32 : index %c32 = arith.constant 32 : index %c4096 = arith.constant 4096 : index + // CHECK: %[[C0:.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> @@ -322,12 +324,13 @@ func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<40 // CHECK: xegpu.create_nd_tdesc %arg1 // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 - // CHECK: xegpu.prefetch_nd %[[V0]] + // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]] %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> - // CHECK: scf.for + // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { - // CHECK: xegpu.prefetch_nd %[[V0]] + // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]] + // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]] %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> @@ -338,10 +341,15 @@ func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<40 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value // CHECK: transform.xegpu.insert_prefetch %{{.*}} %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.yield } } @@ -350,8 +358,11 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2 func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + // CHECK: %[[C64:.+]] = arith.constant 64 : index + // CHECK: %[[C32:.+]] = arith.constant 32 : index %c32 = arith.constant 32 : index %c4096 = arith.constant 4096 : index + // CHECK: %[[C0:.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> @@ -359,13 +370,14 @@ func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: // CHECK: xegpu.create_nd_tdesc %arg1 // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 - // CHECK: xegpu.prefetch_nd %[[V0]] - // CHECK: xegpu.prefetch_nd %[[V0]] + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]] + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]] %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> - // CHECK: scf.for + // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { - // CHECK: xegpu.prefetch_nd %[[V0]] + // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]] + // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]] %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> @@ -376,11 +388,15 @@ func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value %nb = transform.param.constant 2 : i64 -> !transform.param // CHECK: transform.xegpu.insert_prefetch %{{.*}} %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op transform.yield } } diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index cfe2281ba5eff..56c7d71f28431 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -139,7 +139,7 @@ def insertPrefetch0(): ) with InsertionPoint(sequence.body): operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) - xegpu.InsertPrefetchOp( + xegpu.insert_prefetch( operand, ) transform.YieldOp() @@ -157,7 +157,7 @@ def insertPrefetchNbPrefetch(): ) with InsertionPoint(sequence.body): operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) - xegpu.InsertPrefetchOp( + xegpu.insert_prefetch( operand, nb_prefetch=2, ) @@ -183,7 +183,7 @@ def insertPrefetchNbPrefetchParam(): param_int32_t, IntegerAttr.get(int32_t, 2), ) - xegpu.InsertPrefetchOp( + xegpu.insert_prefetch( operand, nb_prefetch=nb_param, ) From 25c4540b5e15074ff41ac1d519bc6dc78699b052 Mon Sep 17 00:00:00 2001 From: Tuomas Karna Date: Tue, 11 Nov 2025 21:09:00 +0200 Subject: [PATCH 3/3] update insert_prefetch op description --- .../mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 85ad91f94a379..68a75fdb5b9a5 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -209,9 +209,10 @@ def InsertPrefetchOp : Op