Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,48 @@ def SetGPULaunchThreadsOp
}];
}

def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface
]> {

let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
let description = [{
Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
transform finds the corresponding `xegpu.load_nd` op and inserts
`xegpu.prefetch_nd` operations for the tile. The load op must reside within
the `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
argument (default value is 1). Returns a handle to the created
`xegpu.create_nd_desc` op.
}];

let arguments = (ins TransformValueHandleTypeInterface:$target,
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
);

let results = (outs TransformHandleTypeInterface:$desc_op);

let assemblyFormat = [{
$target
`nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)?
attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure apply(
::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state);

OpFoldResult getNbPrefetch() {
auto cxt = getContext();
if (getDynamicNbPrefetch())
return OpFoldResult(getDynamicNbPrefetch());
return OpFoldResult(IntegerAttr::get(
IntegerType::get(cxt, 64), getStaticNbPrefetch()));
}
}];
}

#endif // XEGPU_TRANSFORM_OPS
132 changes: 132 additions & 0 deletions mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"

Expand Down Expand Up @@ -405,6 +406,137 @@ void transform::SetGPULaunchThreadsOp::getEffects(
modifiesPayload(effects);
}

DiagnosedSilenceableFailure
transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto targetValues = state.getPayloadValues(getTarget());
if (!llvm::hasSingleElement(targetValues))
return emitDefiniteFailure()
<< "requires exactly one target value handle (got "
<< llvm::range_size(targetValues) << ")";
auto value = *targetValues.begin();

int64_t nbPrefetch = getStaticNbPrefetch();
if (getDynamicNbPrefetch()) {
// Get dynamic prefetch count from transform param or handle.
SmallVector<int32_t> dynamicNbPrefetch;
auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
{getDynamicNbPrefetch()});
if (!status.succeeded())
return status;
if (dynamicNbPrefetch.size() != 1)
return emitDefiniteFailure()
<< "requires exactly one value for dynamic_nb_prefetch";
nbPrefetch = dynamicNbPrefetch[0];
}
if (nbPrefetch <= 0)
return emitSilenceableFailure(getLoc())
<< "nb_prefetch must be a positive integer.";

// Find load operation of the operand.
auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
if (!maybeLoadOp)
return emitSilenceableFailure(getLoc()) << "Could not find load op.";
auto loadOp = *maybeLoadOp;
if (loadOp.getMixedOffsets().size() == 0) {
auto diag = emitSilenceableFailure(getLoc())
<< "Load op must have offsets.";
diag.attachNote(loadOp.getLoc()) << "load op";
return diag;
}

// Find the parent scf.for loop.
auto forOp = loadOp->getParentOfType<scf::ForOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion: could this maybe be generalized to support any parent who implements the LoopLikeOpInterface (or whichever name it goes by)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it can be generalized. I however suggest deferring this to time when we have a concrete use case.

if (!forOp) {
auto diag = emitSilenceableFailure(getLoc())
<< "Load op is not contained in a scf.for loop.";
diag.attachNote(loadOp.getLoc()) << "load op";
return diag;
}

// Find descriptor op.
auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
if (!maybeDescOp)
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
auto descOp = *maybeDescOp;
if (descOp.getMixedOffsets().size() > 0) {
auto diag = emitSilenceableFailure(getLoc())
<< "desc op with offsets is not supported.";
diag.attachNote(descOp.getLoc()) << "desc op";
}

// Clone desc op outside the loop.
rewriter.setInsertionPoint(forOp);
auto newDescOp =
cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));

// Clone reduction loop to emit initial prefetches.
// Compute upper bound of the init loop: start + nbPrefetch * step.
auto nbPrefetchCst =
arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
auto nbStep = rewriter.createOrFold<arith::MulIOp>(
forOp.getLoc(), nbPrefetchCst, forOp.getStep());
auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
forOp.getLoc(), forOp.getLowerBound(), nbStep);
auto initForOp =
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
initUpBound, forOp.getStep());

auto ctx = rewriter.getContext();
auto readCacheHint =
xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);

// Modify loadOp mixedOffsets by replacing the for loop induction variable
// with the given value.
auto getPrefetchOffsets =
[&](Value replacementVal) -> SmallVector<OpFoldResult> {
IRMapping mapping;
mapping.map(forOp.getInductionVar(), replacementVal);
SmallVector<Value> dynamicOffsets =
llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
return mapping.lookupOrDefault(v);
}));
auto constOffsets = loadOp.getConstOffsets().value();
return getMixedValues(constOffsets, dynamicOffsets, ctx);
};

