Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- Passes.h - GPU NVVM pipeline entry points --------------------------===//
//===- Passes.h - GPU NVVM/XeVM pipeline entry points----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -60,6 +60,53 @@ struct GPUToNVVMPipelineOptions
llvm::cl::init(false)};
};

// Options for the gpu to xevm pipeline.
struct GPUToXeVMPipelineOptions
: public PassPipelineOptions<GPUToXeVMPipelineOptions> {
// XeGPU op granularity selection: workgroup | subgroup | lane
PassOptions::Option<std::string> xegpuOpLevel{
*this, "xegpu-op-level",
llvm::cl::desc("Granularity of XeGPU operations to target: workgroup | "
"subgroup | lane"),
llvm::cl::init("workgroup")};
// General lowering controls.
PassOptions::Option<bool> use64bitIndex{
*this, "use-64bit-index",
llvm::cl::desc("Bitwidth of the index type (host & device)"),
llvm::cl::init(true)};
PassOptions::Option<bool> kernelBarePtrCallConv{
*this, "kernel-bare-ptr-calling-convention",
llvm::cl::desc("Use bare pointer calling convention for device kernels"),
llvm::cl::init(false)};
PassOptions::Option<bool> hostBarePtrCallConv{
*this, "host-bare-ptr-calling-convention",
llvm::cl::desc("Use bare pointer calling convention for host launches"),
llvm::cl::init(false)};
PassOptions::Option<std::string> binaryFormat{
*this, "binary-format",
llvm::cl::desc("Final GPU binary emission format (e.g. fatbin)"),
llvm::cl::init("fatbin")};
// Options mirroring xevm-attach-target (GpuXeVMAttachTarget).
PassOptions::Option<std::string> xevmModuleMatcher{
*this, "xevm-module-matcher",
llvm::cl::desc("Regex to match gpu.module names for XeVM target attach"),
llvm::cl::init("")};
PassOptions::Option<std::string> zebinTriple{
*this, "zebin-triple", llvm::cl::desc("Target triple for XeVM codegen"),
llvm::cl::init("spirv64-unknown-unknown")};
PassOptions::Option<std::string> zebinChip{
*this, "zebin-chip", llvm::cl::desc("Target chip (e.g. pvc, bmg)"),
llvm::cl::init("bmg")};
PassOptions::Option<unsigned> optLevel{
*this, "opt-level",
llvm::cl::desc("Optimization level for attached target/codegen"),
llvm::cl::init(2)};
PassOptions::Option<std::string> cmdOptions{
*this, "igc-cmd-options",
llvm::cl::desc("Additional downstream compiler command line options"),
llvm::cl::init("")};
};

//===----------------------------------------------------------------------===//
// Building and Registering.
//===----------------------------------------------------------------------===//
Expand All @@ -70,8 +117,15 @@ struct GPUToNVVMPipelineOptions
void buildLowerToNVVMPassPipeline(OpPassManager &pm,
const GPUToNVVMPipelineOptions &options);

/// Adds the GPU to XeVM pipeline to the given pass manager. Transforms main
/// dialects into XeVM targets. Begins with GPU code regions, then handles host
/// code.
void buildLowerToXeVMPassPipeline(OpPassManager &pm,
const GPUToXeVMPipelineOptions &options);

/// Register all pipeleines for the `gpu` dialect.
void registerGPUToNVVMPipeline();
void registerGPUToXeVMPipeline();

} // namespace gpu
} // namespace mlir
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/GPU/Pipelines/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRGPUPipelines
GPUToNVVMPipeline.cpp
GPUToXeVMPipeline.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
Expand All @@ -12,11 +13,15 @@ add_mlir_dialect_library(MLIRGPUPipelines
MLIRLinalgTransforms
MLIRAffineToStandard
MLIRGPUToNVVMTransforms
MLIRXeGPUToXeVM
MLIRIndexToLLVM
MLIRMathToLLVM
MLIRNVGPUToNVVM
MLIRNVVMToLLVM
MLIRReconcileUnrealizedCasts
MLIRSCFToControlFlow
MLIRVectorToSCF
MLIRXeGPUTransforms
MLIRXeGPUToXeVM
MLIRXeVMToLLVM
)
145 changes: 145 additions & 0 deletions mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//===- GPUToXeVMPipeline.cpp - Lowering pipeline to XeVM/LLVM -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing the lowering to XeVM as a generally
// usable sink pass. If XeGPU ops are used, it expects the MLIR code to have
// XeGPU ops already embedded in gpu code.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Pipelines/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Target/LLVM/XeVM/Target.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;

namespace {
//===----------------------------------------------------------------------===//
// Common pipeline
//===----------------------------------------------------------------------===//
void buildCommonPassPipeline(
OpPassManager &pm, const mlir::gpu::GPUToXeVMPipelineOptions &options) {
// builtin.module scope passes
pm.addPass(createCSEPass());
{
GpuXeVMAttachTargetOptions xevmTargetOptions;
xevmTargetOptions.moduleMatcher = options.xevmModuleMatcher;
xevmTargetOptions.triple = options.zebinTriple;
xevmTargetOptions.chip = options.zebinChip;
xevmTargetOptions.optLevel = options.optLevel;
xevmTargetOptions.cmdOptions = options.cmdOptions;
pm.addPass(createGpuXeVMAttachTarget(xevmTargetOptions));
}
}

//===----------------------------------------------------------------------===//
// GPUModule-specific stuff.
//===----------------------------------------------------------------------===//
void buildGpuPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
if (options.xegpuOpLevel == "workgroup") {
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
pm.addNestedPass<gpu::GPUModuleOp>(createLowerAffinePass());
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
}
if (options.xegpuOpLevel == "subgroup" ||
options.xegpuOpLevel == "workgroup") {
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
pm.addNestedPass<gpu::GPUModuleOp>(createLoopInvariantCodeMotionPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUVectorLinearize());
}
pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeGPUToXeVMPass());
{
ConvertGpuOpsToLLVMSPVOpsOptions gpuToLLVMSPVOptions;
gpuToLLVMSPVOptions.use64bitIndex = options.use64bitIndex;
pm.addNestedPass<gpu::GPUModuleOp>(
createConvertGpuOpsToLLVMSPVOps(gpuToLLVMSPVOptions));
}
pm.addNestedPass<gpu::GPUModuleOp>(createConvertXeVMToLLVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
}

//===----------------------------------------------------------------------===//
// Host Post-GPU pipeline
//===----------------------------------------------------------------------===//
void buildHostPostPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
pm.addNestedPass<func::FuncOp>(LLVM::createLLVMRequestCWrappersPass());
pm.addNestedPass<func::FuncOp>(createGpuAsyncRegionPass());
pm.addPass(createReconcileUnrealizedCastsPass());
pm.addPass(createConvertVectorToSCFPass());
pm.addPass(createSCFToControlFlowPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
{
GpuToLLVMConversionPassOptions gpuToLLVMOptions;
gpuToLLVMOptions.hostBarePtrCallConv = options.hostBarePtrCallConv;
gpuToLLVMOptions.kernelBarePtrCallConv = options.kernelBarePtrCallConv;
pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMOptions));
}
pm.addPass(createConvertToLLVMPass());
pm.addPass(createLowerAffinePass());
pm.addPass(createReconcileUnrealizedCastsPass());
// gpu-module-to-binary
{
GpuModuleToBinaryPassOptions gpuToModuleBinOptions;
gpuToModuleBinOptions.compilationTarget = options.binaryFormat;
gpuToModuleBinOptions.cmdOptions = options.cmdOptions;
pm.addPass(createGpuModuleToBinaryPass(gpuToModuleBinOptions));
}
}
} // namespace

void mlir::gpu::buildLowerToXeVMPassPipeline(
OpPassManager &pm, const GPUToXeVMPipelineOptions &options) {
// Common pipelines
buildCommonPassPipeline(pm, options);

// GPUModule-specific stuff
buildGpuPassPipeline(pm, options);

// Host post-GPUModule-specific stuff
buildHostPostPipeline(pm, options);
}

void mlir::gpu::registerGPUToXeVMPipeline() {
PassPipelineRegistration<GPUToXeVMPipelineOptions>(
"gpu-lower-to-xevm-pipeline",
"The default GPU to XeVM lowering pipeline. It starts by lowering GPU "
"code to the "
"specified compilation target (default is fatbin) then lowers the host "
"code.",
buildLowerToXeVMPassPipeline);
}
1 change: 1 addition & 0 deletions mlir/lib/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ void mlir::registerAllPasses() {
sparse_tensor::registerSparseTensorPipelines();
tosa::registerTosaToLinalgPipelines();
gpu::registerGPUToNVVMPipeline();
gpu::registerGPUToXeVMPipeline();
}
121 changes: 121 additions & 0 deletions mlir/test/Integration/Dialect/XeGPU/LANE/simple_gemm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \
// RUN: | mlir-runner \
// RUN: --shared-libs=%mlir_levelzero_runtime \
// RUN: --shared-libs=%mlir_runner_utils \
// RUN: --entry-point-result=void \
// RUN: | FileCheck %s

