diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td index 34f333e556deb..f5e4afad535e5 100644 --- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -161,4 +161,43 @@ def SetOpLayoutAttrOp : Op, + TransformOpInterface + ]> { + + let summary = "Set number of threads for a given gpu.launch operation"; + let description = [{ + Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + Variadic:$threads, + DefaultValuedOptionalAttr:$static_threads + ); + let results = (outs); + let builders = [ + OpBuilder<(ins "Value":$target, "ArrayRef":$mixedThreads)>, + ]; + + let assemblyFormat = [{ + $target + `threads` `=` custom($threads, $static_threads) + attr-dict `:` qualified(type(operands)) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() { + Builder b(getContext()); + return getMixedValues(getStaticThreads(), getThreads(), b); + } + }]; +} + #endif // XEGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp index 5fdd8534e4e51..7a7a8c9066f09 100644 --- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" @@ -341,6 +342,69 @@ void transform::SetOpLayoutAttrOp::getEffects( modifiesPayload(effects); } +void transform::SetGPULaunchThreadsOp::build( + OpBuilder &builder, OperationState &ostate, Value target, + ArrayRef mixedThreads) { + SmallVector staticThreads; + SmallVector dynamicThreads; + dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads); + build(builder, ostate, target.getType(), + /*target=*/target, + /*threads=*/dynamicThreads, + /*static_threads=*/staticThreads); +} + +DiagnosedSilenceableFailure +transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + auto launchOp = dyn_cast(target); + if (!launchOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a gpu.launch op, but got: " << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + SmallVector threads; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), threads, getMixedThreads()); + if (!status.succeeded()) + return status; + + if (threads.size() != 3) { + return emitSilenceableFailure(getLoc()) + << "Expected threads argument to consist of three values (got " + << threads.size() << ")"; + } + + rewriter.setInsertionPoint(launchOp); + auto createConstValue = [&](int value) { + return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value); + }; + + // Replace threads in-place. + launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); + launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); + launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetGPULaunchThreadsOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + onlyReadsHandle(getTargetMutable(), effects); + onlyReadsHandle(getThreadsMutable(), effects); + modifiesPayload(effects); +} + namespace { class XeGPUTransformDialectExtension : public transform::TransformDialectExtension< diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py index ce8015d8f557b..309883cfc4518 100644 --- a/mlir/python/mlir/dialects/transform/xegpu.py +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -132,3 +132,39 @@ def __init__( loc=loc, ip=ip, ) + + +class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp): + """Specialization for SetGPULaunchThreadsOp class.""" + + def __init__( + self, + launch_op: Union[Operation, Value], + threads: MixedValues, + *, + loc=None, + ip=None, + ): + ( + dynamic_threads, + static_threads, + _, + ) = _dispatch_dynamic_index_list(threads) + + super().__init__( + _get_op_result_or_value(launch_op), + dynamic_threads, + static_threads=static_threads, + loc=loc, + ip=ip, + ) + + +def set_gpu_launch_threads( + launch_op: Union[Operation, Value], + threads: MixedValues, + *, + loc=None, + ip=None, +) -> SetGPULaunchThreadsOp: + return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip) diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir index 726b6748452ae..24f500658f740 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir @@ -71,3 +71,56 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index // expected-note {{target op}} + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected a gpu.launch op, but got: arith.constant}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Requires exactly one targetOp handle (got 2)}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected threads argument to consist of three values (got 2)}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir index bd6a79244ed30..7f2fbe4271a43 100644 --- a/mlir/test/Dialect/XeGPU/transform-ops.mlir +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -230,6 +230,7 @@ module attributes {transform.with_named_sequence} { transform.yield } } + // ----- // CHECK-LABEL: @set_op_layout_attr_operand1 @@ -252,3 +253,58 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads +func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_gpu_launch_threads_param +func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[C1:.+]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK: %[[C16:.+]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[C8:.+]] = arith.constant 8 : index + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C1_0:.+]] = arith.constant 1 : index + // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]]) + // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]]) + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) { + gpu.terminator + } + return +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}} + %th1 = transform.param.constant 4 : i64 -> !transform.param + transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py index 0b587d2020aa6..dc91f5e982579 100644 --- a/mlir/test/python/dialects/transform_xegpu_ext.py +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -113,3 +113,18 @@ def setOpLayoutAttrResult(): # CHECK: sg_layout = [6, 4] # CHECK: sg_data = [32, 16] # CHECK: inst_data = [8, 16] + + +@run +def setGPULaunchThreadsOp(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("gpu.launch"), + ) + with InsertionPoint(sequence.body): + xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1]) + transform.YieldOp() + # CHECK-LABEL: TEST: setGPULaunchThreadsOp + # CHECK: transform.xegpu.set_gpu_launch_threads + # CHECK: threads = [8, 4, 1]