Skip to content

Commit

Permalink
[mlir] subview op lowering for target memrefs with const offset
Browse files Browse the repository at this point in the history
The current standard to llvm conversion pass lowers subview ops only if
dynamic offsets are provided. This commit extends the lowering with a
code path that uses the constant offset of the target memref for the
subview op lowering (see Example 3 of the subview op definition for an
example) if no dynamic offsets are provided.

Differential Revision: https://reviews.llvm.org/D74280
  • Loading branch information
Tobias Gysi authored and ftynse committed Feb 10, 2020
1 parent ed3527c commit 1555d7f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
26 changes: 18 additions & 8 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2304,7 +2304,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
// Currently, only rank > 0 and full or no operands are supported. Fail to
// convert otherwise.
unsigned rank = sourceMemRefType.getRank();
if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) ||
if (viewMemRefType.getRank() == 0 ||
(!dynamicOffsets.empty() && rank != dynamicOffsets.size()) ||
(!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
(!dynamicStrides.empty() && rank != dynamicStrides.size()))
return matchFailure();
Expand All @@ -2315,6 +2316,11 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
if (failed(successStrides))
return matchFailure();

// Fail to convert if neither a dynamic nor static offset is available.
if (dynamicOffsets.empty() &&
offset == MemRefType::getDynamicStrideOrOffset())
return matchFailure();

// Create the descriptor.
MemRefDescriptor sourceMemRef(operands.front());
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
Expand Down Expand Up @@ -2348,14 +2354,18 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
}

// Offset.
Value baseOffset = sourceMemRef.offset(rewriter, loc);
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
Value min = dynamicOffsets[i];
baseOffset = rewriter.create<LLVM::AddOp>(
loc, baseOffset,
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
if (dynamicOffsets.empty()) {
targetMemRef.setConstantOffset(rewriter, loc, offset);
} else {
Value baseOffset = sourceMemRef.offset(rewriter, loc);
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
Value min = dynamicOffsets[i];
baseOffset = rewriter.create<LLVM::AddOp>(
loc, baseOffset,
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
}
targetMemRef.setOffset(rewriter, loc, baseOffset);
}
targetMemRef.setOffset(rewriter, loc, baseOffset);

// Update sizes and strides.
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,31 @@ func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4
return
}

// CHECK-LABEL: func @subview_const_stride_and_offset(
func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]

// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[CST62:.*]] = llvm.mlir.constant(62 : i64)
// CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i64)
// CHECK: %[[CST8:.*]] = llvm.mlir.constant(8 : index)
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i64)
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
// CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%1 = subview %0[][][] :
memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>>
return
}

// -----

module {
Expand Down

0 comments on commit 1555d7f

Please sign in to comment.