From 4e17136b602c82b25af6b0f65e77770de994e473 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Mon, 29 Sep 2025 15:21:56 +0000 Subject: [PATCH 1/7] Add GPUToXeVM lowering pipeline pass. It's the default GPU to XeVM lowering pipeline. It starts by lowering GPU code to the specified compilation target (default is fatbin), then lowers the host code. If XeGPU ops are used, it expects the MLIR code to have XeGPU ops already embedded in gpu code. --- .../mlir/Dialect/GPU/Pipelines/Passes.h | 50 ++++++- mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt | 5 + .../GPU/Pipelines/GPUToXeVMPipeline.cpp | 138 ++++++++++++++++++ mlir/lib/RegisterAllPasses.cpp | 1 + 4 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h index 035235fc7174a..c634236139d6f 100644 --- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h @@ -1,4 +1,4 @@ -//===- Passes.h - GPU NVVM pipeline entry points --------------------------===// +//===- Passes.h - GPU NVVM/XeVM pipeline entry points----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -60,6 +60,47 @@ struct GPUToNVVMPipelineOptions llvm::cl::init(false)}; }; +// Options for the gpu to xevm pipeline. +struct GPUToXeVMPipelineOptions + : public PassPipelineOptions { + // General lowering controls. + PassOptions::Option indexBitWidth{ + *this, "index-bitwidth", + llvm::cl::desc("Bitwidth of the index type (host & device)"), + llvm::cl::init(64)}; + PassOptions::Option kernelBarePtrCallConv{ + *this, "kernel-bare-ptr-calling-convention", + llvm::cl::desc("Use bare pointer calling convention for device kernels"), + llvm::cl::init(false)}; + PassOptions::Option hostBarePtrCallConv{ + *this, "host-bare-ptr-calling-convention", + llvm::cl::desc("Use bare pointer calling convention for host launches"), + llvm::cl::init(false)}; + PassOptions::Option binaryFormat{ + *this, "binary-format", + llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"), + llvm::cl::init("fatbin")}; + // Options mirroring xevm-attach-target (GpuXeVMAttachTarget). + PassOptions::Option xevmModuleMatcher{ + *this, "xevm-module-matcher", + llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"), + llvm::cl::init("")}; + PassOptions::Option zebinTriple{ + *this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"), + llvm::cl::init("spirv64-unknown-unknown")}; + PassOptions::Option zebinChip{ + *this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"), + llvm::cl::init("bmg")}; + PassOptions::Option optLevel{ + *this, "opt-level", + llvm::cl::desc("Optimization level for attached target/codegen"), + llvm::cl::init(2)}; + PassOptions::Option cmdOptions{ + *this, "igc-cmd-options", + llvm::cl::desc("Additional downstream compiler command line options"), + llvm::cl::init("")}; +}; + //===----------------------------------------------------------------------===// // Building and Registering. //===----------------------------------------------------------------------===// @@ -70,8 +111,15 @@ struct GPUToNVVMPipelineOptions void buildLowerToNVVMPassPipeline(OpPassManager &pm, const GPUToNVVMPipelineOptions &options); +/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main +/// dialects into XeVM targets. Begins with GPU code regions, then handles host +/// code. +void buildLowerToXeVMPassPipeline(OpPassManager &pm, + const GPUToXeVMPipelineOptions &options); + /// Register all pipeleines for the `gpu` dialect. void registerGPUToNVVMPipeline(); +void registerGPUToXeVMPipeline(); } // namespace gpu } // namespace mlir diff --git a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt index 70a9c77a6d796..f231eebe6d82b 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRGPUPipelines GPUToNVVMPipeline.cpp + GPUToXeVMPipeline.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU @@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRGPUPipelines MLIRLinalgTransforms MLIRAffineToStandard MLIRGPUToNVVMTransforms + MLIRXeGPUToXeVM MLIRIndexToLLVM MLIRMathToLLVM MLIRNVGPUToNVVM @@ -19,4 +21,7 @@ add_mlir_dialect_library(MLIRGPUPipelines MLIRReconcileUnrealizedCasts MLIRSCFToControlFlow MLIRVectorToSCF + MLIRXeGPUTransforms + MLIRXeGPUToXeVM + MLIRXeVMToLLVM ) diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp new file mode 100644 index 0000000000000..eedd11df7f8af --- /dev/null +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -0,0 +1,138 @@ +//===- GPUToXeVMPipeline.cpp - Lowering pipeline to XeVM/LLVM -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing the lowering to XeVM as a generally +// usable sink pass. If XeGPU ops are used, it expects the MLIR code to have +// XeGPU ops already embedded in gpu code. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Target/LLVM/XeVM/Target.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +//===----------------------------------------------------------------------===// +// Common pipeline +//===----------------------------------------------------------------------===// +void buildCommonPassPipeline( + OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) { + // builtin.module scope passes + pm.addPass(createCSEPass()); + { + GpuXeVMAttachTargetOptions xevmTargetOptions; + xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher; + xevmTargetOptions.triple = options.zebinTriple; + xevmTargetOptions.chip = options.zebinChip; + xevmTargetOptions.optLevel = options.optLevel; + xevmTargetOptions.cmdOptions = options.cmdOptions; + pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions)); + } +} + +//===----------------------------------------------------------------------===// +// GPUModule-specific stuff. +//===----------------------------------------------------------------------===// +void buildGpuPassPipeline(OpPassManager &pm, + const mlir::gpu::GPUToXeVMPipelineOptions &options) { + pm.addNestedPass(xegpu::createXeGPUWgToSgDistribute()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createLowerAffinePass()); + pm.addNestedPass(xegpu::createXeGPUBlocking()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(xegpu::createXeGPUPropagateLayout()); + pm.addNestedPass(xegpu::createXeGPUSubgroupDistribute()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createLoopInvariantCodeMotionPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(xegpu::createXeGPUVectorLinearize()); + pm.addNestedPass(createConvertXeGPUToXeVMPass()); + ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions; + gpuToLLVMSPVOptions.use64bitIndex = options.indexBitWidth; + pm.addNestedPass( + createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions)); + pm.addNestedPass(createConvertXeVMToLLVMPass()); + pm.addNestedPass(createCSEPass()); +} + +//===----------------------------------------------------------------------===// +// Host Post-GPU pipeline +//===----------------------------------------------------------------------===// +void buildHostPostPipeline(OpPassManager &pm, + const mlir::gpu::GPUToXeVMPipelineOptions &options) { + pm.addNestedPass(LLVM::createLLVMRequestCWrappersPass()); + pm.addNestedPass(createGpuAsyncRegionPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addPass(createConvertVectorToSCFPass()); + pm.addPass(createSCFToControlFlowPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + { + GpuToLLVMConversionPassOptions gpuToLLVMOptions; + gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv; + gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv; + pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions)); + } + pm.addPass(createConvertToLLVMPass()); + pm.addPass(createLowerAffinePass()); + // gpu-module-to-binary + { + GpuModuleToBinaryPassOptions gpuToModuleBinOptions; + gpuToModuleBinOptions.compilationTarget = options.binaryFormat; + gpuToModuleBinOptions.cmdOptions = options.cmdOptions; + pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions)); + } + pm.addPass(createReconcileUnrealizedCastsPass()); +} +} // namespace + +void mlir::gpu::buildLowerToXeVMPassPipeline( + OpPassManager &pm, const GPUToXeVMPipelineOptions &options) { + // Common pipelines + buildCommonPassPipeline(pm, options); + + // GPUModule-specific stuff + buildGpuPassPipeline(pm, options); + + // Host post-GPUModule-specific stuff + buildHostPostPipeline(pm, options); +} + +void mlir::gpu::registerGPUToXeVMPipeline() { + PassPipelineRegistration( + "gpu-lower-to-xevm-pipeline", + "The default GPU to XeVM lowering pipeline. It starts by lowering GPU " + "code to the " + "specified compilation target (default is fatbin) then lowers the host " + "code.", + buildLowerToXeVMPassPipeline); +} diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp index c67b24226ae45..dd413d2de8710 100644 --- a/mlir/lib/RegisterAllPasses.cpp +++ b/mlir/lib/RegisterAllPasses.cpp @@ -98,4 +98,5 @@ void mlir::registerAllPasses() { sparse_tensor::registerSparseTensorPipelines(); tosa::registerTosaToLinalgPipelines(); gpu::registerGPUToNVVMPipeline(); + gpu::registerGPUToXeVMPipeline(); } From dd56eb090b28402fdf6788f56f834d164c0a2c56 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 30 Sep 2025 00:07:45 +0000 Subject: [PATCH 2/7] Add a pass option to provide the XeGPU code level. XeGPU allows worgroup, subgroup, and workitem level programming. This options lets the pass manager know at which level the XeGPU ops belong to. --- .../mlir/Dialect/GPU/Pipelines/Passes.h | 12 ++++-- .../GPU/Pipelines/GPUToXeVMPipeline.cpp | 42 +++++++++++-------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h index c634236139d6f..67dc6415008f3 100644 --- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h @@ -63,11 +63,17 @@ struct GPUToNVVMPipelineOptions // Options for the gpu to xevm pipeline. struct GPUToXeVMPipelineOptions : public PassPipelineOptions { + // XeGPU op granularity selection: workgroup | subgroup | workitem + PassOptions::Option xegpuOpLevel{ + *this, "xegpu-op-level", + llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | " + "subgroup | workitem"), + llvm::cl::init("workgroup")}; // General lowering controls. - PassOptions::Option indexBitWidth{ - *this, "index-bitwidth", + PassOptions::Option use64bitIndex{ + *this, "use-64bit-index", llvm::cl::desc("Bitwidth of the index type (host & device)"), - llvm::cl::init(64)}; + llvm::cl::init(true)}; PassOptions::Option kernelBarePtrCallConv{ *this, "kernel-bare-ptr-calling-convention", llvm::cl::desc("Use bare pointer calling convention for device kernels"), diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp index eedd11df7f8af..ae77143f6a66d 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -62,24 +62,30 @@ void buildCommonPassPipeline( //===----------------------------------------------------------------------===// void buildGpuPassPipeline(OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) { - pm.addNestedPass(xegpu::createXeGPUWgToSgDistribute()); - pm.addNestedPass(createCSEPass()); - pm.addNestedPass(createLowerAffinePass()); - pm.addNestedPass(xegpu::createXeGPUBlocking()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass(createCSEPass()); - pm.addNestedPass(xegpu::createXeGPUPropagateLayout()); - pm.addNestedPass(xegpu::createXeGPUSubgroupDistribute()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass(createCSEPass()); - pm.addNestedPass(createLoopInvariantCodeMotionPass()); - pm.addNestedPass(createCSEPass()); - pm.addNestedPass(xegpu::createXeGPUVectorLinearize()); + if (options.xegpuOpLevel == "workgroup") { + pm.addNestedPass(xegpu::createXeGPUWgToSgDistribute()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createLowerAffinePass()); + pm.addNestedPass(xegpu::createXeGPUBlocking()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + } + if (options.xegpuOpLevel == "subgroup") { + pm.addNestedPass(xegpu::createXeGPUPropagateLayout()); + pm.addNestedPass(xegpu::createXeGPUSubgroupDistribute()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(createLoopInvariantCodeMotionPass()); + pm.addNestedPass(createCSEPass()); + pm.addNestedPass(xegpu::createXeGPUVectorLinearize()); + } pm.addNestedPass(createConvertXeGPUToXeVMPass()); - ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions; - gpuToLLVMSPVOptions.use64bitIndex = options.indexBitWidth; - pm.addNestedPass( - createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions)); + { + ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions; + gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex; + pm.addNestedPass( + createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions)); + } pm.addNestedPass(createConvertXeVMToLLVMPass()); pm.addNestedPass(createCSEPass()); } @@ -104,6 +110,7 @@ void buildHostPostPipeline(OpPassManager &pm, } pm.addPass(createConvertToLLVMPass()); pm.addPass(createLowerAffinePass()); + pm.addPass(createReconcileUnrealizedCastsPass()); // gpu-module-to-binary { GpuModuleToBinaryPassOptions gpuToModuleBinOptions; @@ -111,7 +118,6 @@ void buildHostPostPipeline(OpPassManager &pm, gpuToModuleBinOptions.cmdOptions = options.cmdOptions; pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions)); } - pm.addPass(createReconcileUnrealizedCastsPass()); } } // namespace From 8fcae213df599dffe31d0dc5d45a382691c0b55a Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 30 Sep 2025 00:10:16 +0000 Subject: [PATCH 3/7] Add a workitem level gemm test for XeGPU. --- .../Dialect/XeGPU/SIMT/simple_gemm.mlir | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir new file mode 100644 index 0000000000000..ddae9c1e7eb8f --- /dev/null +++ b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir @@ -0,0 +1,123 @@ +// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workitem" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_levelzero_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +module @gemm attributes {gpu.container_module} { + gpu.module @kernel { + gpu.func @simple_gemm(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %block_x = gpu.block_id x + %block_y = gpu.block_id y + %x_block_offset = arith.muli %block_x, %c8 : index + %y_block_offset = arith.muli %block_y, %c16 : index + + %c_tdesc = xegpu.create_nd_tdesc %c : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%x_block_offset, %y_block_offset] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + %a_tdesc = xegpu.create_nd_tdesc %a : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16> + %b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16> + + %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8xf32>) { + + %a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> + %b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> + %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + scf.yield %dpas : vector<8xf32> + } + xegpu.store_nd %r, %c_tdesc[%x_block_offset, %y_block_offset] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + + func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %memref_a = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_a, %a : memref<256x256xf16>, memref<256x256xf16> + %memref_b = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_b, %b : memref<256x256xf16>, memref<256x256xf16> + %memref_c = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_c, %c : memref<256x256xf32>, memref<256x256xf32> + gpu.launch_func @kernel::@simple_gemm blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>) + gpu.wait // Wait for the kernel to finish. + gpu.memcpy %c, %memref_c : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_a : memref<256x256xf16> + gpu.dealloc %memref_b : memref<256x256xf16> + gpu.dealloc %memref_c : memref<256x256xf32> + return %c : memref<256x256xf32> + } + + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f16 = arith.constant 1.0 : f16 + %c2_f16 = arith.constant 2.0 : f16 + %c256 = arith.constant 256 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<256x256xf16> + %B = memref.alloc() : memref<256x256xf16> + %C = memref.alloc() : memref<256x256xf32> + %C_ref = memref.alloc() : memref<256x256xf32> + %c_gen_int = arith.constant 0 : i1 + %cf_lower = arith.constant -0.5 : f32 + %cf_upper = arith.constant 0.5 : f32 + + // Initialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<256x256xf16> + } + } + + // Initialize the B matrix. + // Make matrix B an identity matrix. + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<256x256xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<256x256xf16> + } + } + } + + // Initialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> + } + } + + // Run GPU version. + %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32> + + // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}} + // CHECK-COUNT-256: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + call @printMemrefF32(%gpu_result_cast) : (memref<*xf32>) -> () + memref.dealloc %A : memref<256x256xf16> + memref.dealloc %B : memref<256x256xf16> + memref.dealloc %C : memref<256x256xf32> + memref.dealloc %C_ref : memref<256x256xf32> + return + } + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} From 5483acf8893edab7271d8b9d1fedcddcc911a961 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 30 Sep 2025 15:05:08 +0000 Subject: [PATCH 4/7] Fix a small logic. --- mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp index ae77143f6a66d..e3e1bab3e12d1 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -70,7 +70,8 @@ void buildGpuPassPipeline(OpPassManager &pm, pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); } - if (options.xegpuOpLevel == "subgroup") { + if (options.xegpuOpLevel == "subgroup" || + options.xegpuOpLevel == "workgroup") { pm.addNestedPass(xegpu::createXeGPUPropagateLayout()); pm.addNestedPass(xegpu::createXeGPUSubgroupDistribute()); pm.addNestedPass(createCanonicalizerPass()); From 7c6ba629b659dce0cf38687ec738200948e2c0bf Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 30 Sep 2025 21:57:55 +0000 Subject: [PATCH 5/7] Add test cases for SG and WG. --- .../Dialect/XeGPU/SG/simple_gemm.mlir | 120 ++++++++++++++ .../Dialect/XeGPU/SIMT/simple_gemm.mlir | 2 - .../Dialect/XeGPU/WG/simple_gemm.mlir | 149 ++++++++++++++++++ 3 files changed, 269 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir create mode 100644 mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir diff --git a/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir new file mode 100644 index 0000000000000..877edf47fcd15 --- /dev/null +++ b/mlir/test/Integration/Dialect/XeGPU/SG/simple_gemm.mlir @@ -0,0 +1,120 @@ +// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=subgroup" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_levelzero_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +module @gemm attributes {gpu.container_module} { + gpu.module @kernel { + gpu.func @simple_gemm(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %block_x = gpu.block_id x + %block_y = gpu.block_id y + %x_block_offset = arith.muli %block_x, %c8 : index + %y_block_offset = arith.muli %block_y, %c16 : index + + %c_tdesc = xegpu.create_nd_tdesc %c : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32> + %c_init_value = xegpu.load_nd %c_tdesc[%x_block_offset, %y_block_offset] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %a_tdesc = xegpu.create_nd_tdesc %a : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16> + %b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16> + + %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8x16xf32>) { + %a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %dpas : vector<8x16xf32> + } + xegpu.store_nd %r, %c_tdesc[%x_block_offset, %y_block_offset] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } + } + + func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %memref_a = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_a, %a : memref<256x256xf16>, memref<256x256xf16> + %memref_b = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %memref_b, %b : memref<256x256xf16>, memref<256x256xf16> + %memref_c = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %memref_c, %c : memref<256x256xf32>, memref<256x256xf32> + gpu.launch_func @kernel::@simple_gemm blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>) + gpu.wait // Wait for the kernel to finish. + gpu.memcpy %c, %memref_c : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %memref_a : memref<256x256xf16> + gpu.dealloc %memref_b : memref<256x256xf16> + gpu.dealloc %memref_c : memref<256x256xf32> + return %c : memref<256x256xf32> + } + + + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f16 = arith.constant 1.0 : f16 + %c2_f16 = arith.constant 2.0 : f16 + %c256 = arith.constant 256 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<256x256xf16> + %B = memref.alloc() : memref<256x256xf16> + %C = memref.alloc() : memref<256x256xf32> + %C_ref = memref.alloc() : memref<256x256xf32> + %c_gen_int = arith.constant 0 : i1 + %cf_lower = arith.constant -0.5 : f32 + %cf_upper = arith.constant 0.5 : f32 + // Option 1: intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<256x256xf16> + } + } + + // Initialize the B matrix + // Make matrix B an identity matrix + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<256x256xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<256x256xf16> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> + } + } + + // Run GPU. + %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32> + // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}} + // CHECK-COUNT-256: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + + memref.dealloc %A : memref<256x256xf16> + memref.dealloc %B : memref<256x256xf16> + memref.dealloc %C : memref<256x256xf32> + memref.dealloc %C_ref : memref<256x256xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir index ddae9c1e7eb8f..36b04791ee2dd 100644 --- a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir +++ b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir @@ -25,7 +25,6 @@ module @gemm attributes {gpu.container_module} { %b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16> %r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8xf32>) { - %a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16> %b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16> %dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> @@ -118,6 +117,5 @@ module @gemm attributes {gpu.container_module} { memref.dealloc %C_ref : memref<256x256xf32> return } - func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} } diff --git a/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir new file mode 100644 index 0000000000000..fdc24c02f9d98 --- /dev/null +++ b/mlir/test/Integration/Dialect/XeGPU/WG/simple_gemm.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workgroup" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_levelzero_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +#a = #xegpu.layout +#b = #xegpu.layout +#c = #xegpu.layout +#a_prefetch = #xegpu.layout +#b_prefetch = #xegpu.layout +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %A_gpu, %A : memref<256x256xf16>, memref<256x256xf16> + %B_gpu = gpu.alloc () : memref<256x256xf16> + gpu.memcpy %B_gpu, %B : memref<256x256xf16>, memref<256x256xf16> + %C_gpu = gpu.alloc () : memref<256x256xf32> + gpu.memcpy %C_gpu, %C : memref<256x256xf32>, memref<256x256xf32> + // NOTE: Here we can't use [8, 64] wi threads following the SG thread layout of [8, 4]. Because runtime will linearize the x dimension first (we need y dimension to be linearized first). + // So just use linearized thread layout of [512, 1] wi threads. + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c512, %c1, %c1) args(%A_gpu : memref<256x256xf16>, %B_gpu : memref<256x256xf16>, %C_gpu : memref<256x256xf32>) + gpu.wait // Wait for the kernel to finish. + gpu.memcpy %C, %C_gpu : memref<256x256xf32>, memref<256x256xf32> + gpu.dealloc %A_gpu : memref<256x256xf16> + gpu.dealloc %B_gpu : memref<256x256xf16> + gpu.dealloc %C_gpu : memref<256x256xf32> + return %C : memref<256x256xf32> + } + + gpu.module @test_kernel { + gpu.func @test_kernel(%A: memref<256x256xf16>, %B: memref<256x256xf16>, %C: memref<256x256xf32>) kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c96 = arith.constant 96 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c256 : index + %n = arith.muli %block_id_y, %c256 : index + %c_tdesc = xegpu.create_nd_tdesc %C : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32, #c> + %c_init_value = xegpu.load_nd %c_tdesc[%m, %n] : !xegpu.tensor_desc<256x256xf32, #c> -> vector<256x256xf32> + %a_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a> + %b_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b> + // Prefetch A 3 times. + %a_prefetch_tdesc = xegpu.create_nd_tdesc %A : memref<256x256xf16> -> !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c0] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c32] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %c64] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + // Prefetch B 3 times. + %b_prefetch_tdesc = xegpu.create_nd_tdesc %B : memref<256x256xf16> -> !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c0, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c32, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%c64, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + + %out = scf.for %k = %c0 to %c256 step %c32 + iter_args(%c_value = %c_init_value) + -> (vector<256x256xf32>) { + %a_value = xegpu.load_nd %a_tdesc[%m, %k] : !xegpu.tensor_desc<256x32xf16, #a> -> vector<256x32xf16> + %b_value = xegpu.load_nd %b_tdesc[%k, %n] : !xegpu.tensor_desc<32x256xf16, #b> -> vector<32x256xf16> + // Prefetch next tiles. + %prefetch_offset = arith.addi %k, %c96 : index + xegpu.prefetch_nd %a_prefetch_tdesc[%m, %prefetch_offset] : !xegpu.tensor_desc<256x32xf16, #a_prefetch> + xegpu.prefetch_nd %b_prefetch_tdesc[%prefetch_offset, %n] : !xegpu.tensor_desc<32x256xf16, #b_prefetch> + %c_new_value = xegpu.dpas %a_value, %b_value, %c_value {layout_result_0 = #c} + : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> + scf.yield %c_new_value : vector<256x256xf32> + } + xegpu.store_nd %out, %c_tdesc[%m, %n] : vector<256x256xf32>, !xegpu.tensor_desc<256x256xf32, #c> + gpu.return + } + } + + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1_f16 = arith.constant 1.0 : f16 + %c2_f16 = arith.constant 2.0 : f16 + %c256 = arith.constant 256 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<256x256xf16> + %B = memref.alloc() : memref<256x256xf16> + %C = memref.alloc() : memref<256x256xf32> + %C_ref = memref.alloc() : memref<256x256xf32> + %c_gen_int = arith.constant 0 : i1 + %cf_lower = arith.constant -0.5 : f32 + %cf_upper = arith.constant 0.5 : f32 + // Intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<256x256xf16> + } + } + + // Initialize the B matrix + // Make matrix B an identity matrix + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<256x256xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<256x256xf16> + } + } + } + + // Initialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c256 step %c1 { + scf.for %j = %c0 to %c256 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32> + } + } + + // Run GPU version. + %2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32> + %gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32> + // CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}} + // CHECK-COUNT-256: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + call @printMemrefF32(%gpu_result_cast) : (memref<*xf32>) -> () + + memref.dealloc %A : memref<256x256xf16> + memref.dealloc %B : memref<256x256xf16> + memref.dealloc %C : memref<256x256xf32> + memref.dealloc %C_ref : memref<256x256xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} +} From 6d5daa098d5c4d1e0211eb07e311b7ce169b09df Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Fri, 3 Oct 2025 16:32:22 +0000 Subject: [PATCH 6/7] Address review comments. Change `workitem` to `lane`. --- mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h | 4 ++-- mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h index 67dc6415008f3..c545517ec7739 100644 --- a/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h @@ -63,11 +63,11 @@ struct GPUToNVVMPipelineOptions // Options for the gpu to xevm pipeline. struct GPUToXeVMPipelineOptions : public PassPipelineOptions { - // XeGPU op granularity selection: workgroup | subgroup | workitem + // XeGPU op granularity selection: workgroup | subgroup | lane PassOptions::Option xegpuOpLevel{ *this, "xegpu-op-level", llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | " - "subgroup | workitem"), + "subgroup | lane"), llvm::cl::init("workgroup")}; // General lowering controls. PassOptions::Option use64bitIndex{ diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir index 36b04791ee2dd..ffe29ef35a5c9 100644 --- a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir +++ b/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workitem" \ +// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \ // RUN: | mlir-runner \ // RUN: --shared-libs=%mlir_levelzero_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ From e71a858d185cb6121e12ca7ec15a3abca481ed77 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 7 Oct 2025 14:06:46 +0000 Subject: [PATCH 7/7] Address review comments. Rename `SIMT` test folder to `LANE`. --- .../Integration/Dialect/XeGPU/{SIMT => LANE}/simple_gemm.mlir | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename mlir/test/Integration/Dialect/XeGPU/{SIMT => LANE}/simple_gemm.mlir (100%) diff --git a/mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir b/mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir similarity index 100% rename from mlir/test/Integration/Dialect/XeGPU/SIMT/simple_gemm.mlir rename to mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir