Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
"Run function signature conversion to convert vector types">,
Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
/*default=*/"true",
"Run vector unrolling to convert vector types in function bodies">
"Run vector unrolling to convert vector types in function bodies">,
Option<"convertGPUModules", "convert-gpu-modules", "bool",
/*default=*/"false",
"Clone and convert GPU modules">
];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRArithToSPIRV
MLIRArithTransforms
MLIRFuncToSPIRV
MLIRGPUDialect
MLIRGPUToSPIRV
MLIRIndexToSPIRV
MLIRIR
Expand Down
100 changes: 69 additions & 31 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
Expand All @@ -40,6 +41,35 @@ using namespace mlir;

namespace {

/// Map memRef memory space to SPIR-V storage class.
void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
spirv::TargetEnv targetEnv(targetAttr);
bool targetEnvSupportsKernelCapability =
targetEnv.allows(spirv::Capability::Kernel);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
targetEnvSupportsKernelCapability
? spirv::mapMemorySpaceToOpenCLStorageClass
: spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
spirv::convertMemRefTypesAndAttrs(op, converter);
}

/// Populate patterns for each dialect.
void populateConvertToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
RewritePatternSet &patterns) {
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
populateGPUToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
}

/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
Expand All @@ -57,38 +87,46 @@ struct ConvertToSPIRVPass final
if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
return signalPassFailure();

spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;

// Map MemRef memory space to SPIR-V storage class.
spirv::TargetEnv targetEnv(targetAttr);
bool targetEnvSupportsKernelCapability =
targetEnv.allows(spirv::Capability::Kernel);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
targetEnvSupportsKernelCapability
? spirv::mapMemorySpaceToOpenCLStorageClass
: spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
spirv::convertMemRefTypesAndAttrs(op, converter);
// Generic conversion.
if (!convertGPUModules) {
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
mapToMemRef(op, targetAttr);
populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
patterns);
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
return;
}

// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
populateGPUToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
// Clone each GPU kernel module for conversion, given that the GPU
// launch op still needs the original GPU kernel module.
SmallVector<Operation *, 1> gpuModules;
OpBuilder builder(context);
op->walk([&](gpu::GPUModuleOp gpuModule) {
builder.setInsertionPoint(gpuModule);
gpuModules.push_back(builder.clone(*gpuModule));
});
// Run conversion for each module independently as they can have
// different TargetEnv attributes.
for (Operation *gpuModule : gpuModules) {
spirv::TargetEnvAttr targetAttr =
spirv::lookupTargetEnvOrDefault(gpuModule);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
mapToMemRef(gpuModule, targetAttr);
populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
patterns);
if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
return signalPassFailure();
}
}
};

Expand Down
110 changes: 110 additions & 0 deletions mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// RUN: mlir-opt -convert-to-spirv="convert-gpu-modules=true run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
} {
// CHECK-LABEL: func.func @main
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: gpu.launch_func @[[$KERNELS_1:.*]]::@[[$BUILTIN_WG_ID_X:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
// CHECK: gpu.launch_func @[[$KERNELS_2:.*]]::@[[$BUILTIN_WG_ID_Y:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
func.func @main() {
%c1 = arith.constant 1 : index
gpu.launch_func @kernels_1::@builtin_workgroup_id_x
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
gpu.launch_func @KERNELS_2::@builtin_workgroup_id_y
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
return
}

// CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
// CHECK: spirv.func @[[$BUILTIN_WG_ID_X]]
// CHECK: spirv.mlir.addressof
// CHECK: spirv.Load "Input"
// CHECK: spirv.CompositeExtract
gpu.module @kernels_1 {
gpu.func @builtin_workgroup_id_x() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%0 = gpu.block_id x
gpu.return
}
}
// CHECK: gpu.module @[[$KERNELS_1]]
// CHECK: gpu.func @[[$BUILTIN_WG_ID_X]]
// CHECK gpu.block_id x
// CHECK: gpu.return

// CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
// CHECK: spirv.func @[[$BUILTIN_WG_ID_Y]]
// CHECK: spirv.mlir.addressof
// CHECK: spirv.Load "Input"
// CHECK: spirv.CompositeExtract
gpu.module @KERNELS_2 {
gpu.func @builtin_workgroup_id_y() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%0 = gpu.block_id y
gpu.return
}
}
// CHECK: gpu.module @[[$KERNELS_2]]
// CHECK: gpu.func @[[$BUILTIN_WG_ID_Y]]
// CHECK gpu.block_id y
// CHECK: gpu.return
}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
} {
// CHECK-LABEL: func.func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2xi32>, %[[ARG1:.*]]: memref<4xi32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: gpu.launch_func @[[$KERNEL_MODULE:.*]]::@[[$KERNEL_FUNC:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]]) args(%[[ARG0]] : memref<2xi32>, %[[ARG1]] : memref<4xi32>)
func.func @main(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>) {
%c1 = arith.constant 1 : index
gpu.launch_func @kernels::@kernel_foo
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
args(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>)
return
}

// CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
// CHECK: spirv.func @[[$KERNEL_FUNC]]
// CHECK-SAME: %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<2 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
// CHECK-SAME: %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
gpu.module @kernels {
gpu.func @kernel_foo(%arg0 : memref<2xi32>, %arg1 : memref<4xi32>)
kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
// CHECK: spirv.Constant
// CHECK: spirv.Constant dense<0>
%idx0 = arith.constant 0 : index
%vec0 = arith.constant dense<[0, 0]> : vector<2xi32>
// CHECK: spirv.AccessChain
// CHECK: spirv.Load "StorageBuffer"
%val = memref.load %arg0[%idx0] : memref<2xi32>
// CHECK: spirv.CompositeInsert
%vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32>
// CHECK: spirv.VectorShuffle
%shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32>
// CHECK: spirv.CompositeExtract
%res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32>
// CHECK: spirv.AccessChain
// CHECK: spirv.Store "StorageBuffer"
memref.store %res, %arg1[%idx0]: memref<4xi32>
// CHECK: spirv.Return
gpu.return
}
}
// CHECK: gpu.module @[[$KERNEL_MODULE]]
// CHECK: gpu.func @[[$KERNEL_FUNC]]
// CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32>
// CHECK: arith.constant
// CHECK: memref.load
// CHECK: vector.insertelement
// CHECK: vector.shuffle
// CHECK: vector.extractelement
// CHECK: memref.store
// CHECK: gpu.return
}
5 changes: 4 additions & 1 deletion mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
Expand Down Expand Up @@ -64,7 +65,9 @@ static LogicalResult runMLIRPasses(Operation *op,
passManager.addPass(createGpuKernelOutliningPass());
passManager.addPass(memref::createFoldMemRefAliasOpsPass());

passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
ConvertToSPIRVPassOptions convertToSPIRVOptions{};
convertToSPIRVOptions.convertGPUModules = true;
passManager.addPass(createConvertToSPIRVPass(convertToSPIRVOptions));
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
modulePM.addPass(spirv::createSPIRVLowerABIAttributesPass());
modulePM.addPass(spirv::createSPIRVUpdateVCEPass());
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8409,6 +8409,7 @@ cc_library(
":ArithTransforms",
":ConversionPassIncGen",
":FuncToSPIRV",
":GPUDialect",
":GPUToSPIRV",
":IR",
":IndexToSPIRV",
Expand Down
Loading