From 5fee1799f4d8da59c251e2d04172fc2f387cbe54 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Sat, 28 May 2022 17:33:04 -0700 Subject: [PATCH] [mlir] translate memref.reshape with static shapes but dynamic dims Prior to this patch, the lowering of memref.reshape operations to the LLVM dialect failed if the shape argument had a static shape with dynamic dimensions. This patch adds the necessary support so that when the shape argument has dynamic values, the lowering probes the dimension at runtime to set the size in the `MemRefDescriptor` type. This patch also computes the stride for dynamic dimensions by deriving it from the sizes of the inner dimensions. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D126604 --- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 41 ++++++++-- .../convert-static-memref-ops.mlir | 74 ++++++++++++++++--- 2 files changed, 99 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 3dfd591b0fca8..97b631afc33cd 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1128,14 +1128,43 @@ struct MemRefReshapeOpLowering if (!isStaticStrideOrOffset(offset)) return rewriter.notifyMatchFailure(reshapeOp, "dynamic offset is unsupported"); - if (!llvm::all_of(strides, isStaticStrideOrOffset)) - return rewriter.notifyMatchFailure(reshapeOp, - "dynamic strides are unsupported"); desc.setConstantOffset(rewriter, loc, offset); - for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { - desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i)); - desc.setConstantStride(rewriter, loc, i, strides[i]); + + assert(targetMemRefType.getLayout().isIdentity() && + "Identity layout map is a precondition of a valid reshape op"); + + Value stride = nullptr; + int64_t targetRank = targetMemRefType.getRank(); + for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { + if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { + // If the stride for this dimension is dynamic, then use the product + // of the sizes of the inner dimensions. + stride = createIndexConstant(rewriter, loc, strides[i]); + } else if (!stride) { + // `stride` is null only in the first iteration of the loop. However, + // since the target memref has an identity layout, we can safely set + // the innermost stride to 1. + stride = createIndexConstant(rewriter, loc, 1); + } + + Value dimSize; + int64_t size = targetMemRefType.getDimSize(i); + // If the size of this dimension is dynamic, then load it at runtime + // from the shape operand. + if (!ShapedType::isDynamic(size)) { + dimSize = createIndexConstant(rewriter, loc, size); + } else { + Value shapeOp = reshapeOp.shape(); + Value index = createIndexConstant(rewriter, loc, i); + dimSize = rewriter.create(loc, shapeOp, index); + } + + desc.setSize(rewriter, loc, i, dimSize); + desc.setStride(rewriter, loc, i, stride); + + // Prepare the stride value for the next dimension. + stride = rewriter.create(loc, stride, dimSize); } *descriptor = desc; diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir index b8f3717682a05..1296f8e4881cd 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -232,23 +232,77 @@ func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> { // CHECK: %[[elem1:.*]] = llvm.extractvalue %[[cast0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[elem1]], %[[insert0:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[zero:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero]], %[[insert1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64 - // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[two]], %[[insert2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64 - // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64 - // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + + // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[twenty0:.*]] = llvm.mlir.constant(20 : index) : i64 - // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty0]], %[[insert5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[twenty0]], %[[insert2]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one]], %[[insert3]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[twenty1:.*]] = llvm.mlir.constant(20 : index) : i64 - // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[twenty1]], %[[insert6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 - // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[one]], %[[insert7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64 + // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty1]], %[[insert5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + + // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64 + // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[two]], %[[insert6]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert7]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[insert8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32> %1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32> // CHECK: return %[[cast1]] : memref<2x6x20xf32> return %1 : memref<2x6x20xf32> } + +// ----- + +// CHECK-LABEL: func @memref.reshape.dynamic.dim +// CHECK-SAME: %[[arg:.*]]: memref, %[[shape:.*]]: memref<4xi64>) -> memref +func.func @memref.reshape.dynamic.dim(%arg: memref, %shape: memref<4xi64>) -> memref { + // CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<4xi64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef]][0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1]][2] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[thirty_two0:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[thirty_two0]], %[[insert2]][3, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0]], %[[insert3]][4, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[thirty_two1:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[twelve:.*]] = llvm.mlir.constant(12 : index) : i64 + // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[twelve]], %[[insert4]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[thirty_two1]], %[[insert5]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[three_hundred_and_eighty_four:.*]] = llvm.mlir.constant(384 : index) : i64 + // CHECK: %[[one1:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0]] : !llvm.ptr + // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[shape_load0]], %[[insert6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[three_hundred_and_eighty_four]], %[[insert7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[mul:.*]] = llvm.mul %19, %23 : i64 + // CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[shape_ptr1:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[shape_gep1:.*]] = llvm.getelementptr %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: %[[shape_load1:.*]] = llvm.load %[[shape_gep1]] : !llvm.ptr + // CHECK: %[[insert9:.*]] = llvm.insertvalue %[[shape_load1]], %[[insert8]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert10:.*]] = llvm.insertvalue %[[mul]], %[[insert9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert10]] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> to memref + %0 = memref.reshape %arg(%shape) : (memref, memref<4xi64>) -> memref + + return %0 : memref + // CHECK: return %[[result_cast]] : memref +}