Skip to content

Commit

Permalink
[mlir][spirv] GPUToSPIRVPass: support case when TargetEnv attribute…
Browse files Browse the repository at this point in the history
… attached to the `gpu.module`

Previously, only case when `TargetEnv` was attached to the top level `ModuleOp` was supported.

Differential Revision: https://reviews.llvm.org/D135907
  • Loading branch information
Hardcode84 committed Oct 14, 2022
1 parent 4d2d426 commit 8ca5058
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 26 deletions.
9 changes: 9 additions & 0 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,15 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
spvModuleRegion.begin());
// The spirv.module build method adds a block. Remove that.
rewriter.eraseBlock(&spvModuleRegion.back());

// Some of the patterns call `lookupTargetEnv` during conversion and they
// will fail if called after GPUModuleConversion and we don't preserve
// `TargetEnv` attribute.
// Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);

rewriter.eraseOp(moduleOp);
return success();
}
Expand Down
54 changes: 29 additions & 25 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,37 +63,41 @@ void GPUToSPIRVPass::runOnOperation() {
gpuModules.push_back(builder.clone(*moduleOp.getOperation()));
});

// Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
// Run conversion for each module independently as they can have different
// TargetEnv attributes.
for (Operation *gpuModule : gpuModules) {
// Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
std::unique_ptr<ConversionTarget> target =
spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);

RewritePatternSet patterns(context);
spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);

if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
return signalPassFailure();
}

auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
std::unique_ptr<ConversionTarget> target =
spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
SPIRVConversionTarget::get(targetAttr);

SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
populateGPUToSPIRVPatterns(typeConverter, patterns);

if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);

if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
return signalPassFailure();
}

auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);

SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);

// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);

if (failed(applyFullConversion(gpuModules, *target, std::move(patterns))))
return signalPassFailure();
}

std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
35 changes: 34 additions & 1 deletion mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -verify-diagnostics -split-input-file %s -o - | FileCheck %s

module attributes {
gpu.container_module,
Expand Down Expand Up @@ -28,3 +28,36 @@ module attributes {
return
}
}

// -----

module attributes {
gpu.container_module
} {
gpu.module @kernels attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Addresses], []>, #spirv.resource_limits<>>
} {
// CHECK-LABEL: spirv.module @{{.*}} Physical64 OpenCL
// CHECK-SAME: spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Addresses], []>, #spirv.resource_limits<>>
// CHECK: spirv.func
// CHECK-SAME: {{%.*}}: f32
// CHECK-NOT: spirv.interface_var_abi
// CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.array<12 x f32>, CrossWorkgroup>
// CHECK-NOT: spirv.interface_var_abi
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
gpu.return
}
}

func.func @main() {
%0 = "op"() : () -> (f32)
%1 = "op"() : () -> (memref<12xf32, #spirv.storage_class<CrossWorkgroup>>)
%cst = arith.constant 1 : index
gpu.launch_func @kernels::@basic_module_structure
blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
args(%0 : f32, %1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>)
return
}
}

0 comments on commit 8ca5058

Please sign in to comment.