Skip to content

Commit

Permalink
[mlir][spirv] Fix UnifyAliasedResourcePass for 64-bit index
Browse files Browse the repository at this point in the history
Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D145079
  • Loading branch information
antiagainst committed Mar 14, 2023
1 parent 68c14f5 commit 141b7d4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
25 changes: 15 additions & 10 deletions mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
Expand Up @@ -366,7 +366,6 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
}

Location loc = acOp.getLoc();
auto i32Type = rewriter.getI32Type();

if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
// The source indices are for a buffer with scalar element types. Rewrite
Expand All @@ -376,16 +375,19 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
int ratio = dstNumBytes / srcNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));

auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
Type indexType = oldIndex.getType();

int ratio = dstNumBytes / srcNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, indexType, rewriter.getIntegerAttr(indexType, ratio));

indices.back() =
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
indices.push_back(
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));

rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
Expand All @@ -400,14 +402,17 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
int ratio = srcNumBytes / dstNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));

auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
Type indexType = oldIndex.getType();

int ratio = srcNumBytes / dstNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, indexType, rewriter.getIntegerAttr(indexType, ratio));

indices.back() =
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);

rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Expand Up @@ -32,6 +32,33 @@ spirv.module Logical GLSL450 {

// -----

spirv.module Logical GLSL450 {
spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>

spirv.func @load_store_scalar_64bit(%index: i64) -> f32 "None" {
%c0 = spirv.Constant 0 : i64
%addr = spirv.mlir.addressof @var01s : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
%ac = spirv.AccessChain %addr[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i64, i64
%value = spirv.Load "StorageBuffer" %ac : f32
spirv.Store "StorageBuffer" %ac, %value : f32
spirv.ReturnValue %value : f32
}
}

// CHECK-LABEL: spirv.module

// CHECK-NOT: @var01s
// CHECK: spirv.GlobalVariable @var01v bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
// CHECK-NOT: @var01s

// CHECK: spirv.func @load_store_scalar_64bit(%[[INDEX:.+]]: i64)
// CHECK-DAG: %[[C4:.+]] = spirv.Constant 4 : i64
// CHECK: spirv.SDiv %[[INDEX]], %[[C4]] : i64
// CHECK: spirv.SMod %[[INDEX]], %[[C4]] : i64

// -----

spirv.module Logical GLSL450 {
spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
Expand Down

0 comments on commit 141b7d4

Please sign in to comment.