Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[mlir][spirv] Query target environment for mapping memory space
Checks spirv::TargetEnv from op to see if it contains either Kernel or Shader capabilities.
If it does, then it will set the memory space mapping accordingly.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134317
  • Loading branch information
raikonenfnu committed Sep 20, 2022
1 parent 079a5ff commit 1dc48a9
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
10 changes: 10 additions & 0 deletions mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -325,6 +326,15 @@ void MapMemRefStorageClassPass::runOnOperation() {
MLIRContext *context = &getContext();
Operation *op = getOperation();

if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
spirv::TargetEnv targetEnv(attr);
if (targetEnv.allows(spirv::Capability::Kernel)) {
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
} else if (targetEnv.allows(spirv::Capability::Shader)) {
memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
}
}

auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);

Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
Expand Up @@ -114,3 +114,56 @@ func.func @missing_mapping() {
%0 = "dialect.memref_producer"() : () -> (memref<f32, 2>)
return
}

// -----

/// Checks memory maps to OpenCL mapping if Kernel capability is enabled.
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, #spv.resource_limits<>> } {
func.func @operand_result() {
// CHECK: memref<f32, #spv.storage_class<CrossWorkgroup>>
%0 = "dialect.memref_producer"() : () -> (memref<f32>)
// CHECK: memref<4xi32, #spv.storage_class<Generic>>
%1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
// CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
%2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
// CHECK: memref<*xf16, #spv.storage_class<UniformConstant>>
%3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)


"dialect.memref_consumer"(%0) : (memref<f32>) -> ()
// CHECK: memref<4xi32, #spv.storage_class<Generic>>
"dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
// CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
"dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
// CHECK: memref<*xf16, #spv.storage_class<UniformConstant>>
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()

return
}
}

// -----

/// Checks memory maps to Vulkan mapping if Shader capability is enabled.
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>> } {
func.func @operand_result() {
// CHECK: memref<f32, #spv.storage_class<StorageBuffer>>
%0 = "dialect.memref_producer"() : () -> (memref<f32>)
// CHECK: memref<4xi32, #spv.storage_class<Generic>>
%1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
// CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
%2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
// CHECK: memref<*xf16, #spv.storage_class<Uniform>>
%3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)


"dialect.memref_consumer"(%0) : (memref<f32>) -> ()
// CHECK: memref<4xi32, #spv.storage_class<Generic>>
"dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
// CHECK: memref<?x4xf16, #spv.storage_class<Workgroup>>
"dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
// CHECK: memref<*xf16, #spv.storage_class<Uniform>>
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
return
}
}

0 comments on commit 1dc48a9

Please sign in to comment.