// Insert prefetch op in init loop.
// Replace induction var with the init loop induction var.
rewriter.setInsertionPointToStart(initForOp.getBody());
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
newDescOp.getResult(),
getPrefetchOffsets(initForOp.getInductionVar()),
readCacheHint, readCacheHint, readCacheHint);

// Insert prefetch op in main loop.
// Calculate prefetch offset after the init prefetches have been issued.
rewriter.setInsertionPointToStart(forOp.getBody());
auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
forOp.getInductionVar(), nbStep);
// Replace induction var with correct offset.
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
newDescOp.getResult(),
getPrefetchOffsets(prefetchOffset), readCacheHint,
readCacheHint, readCacheHint);

// Unroll the init loop.
if (failed(loopUnrollFull(initForOp)))
return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";

results.set(llvm::cast<OpResult>(getResult()), {newDescOp});

return DiagnosedSilenceableFailure::success();
}

void transform::InsertPrefetchOp::getEffects(
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
Expand Down
43 changes: 43 additions & 0 deletions mlir/python/mlir/dialects/transform/xegpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
MixedValues,
MixedInt,
get_op_result_or_value as _get_op_result_or_value,
_dispatch_dynamic_index_list,
)
Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
)


@_ods_cext.register_operation(_Dialect, replace=True)
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
"""Specialization for SetGPULaunchThreadsOp class."""

Expand Down Expand Up @@ -168,3 +170,44 @@ def set_gpu_launch_threads(
ip=None,
) -> SetGPULaunchThreadsOp:
return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class InsertPrefetchOp(InsertPrefetchOp):
"""Specialization for InsertPrefetchOp class."""

def __init__(
self,
target: Value,
*,
nb_prefetch: Optional[MixedInt] = 1,
loc=None,
ip=None,
):
static_nb_prefetch = 1
dynamic_nb_prefetch = None
if isinstance(nb_prefetch, int):
static_nb_prefetch = nb_prefetch
elif isinstance(nb_prefetch, IntegerAttr):
static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error
elif isinstance(nb_prefetch, (Operation, Value, OpView)):
dynamic_nb_prefetch = nb_prefetch

super().__init__(
transform.AnyOpType.get(),
target,
dynamic_nb_prefetch=dynamic_nb_prefetch,
static_nb_prefetch=static_nb_prefetch,
loc=loc,
ip=ip,
)


def insert_prefetch(
target: Value,
*,
nb_prefetch: Optional[MixedInt] = 1,
loc=None,
ip=None,
) -> OpResult:
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
31 changes: 31 additions & 0 deletions mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,34 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: @insert_prefetch_dpas_c
func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
// expected-note@below {{load op}}
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
scf.yield %7 : vector<256x256xf16>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
// expected-error@below {{Load op is not contained in a scf.for loop.}}
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
transform.yield
}
}
92 changes: 92 additions & 0 deletions mlir/test/Dialect/XeGPU/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,95 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: @insert_prefetch_dpas_a
func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
// CHECK: %[[C32:.+]] = arith.constant 32 : index
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
// CHECK: xegpu.create_nd_tdesc %arg0
// CHECK: xegpu.create_nd_tdesc %arg1
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
// CHECK-SAME: !xegpu.tensor_desc<256x32xf16
// CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]]
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
// CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]]
// CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]]
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
scf.yield %7 : vector<256x256xf16>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.canonicalization
} : !transform.any_op

transform.yield
}
}

// -----

// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
// CHECK: %[[C64:.+]] = arith.constant 64 : index
// CHECK: %[[C32:.+]] = arith.constant 32 : index
%c32 = arith.constant 32 : index
%c4096 = arith.constant 4096 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
// CHECK: xegpu.create_nd_tdesc %arg0
// CHECK: xegpu.create_nd_tdesc %arg1
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
// CHECK-SAME: !xegpu.tensor_desc<256x32xf16
// CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]]
// CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]]
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
// CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]]
// CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]]
%5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
scf.yield %7 : vector<256x256xf16>
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
%nb = transform.param.constant 2 : i64 -> !transform.param<i64>
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.yield
}
}
Loading
Loading