diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 34f333e556deb..41f7afa78212e 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -161,4 +161,47 @@ def SetOpLayoutAttrOp : Op, + TransformOpInterface +]> { + + let summary = "Adds xegpu prefetch ops to matmul operand tiles."; + let description = [{ + Given a target value (e.g., `vector`) residing in a `scf.for` loop, this + transform finds the corresponding `xegpu.load_nd` op and inserts + `xegpu.prefetch` operations for the tile. The load op must reside within the + `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch` + argument. Returns a handle to the created `xegpu.create_nd_desc` op. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$target, + Optional:$dynamic_nb_prefetch, + DefaultValuedOptionalAttr:$static_nb_prefetch + ); + + let results = (outs TransformHandleTypeInterface:$desc_op); + + let assemblyFormat = [{ + $target + `nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)? + attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + OpFoldResult getNbPrefetch() { + auto cxt = getContext(); + if (getDynamicNbPrefetch()) + return OpFoldResult(getDynamicNbPrefetch()); + return OpFoldResult(IntegerAttr::get( + IntegerType::get(cxt, 64), getStaticNbPrefetch())); + } + }]; +} + #endif // XEGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 5fdd8534e4e51..d9222d10a5e70 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" @@ -341,6 +342,143 @@ void transform::SetOpLayoutAttrOp::getEffects( modifiesPayload(effects); } +DiagnosedSilenceableFailure +transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetValues = state.getPayloadValues(getTarget()); + if (!llvm::hasSingleElement(targetValues)) { + return emitDefiniteFailure() + << "requires exactly one target value handle (got " + << llvm::range_size(targetValues) << ")"; + } + auto value = *targetValues.begin(); + + int64_t nbPrefetch = getStaticNbPrefetch(); + if (getDynamicNbPrefetch()) { + // Get dynamic prefetch count from transform param or handle. + SmallVector dynamicNbPrefetch; + auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch, + {getDynamicNbPrefetch()}); + if (!status.succeeded()) + return status; + if (dynamicNbPrefetch.size() != 1) { + return emitDefiniteFailure() + << "requires exactly one value for dynamic_nb_prefetch"; + } + nbPrefetch = dynamicNbPrefetch[0]; + } + if (nbPrefetch <= 0) { + return emitSilenceableFailure(getLoc()) + << "nb_prefetch must be a positive integer."; + } + + // Find load operation of the operand. + auto maybeLoadOp = findProducerOfType(value); + if (!maybeLoadOp) { + return emitSilenceableFailure(getLoc()) << "Could not find load op."; + } + auto loadOp = *maybeLoadOp; + if (loadOp.getMixedOffsets().size() == 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op must have offsets."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find the parent scf.for loop. + auto forOp = loadOp->getParentOfType(); + if (!forOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Load op is not contained in a scf.for loop."; + diag.attachNote(loadOp.getLoc()) << "load op"; + return diag; + } + + // Find descriptor op. + auto maybeDescOp = findProducerOfType(value); + if (!maybeDescOp) { + return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; + } + auto descOp = *maybeDescOp; + if (descOp.getMixedOffsets().size() > 0) { + auto diag = emitSilenceableFailure(getLoc()) + << "desc op with offsets is not supported."; + diag.attachNote(descOp.getLoc()) << "desc op"; + } + + // Clone desc op outside the loop. + rewriter.setInsertionPoint(forOp); + auto newDescOp = + cast(rewriter.clone(*descOp.getOperation())); + + // Clone reduction loop to emit initial prefetches. + // Compute upper bound of the init loop: start + nbPrefetch * step. + auto nbPrefetchCst = + arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch); + auto nbStep = rewriter.createOrFold( + forOp.getLoc(), nbPrefetchCst, forOp.getStep()); + auto initUpBound = rewriter.createOrFold( + forOp.getLoc(), forOp.getLowerBound(), nbStep); + auto initForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + initUpBound, forOp.getStep()); + + auto ctx = rewriter.getContext(); + auto readCacheHint = + xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); + + // Modify loadOp mixedOffsets by replacing the for loop induction variable + // with the given value. + auto getPrefetchOffsets = + [&](Value replacementVal) -> SmallVector { + IRMapping mapping; + mapping.map(forOp.getInductionVar(), replacementVal); + SmallVector dynamicOffsets = + llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) { + return mapping.lookupOrDefault(v); + })); + auto constOffsets = loadOp.getConstOffsets().value(); + return getMixedValues(constOffsets, dynamicOffsets, ctx); + }; + + // Insert prefetch op in init loop. + // Replace induction var with the init loop induction var. + rewriter.setInsertionPointToStart(initForOp.getBody()); + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(initForOp.getInductionVar()), + readCacheHint, readCacheHint, readCacheHint); + + // Insert prefetch op in main loop. + // Calculate prefetch offset after the init prefetches have been issued. + rewriter.setInsertionPointToStart(forOp.getBody()); + auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(), + forOp.getInductionVar(), nbStep); + // Replace induction var with correct offset. + xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), + newDescOp.getResult(), + getPrefetchOffsets(prefetchOffset), readCacheHint, + readCacheHint, readCacheHint); + + // Unroll the init loop. + if (failed(loopUnrollFull(initForOp))) { + return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; + } + + results.set(llvm::cast(getResult()), {newDescOp}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::InsertPrefetchOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getDynamicNbPrefetchMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + namespace { class XeGPUTransformDialectExtension : public transform::TransformDialectExtension< diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index ce8015d8f557b..b14426b13fae3 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -11,6 +11,7 @@ from .._ods_common import _cext as _ods_cext from .._ods_common import ( MixedValues, + MixedInt, get_op_result_or_value as _get_op_result_or_value, _dispatch_dynamic_index_list, ) @@ -132,3 +133,34 @@ def __init__( loc=loc, ip=ip, ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InsertPrefetchOp(InsertPrefetchOp): + """Specialization for InsertPrefetchOp class.""" + + def __init__( + self, + target: Value, + *, + nb_prefetch: Optional[MixedInt] = 1, + loc=None, + ip=None, + ): + static_nb_prefetch = 1 + dynamic_nb_prefetch = None + if isinstance(nb_prefetch, int): + static_nb_prefetch = nb_prefetch + elif isinstance(nb_prefetch, IntegerAttr): + static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error + elif isinstance(nb_prefetch, (Operation, Value, OpView)): + dynamic_nb_prefetch = nb_prefetch + + super().__init__( + transform.AnyOpType.get(), + target, + dynamic_nb_prefetch=dynamic_nb_prefetch, + static_nb_prefetch=static_nb_prefetch, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir index 726b6748452ae..aaacf2a5f9280 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir @@ -71,3 +71,34 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_c +func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + // expected-note@below {{load op}} + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value + // expected-error@below {{Load op is not contained in a scf.for loop.}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index bd6a79244ed30..1fb1571b2bb43 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -252,3 +252,79 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a +func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: xegpu.prefetch_nd %[[V0]] + %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2 +func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c4096 = arith.constant 4096 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16> + %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16> + // CHECK: xegpu.create_nd_tdesc %arg0 + // CHECK: xegpu.create_nd_tdesc %arg1 + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: !xegpu.tensor_desc<256x32xf16 + // CHECK: xegpu.prefetch_nd %[[V0]] + // CHECK: xegpu.prefetch_nd %[[V0]] + %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16> + // CHECK: scf.for + %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) { + // CHECK: xegpu.prefetch_nd %[[V0]] + %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16> + %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16> + %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16> + scf.yield %7 : vector<256x256xf16> + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value + %nb = transform.param.constant 2 : i64 -> !transform.param + // CHECK: transform.xegpu.insert_prefetch %{{.*}} + %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index 0b587d2020aa6..4884f6d2b9cb5 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -3,7 +3,7 @@ from mlir.ir import * from mlir.dialects import transform from mlir.dialects.transform import xegpu -from mlir.dialects.transform import AnyValueType +from mlir.dialects.transform import structured, AnyValueType def run(f): @@ -113,3 +113,68 @@ def setOpLayoutAttrResult(): # CHECK: sg_layout = [6, 4] # CHECK: sg_data = [32, 16] # CHECK: inst_data = [8, 16] + + +@run +def insertPrefetch0(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.InsertPrefetchOp( + operand, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetch0 + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + + +@run +def insertPrefetchNbPrefetch(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + xegpu.InsertPrefetchOp( + operand, + nb_prefetch=2, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetch + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = 2 + + +@run +def insertPrefetchNbPrefetchParam(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.dpas"), + ) + with InsertionPoint(sequence.body): + operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0]) + int32_t = IntegerType.get_signless(32) + param_int32_t = transform.ParamType.get(int32_t) + nb_param = transform.ParamConstantOp( + param_int32_t, + IntegerAttr.get(int32_t, 2), + ) + xegpu.InsertPrefetchOp( + operand, + nb_prefetch=nb_param, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam + # CHECK: %[[OPR:.*]] = get_operand + # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2 + # CHECK: transform.xegpu.insert_prefetch %[[OPR]] + # CHECK-SAME: nb_prefetch = %[[PARAM_OP]]