Skip to content

Commit

Permalink
[mlir] translate memref.reshape with static shapes but dynamic dims
Browse files Browse the repository at this point in the history
Prior to this patch, the lowering of memref.reshape operations to the
LLVM dialect failed if the shape argument had a static shape with
dynamic dimensions.  This patch adds the necessary support so that when
the shape argument has dynamic values, the lowering probes the dimension
at runtime to set the size in the `MemRefDescriptor` type.  This patch
also computes the stride for dynamic dimensions by deriving it from the
sizes of the inner dimensions.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D126604
  • Loading branch information
ashay committed Jun 2, 2022
1 parent 01ba470 commit 5fee179
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 16 deletions.
41 changes: 35 additions & 6 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Expand Up @@ -1128,14 +1128,43 @@ struct MemRefReshapeOpLowering
if (!isStaticStrideOrOffset(offset))
return rewriter.notifyMatchFailure(reshapeOp,
"dynamic offset is unsupported");
if (!llvm::all_of(strides, isStaticStrideOrOffset))
return rewriter.notifyMatchFailure(reshapeOp,
"dynamic strides are unsupported");

desc.setConstantOffset(rewriter, loc, offset);
for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i));
desc.setConstantStride(rewriter, loc, i, strides[i]);

assert(targetMemRefType.getLayout().isIdentity() &&
"Identity layout map is a precondition of a valid reshape op");

Value stride = nullptr;
int64_t targetRank = targetMemRefType.getRank();
for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
// If the stride for this dimension is dynamic, then use the product
// of the sizes of the inner dimensions.
stride = createIndexConstant(rewriter, loc, strides[i]);
} else if (!stride) {
// `stride` is null only in the first iteration of the loop. However,
// since the target memref has an identity layout, we can safely set
// the innermost stride to 1.
stride = createIndexConstant(rewriter, loc, 1);
}

Value dimSize;
int64_t size = targetMemRefType.getDimSize(i);
// If the size of this dimension is dynamic, then load it at runtime
// from the shape operand.
if (!ShapedType::isDynamic(size)) {
dimSize = createIndexConstant(rewriter, loc, size);
} else {
Value shapeOp = reshapeOp.shape();
Value index = createIndexConstant(rewriter, loc, i);
dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
}

desc.setSize(rewriter, loc, i, dimSize);
desc.setStride(rewriter, loc, i, stride);

// Prepare the stride value for the next dimension.
stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
}

*descriptor = desc;
Expand Down
74 changes: 64 additions & 10 deletions mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
Expand Up @@ -232,23 +232,77 @@ func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> {
// CHECK: %[[elem1:.*]] = llvm.extractvalue %[[cast0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[insert0:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[insert1:.*]] = llvm.insertvalue %[[elem1]], %[[insert0:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>

// CHECK: %[[zero:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero]], %[[insert1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[insert3:.*]] = llvm.insertvalue %[[two]], %[[insert2]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64
// CHECK: %[[insert4:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert3]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64
// CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>

// CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[twenty0:.*]] = llvm.mlir.constant(20 : index) : i64
// CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty0]], %[[insert5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[insert3:.*]] = llvm.insertvalue %[[twenty0]], %[[insert2]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one]], %[[insert3]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>

// CHECK: %[[twenty1:.*]] = llvm.mlir.constant(20 : index) : i64
// CHECK: %[[insert7:.*]] = llvm.insertvalue %[[twenty1]], %[[insert6]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[insert8:.*]] = llvm.insertvalue %[[one]], %[[insert7]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64
// CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty1]], %[[insert5]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>

// CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64
// CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[insert7:.*]] = llvm.insertvalue %[[two]], %[[insert6]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[insert8:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert7]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>

// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[insert8]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32>
%1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32>

// CHECK: return %[[cast1]] : memref<2x6x20xf32>
return %1 : memref<2x6x20xf32>
}

// -----

// CHECK-LABEL: func @memref.reshape.dynamic.dim
// CHECK-SAME: %[[arg:.*]]: memref<?x?x?xf32>, %[[shape:.*]]: memref<4xi64>) -> memref<?x?x12x32xf32>
func.func @memref.reshape.dynamic.dim(%arg: memref<?x?x?xf32>, %shape: memref<4xi64>) -> memref<?x?x12x32xf32> {
// CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref<?x?x?xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<4xi64> to !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>

// CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>

// CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[thirty_two0:.*]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[insert3:.*]] = llvm.insertvalue %[[thirty_two0]], %[[insert2]][3, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0]], %[[insert3]][4, 3] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>

// CHECK: %[[thirty_two1:.*]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[twelve:.*]] = llvm.mlir.constant(12 : index) : i64
// CHECK: %[[insert5:.*]] = llvm.insertvalue %[[twelve]], %[[insert4]][3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[insert6:.*]] = llvm.insertvalue %[[thirty_two1]], %[[insert5]][4, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>

// CHECK: %[[three_hundred_and_eighty_four:.*]] = llvm.mlir.constant(384 : index) : i64
// CHECK: %[[one1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
// CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0]] : !llvm.ptr<i64>
// CHECK: %[[insert7:.*]] = llvm.insertvalue %[[shape_load0]], %[[insert6]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[insert8:.*]] = llvm.insertvalue %[[three_hundred_and_eighty_four]], %[[insert7]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>

// CHECK: %[[mul:.*]] = llvm.mul %19, %23 : i64
// CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[shape_ptr1:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr<i64>, ptr<i64>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[shape_gep1:.*]] = llvm.getelementptr %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr<i64>, i64) -> !llvm.ptr<i64>
// CHECK: %[[shape_load1:.*]] = llvm.load %[[shape_gep1]] : !llvm.ptr<i64>
// CHECK: %[[insert9:.*]] = llvm.insertvalue %[[shape_load1]], %[[insert8]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>
// CHECK: %[[insert10:.*]] = llvm.insertvalue %[[mul]], %[[insert9]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)>

// CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert10]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<4 x i64>, array<4 x i64>)> to memref<?x?x12x32xf32>
%0 = memref.reshape %arg(%shape) : (memref<?x?x?xf32>, memref<4xi64>) -> memref<?x?x12x32xf32>

return %0 : memref<?x?x12x32xf32>
// CHECK: return %[[result_cast]] : memref<?x?x12x32xf32>
}

0 comments on commit 5fee179

Please sign in to comment.