Skip to content

Commit

Permalink
[mlir][spirv] Adapt subview legalization to the updated op semantics.
Browse files Browse the repository at this point in the history
The subview semantics changes recently to allow for more natural
representation of constant offsets and strides. The legalization of
subview op for lowering to SPIR-V needs to account for this.
Also change the linearization to use the strides from the affine map
of a memref.

Differential Revision: https://reviews.llvm.org/D80270
  • Loading branch information
MaheshRavishankar committed May 20, 2020
1 parent 55430f5 commit 0e88eb5
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 89 deletions.
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -2685,6 +2685,21 @@ def SubViewOp : Std_Op<"subview", [
/// constructed with `b` at location `loc`.
SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc);

/// Return the offsets as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed
/// with `b` at location `loc`
SmallVector<Value, 4> getOrCreateOffsets(OpBuilder &b, Location loc);

/// Return the sizes as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed
/// with `b` at location `loc`
SmallVector<Value, 4> getOrCreateSizes(OpBuilder &b, Location loc);

/// Return the strides as Values. Each Value is either the dynamic
/// value specified in the op or a ConstantIndexOp constructed with
/// `b` at location `loc`
SmallVector<Value, 4> getOrCreateStrides(OpBuilder &b, Location loc);

/// A subview result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
Expand Down
31 changes: 6 additions & 25 deletions mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,15 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
// TODO: Aborting when the offsets are static. There might be a way to fold
// the subview op with load even if the offsets have been canonicalized
// away.
if (subViewOp.getNumOffsets() == 0)
return failure();

ValueRange opOffsets = subViewOp.offsets();
SmallVector<Value, 2> opStrides;
if (subViewOp.getNumStrides()) {
// If the strides are dynamic, get the stride operands.
opStrides = llvm::to_vector<2>(subViewOp.strides());
} else {
// When static, the stride operands can be retrieved by taking the strides
// of the result of the subview op, and dividing the strides of the base
// memref.
SmallVector<int64_t, 2> staticStrides;
if (failed(subViewOp.getStaticStrides(staticStrides))) {
return failure();
}
opStrides.reserve(opOffsets.size());
for (auto stride : staticStrides) {
auto constValAttr = rewriter.getIntegerAttr(
IndexType::get(rewriter.getContext()), stride);
opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr));
}
}
assert(opOffsets.size() == opStrides.size());
SmallVector<Value, 4> opOffsets = subViewOp.getOrCreateOffsets(rewriter, loc);
SmallVector<Value, 4> opStrides = subViewOp.getOrCreateStrides(rewriter, loc);
assert(opOffsets.size() == indices.size() &&
"expected as many indices as rank of subview op result type");
assert(opStrides.size() == indices.size() &&
"expected as many indices as rank of subview op result type");

// New indices for the load are the current indices * subview_stride +
// subview_offset.
assert(indices.size() == opStrides.size());
sourceIndices.resize(indices.size());
for (auto index : llvm::enumerate(indices)) {
auto offset = opOffsets[index.index()];
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2548,6 +2548,44 @@ SmallVector<SubViewOp::Range, 8> SubViewOp::getOrCreateRanges(OpBuilder &b,
return res;
}

SmallVector<Value, 4> SubViewOp::getOrCreateOffsets(OpBuilder &b,
Location loc) {
unsigned dynamicIdx = 1;
return llvm::to_vector<4>(llvm::map_range(
static_offsets().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticOffset = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamicStrideOrOffset(staticOffset))
return getOperand(dynamicIdx++);
else
return b.create<ConstantIndexOp>(loc, staticOffset);
}));
}

SmallVector<Value, 4> SubViewOp::getOrCreateSizes(OpBuilder &b, Location loc) {
unsigned dynamicIdx = 1 + offsets().size();
return llvm::to_vector<4>(llvm::map_range(
static_sizes().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticSize = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamic(staticSize))
return getOperand(dynamicIdx++);
else
return b.create<ConstantIndexOp>(loc, staticSize);
}));
}

SmallVector<Value, 4> SubViewOp::getOrCreateStrides(OpBuilder &b,
Location loc) {
unsigned dynamicIdx = 1 + offsets().size() + sizes().size();
return llvm::to_vector<4>(llvm::map_range(
static_strides().cast<ArrayAttr>(), [&](Attribute a) -> Value {
int64_t staticStride = a.cast<IntegerAttr>().getInt();
if (ShapedType::isDynamicStrideOrOffset(staticStride))
return getOperand(dynamicIdx++);
else
return b.create<ConstantIndexOp>(loc, staticStride);
}));
}