module @gemm attributes {gpu.container_module} {
gpu.module @kernel {
gpu.func @simple_gemm(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c256 = arith.constant 256 : index
%block_x = gpu.block_id x
%block_y = gpu.block_id y
%x_block_offset = arith.muli %block_x, %c8 : index
%y_block_offset = arith.muli %block_y, %c16 : index

%c_tdesc = xegpu.create_nd_tdesc %c : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
%c_init_value = xegpu.load_nd %c_tdesc[%x_block_offset, %y_block_offset] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
%a_tdesc = xegpu.create_nd_tdesc %a : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16>
%b_tdesc = xegpu.create_nd_tdesc %b : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>

%r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> (vector<8xf32>) {
%a_val = xegpu.load_nd %a_tdesc[%x_block_offset, %k] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
%b_val = xegpu.load_nd %b_tdesc[%k, %y_block_offset] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
%dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
scf.yield %dpas : vector<8xf32>
}
xegpu.store_nd %r, %c_tdesc[%x_block_offset, %y_block_offset] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}

func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%memref_a = gpu.alloc () : memref<256x256xf16>
gpu.memcpy %memref_a, %a : memref<256x256xf16>, memref<256x256xf16>
%memref_b = gpu.alloc () : memref<256x256xf16>
gpu.memcpy %memref_b, %b : memref<256x256xf16>, memref<256x256xf16>
%memref_c = gpu.alloc () : memref<256x256xf32>
gpu.memcpy %memref_c, %c : memref<256x256xf32>, memref<256x256xf32>
gpu.launch_func @kernel::@simple_gemm blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>)
gpu.wait // Wait for the kernel to finish.
gpu.memcpy %c, %memref_c : memref<256x256xf32>, memref<256x256xf32>
gpu.dealloc %memref_a : memref<256x256xf16>
gpu.dealloc %memref_b : memref<256x256xf16>
gpu.dealloc %memref_c : memref<256x256xf32>
return %c : memref<256x256xf32>
}

func.func @main() attributes {llvm.emit_c_interface} {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_f16 = arith.constant 1.0 : f16
%c2_f16 = arith.constant 2.0 : f16
%c256 = arith.constant 256 : index
%cf_0 = arith.constant 0.0 : f16
%cf_1 = arith.constant 1.0 : f16
%A = memref.alloc() : memref<256x256xf16>
%B = memref.alloc() : memref<256x256xf16>
%C = memref.alloc() : memref<256x256xf32>
%C_ref = memref.alloc() : memref<256x256xf32>
%c_gen_int = arith.constant 0 : i1
%cf_lower = arith.constant -0.5 : f32
%cf_upper = arith.constant 0.5 : f32

// Initialize matrix A ; A[i, j] = j
scf.for %i = %c0 to %c256 step %c1 {
scf.for %j = %c0 to %c256 step %c1 {
%t = index.castu %j : index to i16
%val = arith.uitofp %t : i16 to f16
memref.store %val, %A[%i, %j] : memref<256x256xf16>
}
}

// Initialize the B matrix.
// Make matrix B an identity matrix.
scf.for %i = %c0 to %c256 step %c1 {
scf.for %j = %c0 to %c256 step %c1 {
%i_i32 = index.castu %i : index to i32
%j_i32 = index.castu %j : index to i32
%i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32

scf.if %i_j_same {
memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
} else {
memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
}
}
}

// Initialize matrix C and C_ref ; C[i, j] = 0
%c0_f32 = arith.constant 0.0 : f32
scf.for %i = %c0 to %c256 step %c1 {
scf.for %j = %c0 to %c256 step %c1 {
memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
}
}

// Run GPU version.
%2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
%gpu_result_cast = memref.cast %2 : memref<256x256xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
// CHECK-COUNT-256: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
call @printMemrefF32(%gpu_result_cast) : (memref<*xf32>) -> ()
memref.dealloc %A : memref<256x256xf16>
memref.dealloc %B : memref<256x256xf16>
memref.dealloc %C : memref<256x256xf32>
memref.dealloc %C_ref : memref<256x256xf32>
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
}
Loading
Loading