LogicalResult
SubViewOp::getStaticStrides(SmallVectorImpl<int64_t> &staticStrides) {
if (!strides().empty())
Expand Down
62 changes: 31 additions & 31 deletions mlir/test/Conversion/GPUToSPIRV/load-store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@ module attributes {

// CHECK-LABEL: spv.module Logical GLSL450
gpu.module @kernels {
// CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable @[[NUMWORKGROUPSVAR:.*]] built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable @[[LOCALINVOCATIONIDVAR:.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable @[[WORKGROUPIDVAR:.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-LABEL: spv.func @load_store_kernel
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>}
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
// CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>}
// CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>}
// CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>}
// CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>}
// CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}
// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>}
// CHECK-SAME: %[[ARG1:.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}
// CHECK-SAME: %[[ARG2:.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32, stride=4> [0]>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 2)>}
// CHECK-SAME: %[[ARG3:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 3), StorageBuffer>}
// CHECK-SAME: %[[ARG4:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>}
// CHECK-SAME: %[[ARG5:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>}
// CHECK-SAME: %[[ARG6:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>}
gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel
attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
// CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
// CHECK: [[WORKGROUPID:%.*]] = spv.Load "Input" [[ADDRESSWORKGROUPID]]
// CHECK: [[WORKGROUPIDX:%.*]] = spv.CompositeExtract [[WORKGROUPID]]{{\[}}0 : i32{{\]}}
// CHECK: [[ADDRESSLOCALINVOCATIONID:%.*]] = spv._address_of [[LOCALINVOCATIONIDVAR]]
// CHECK: [[LOCALINVOCATIONID:%.*]] = spv.Load "Input" [[ADDRESSLOCALINVOCATIONID]]
// CHECK: [[LOCALINVOCATIONIDX:%.*]] = spv.CompositeExtract [[LOCALINVOCATIONID]]{{\[}}0 : i32{{\]}}
// CHECK: %[[ADDRESSWORKGROUPID:.*]] = spv._address_of @[[WORKGROUPIDVAR]]
// CHECK: %[[WORKGROUPID:.*]] = spv.Load "Input" %[[ADDRESSWORKGROUPID]]
// CHECK: %[[WORKGROUPIDX:.*]] = spv.CompositeExtract %[[WORKGROUPID]]{{\[}}0 : i32{{\]}}
// CHECK: %[[ADDRESSLOCALINVOCATIONID:.*]] = spv._address_of @[[LOCALINVOCATIONIDVAR]]
// CHECK: %[[LOCALINVOCATIONID:.*]] = spv.Load "Input" %[[ADDRESSLOCALINVOCATIONID]]
// CHECK: %[[LOCALINVOCATIONIDX:.*]] = spv.CompositeExtract %[[LOCALINVOCATIONID]]{{\[}}0 : i32{{\]}}
%0 = "gpu.block_id"() {dimension = "x"} : () -> index
%1 = "gpu.block_id"() {dimension = "y"} : () -> index
%2 = "gpu.block_id"() {dimension = "z"} : () -> index
Expand All @@ -54,26 +54,26 @@ module attributes {
%9 = "gpu.block_dim"() {dimension = "x"} : () -> index
%10 = "gpu.block_dim"() {dimension = "y"} : () -> index
%11 = "gpu.block_dim"() {dimension = "z"} : () -> index
// CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], [[WORKGROUPIDX]]
// CHECK: %[[INDEX1:.*]] = spv.IAdd %[[ARG3]], %[[WORKGROUPIDX]]
%12 = addi %arg3, %0 : index
// CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], [[LOCALINVOCATIONIDX]]
// CHECK: %[[INDEX2:.*]] = spv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
%13 = addi %arg4, %3 : index
// CHECK: [[STRIDE1_1:%.*]] = spv.constant 4 : i32
// CHECK: [[OFFSET1_1:%.*]] = spv.IMul [[STRIDE1_1]], [[INDEX1]] : i32
// CHECK: [[STRIDE1_2:%.*]] = spv.constant 1 : i32
// CHECK: [[UPDATE1_2:%.*]] = spv.IMul [[STRIDE1_2]], [[INDEX2]] : i32
// CHECK: [[OFFSET1_2:%.*]] = spv.IAdd [[OFFSET1_1]], [[UPDATE1_2]] : i32
// CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
// CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[ZERO1]], [[OFFSET1_2]]{{\]}}
// CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
// CHECK: %[[STRIDE1_1:.*]] = spv.constant 4 : i32
// CHECK: %[[OFFSET1_1:.*]] = spv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
// CHECK: %[[STRIDE1_2:.*]] = spv.constant 1 : i32
// CHECK: %[[UPDATE1_2:.*]] = spv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
// CHECK: %[[OFFSET1_2:.*]] = spv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
// CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32
// CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO1]], %[[OFFSET1_2]]{{\]}}
// CHECK-NEXT: %[[VAL1:.*]] = spv.Load "StorageBuffer" %[[PTR1]]
%14 = load %arg0[%12, %13] : memref<12x4xf32>
// CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: [[VAL2:%.*]] = spv.Load "StorageBuffer" [[PTR2]]
// CHECK: %[[PTR2:.*]] = spv.AccessChain %[[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: %[[VAL2:.*]] = spv.Load "StorageBuffer" %[[PTR2]]
%15 = load %arg1[%12, %13] : memref<12x4xf32>
// CHECK: [[VAL3:%.*]] = spv.FAdd [[VAL1]], [[VAL2]]
// CHECK: %[[VAL3:.*]] = spv.FAdd %[[VAL1]], %[[VAL2]]
%16 = addf %14, %15 : f32
// CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: spv.Store "StorageBuffer" [[PTR3]], [[VAL3]]
// CHECK: %[[PTR3:.*]] = spv.AccessChain %[[ARG2]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
// CHECK-NEXT: spv.Store "StorageBuffer" %[[PTR3]], %[[VAL3]]
store %16, %arg2[%12, %13] : memref<12x4xf32>
gpu.return
}
Expand Down
36 changes: 19 additions & 17 deletions mlir/test/Conversion/GPUToSPIRV/loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,29 @@ module attributes {
gpu.module @kernels {
gpu.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) kernel
attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
// CHECK: [[LB:%.*]] = spv.constant 4 : i32
// CHECK: %[[LB:.*]] = spv.constant 4 : i32
%lb = constant 4 : index
// CHECK: [[UB:%.*]] = spv.constant 42 : i32
// CHECK: %[[UB:.*]] = spv.constant 42 : i32
%ub = constant 42 : index
// CHECK: [[STEP:%.*]] = spv.constant 2 : i32
// CHECK: %[[STEP:.*]] = spv.constant 2 : i32
%step = constant 2 : index
// CHECK: spv.loop {
// CHECK-NEXT: spv.Branch [[HEADER:\^.*]]([[LB]] : i32)
// CHECK: [[HEADER]]([[INDVAR:%.*]]: i32):
// CHECK: [[CMP:%.*]] = spv.SLessThan [[INDVAR]], [[UB]] : i32
// CHECK: spv.BranchConditional [[CMP]], [[BODY:\^.*]], [[MERGE:\^.*]]
// CHECK: [[BODY]]:
// CHECK: [[STRIDE1:%.*]] = spv.constant 1 : i32
// CHECK: [[OFFSET1:%.*]] = spv.IMul [[STRIDE1]], [[INDVAR]] : i32
// CHECK: spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[OFFSET1]]{{\]}} : {{.*}}
// CHECK: [[STRIDE2:%.*]] = spv.constant 1 : i32
// CHECK: [[OFFSET2:%.*]] = spv.IMul [[STRIDE2]], [[INDVAR]] : i32
// CHECK: spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[OFFSET2]]{{\]}} : {{.*}}
// CHECK: [[INCREMENT:%.*]] = spv.IAdd [[INDVAR]], [[STEP]] : i32
// CHECK: spv.Branch [[HEADER]]([[INCREMENT]] : i32)
// CHECK: [[MERGE]]
// CHECK-NEXT: spv.Branch ^[[HEADER:.*]](%[[LB]] : i32)
// CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32):
// CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
// CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
// CHECK: ^[[BODY]]:
// CHECK: %[[STRIDE1:.*]] = spv.constant 1 : i32
// CHECK: %[[INDEX1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
// CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32
// CHECK: spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}}
// CHECK: %[[STRIDE2:.*]] = spv.constant 1 : i32
// CHECK: %[[INDEX2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
// CHECK: %[[ZERO2:.*]] = spv.constant 0 : i32
// CHECK: spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]]
// CHECK: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
// CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)
// CHECK: ^[[MERGE]]
// CHECK: spv._merge
// CHECK: }
scf.for %arg4 = %lb to %ub step %step {
Expand Down
32 changes: 16 additions & 16 deletions mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ module {
func @fold_static_stride_subview
(%arg0 : memref<12x32xf32>, %arg1 : index,
%arg2 : index, %arg3 : index, %arg4 : index) {
// CHECK: %[[C2:.*]] = constant 2
// CHECK: %[[C3:.*]] = constant 3
// CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C2]]
// CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]]
// CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[C3]]
// CHECK: %[[T3:.*]] = addi %[[ARG2]], %[[T2]]
// CHECK: %[[LOADVAL:.*]] = load %[[ARG0]][%[[T1]], %[[T3]]]
// CHECK: %[[STOREVAL:.*]] = sqrt %[[LOADVAL]]
// CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C2]]
// CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]]
// CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[C3]]
// CHECK: %[[T9:.*]] = addi %[[ARG2]], %[[T8]]
// CHECK store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
%0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
// CHECK-DAG: %[[C2:.*]] = constant 2
// CHECK-DAG: %[[C3:.*]] = constant 3
// CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C3]]
// CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]]
// CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[ARG2]]
// CHECK: %[[T3:.*]] = addi %[[T2]], %[[C2]]
// CHECK: %[[LOADVAL:.*]] = load %[[ARG0]][%[[T1]], %[[T3]]]
// CHECK: %[[STOREVAL:.*]] = sqrt %[[LOADVAL]]
// CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C3]]
// CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]]
// CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[ARG2]]
// CHECK: %[[T9:.*]] = addi %[[T8]], %[[C2]]
// CHECK: store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]]
%0 = subview %arg0[%arg1, 2][4, 4][3, %arg2] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [96, ?]>
%1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
%2 = sqrt %1 : f32
store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [96, ?]>
return
}

Expand Down

0 comments on commit 0e88eb5

Please sign in to comment.