diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 3dfd591b0fca8..97b631afc33cd 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1128,14 +1128,43 @@ struct MemRefReshapeOpLowering if (!isStaticStrideOrOffset(offset)) return rewriter.notifyMatchFailure(reshapeOp, "dynamic offset is unsupported"); - if (!llvm::all_of(strides, isStaticStrideOrOffset)) - return rewriter.notifyMatchFailure(reshapeOp, - "dynamic strides are unsupported"); desc.setConstantOffset(rewriter, loc, offset); - for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { - desc.setConstantSize(rewriter, loc, i, targetMemRefType.getDimSize(i)); - desc.setConstantStride(rewriter, loc, i, strides[i]); + + assert(targetMemRefType.getLayout().isIdentity() && + "Identity layout map is a precondition of a valid reshape op"); + + Value stride = nullptr; + int64_t targetRank = targetMemRefType.getRank(); + for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { + if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { + // If the stride for this dimension is dynamic, then use the product + // of the sizes of the inner dimensions. + stride = createIndexConstant(rewriter, loc, strides[i]); + } else if (!stride) { + // `stride` is null only in the first iteration of the loop. However, + // since the target memref has an identity layout, we can safely set + // the innermost stride to 1. + stride = createIndexConstant(rewriter, loc, 1); + } + + Value dimSize; + int64_t size = targetMemRefType.getDimSize(i); + // If the size of this dimension is dynamic, then load it at runtime + // from the shape operand. + if (!ShapedType::isDynamic(size)) { + dimSize = createIndexConstant(rewriter, loc, size); + } else { + Value shapeOp = reshapeOp.shape(); + Value index = createIndexConstant(rewriter, loc, i); + dimSize = rewriter.create(loc, shapeOp, index); + } + + desc.setSize(rewriter, loc, i, dimSize); + desc.setStride(rewriter, loc, i, stride); + + // Prepare the stride value for the next dimension. + stride = rewriter.create(loc, stride, dimSize); } *descriptor = desc; diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir index b8f3717682a05..1296f8e4881cd 100644 --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -232,23 +232,77 @@ func.func @memref.reshape(%arg0: memref<4x5x6xf32>) -> memref<2x6x20xf32> { // CHECK: %[[elem1:.*]] = llvm.extractvalue %[[cast0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[elem0]], %[[undef]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[elem1]], %[[insert0:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[zero:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero]], %[[insert1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64 - // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[two]], %[[insert2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64 - // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64 - // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + + // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK: %[[twenty0:.*]] = llvm.mlir.constant(20 : index) : i64 - // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty0]], %[[insert5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[twenty0]], %[[insert2]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one]], %[[insert3]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[twenty1:.*]] = llvm.mlir.constant(20 : index) : i64 - // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[twenty1]], %[[insert6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> - // CHECK: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 - // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[one]], %[[insert7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[six:.*]] = llvm.mlir.constant(6 : index) : i64 + // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[six]], %[[insert4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[twenty1]], %[[insert5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + + // CHECK: %[[hundred_and_twenty:.*]] = llvm.mlir.constant(120 : index) : i64 + // CHECK: %[[two:.*]] = llvm.mlir.constant(2 : index) : i64 + // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[two]], %[[insert6]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[hundred_and_twenty]], %[[insert7]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[insert8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<2x6x20xf32> %1 = memref.reshape %arg0(%0) : (memref<4x5x6xf32>, memref<3xi64>) -> memref<2x6x20xf32> // CHECK: return %[[cast1]] : memref<2x6x20xf32> return %1 : memref<2x6x20xf32> } + +// ----- + +// CHECK-LABEL: func @memref.reshape.dynamic.dim +// CHECK-SAME: %[[arg:.*]]: memref, %[[shape:.*]]: memref<4xi64>) -> memref +func.func @memref.reshape.dynamic.dim(%arg: memref, %shape: memref<4xi64>) -> memref { + // CHECK: %[[arg_cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : memref to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[shape_cast:.*]] = builtin.unrealized_conversion_cast %[[shape]] : memref<4xi64> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[undef:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[alloc_ptr:.*]] = llvm.extractvalue %[[arg_cast]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[align_ptr:.*]] = llvm.extractvalue %[[arg_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK: %[[insert0:.*]] = llvm.insertvalue %[[alloc_ptr]], %[[undef]][0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert1:.*]] = llvm.insertvalue %[[align_ptr]], %[[insert0]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[zero0:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[insert2:.*]] = llvm.insertvalue %[[zero0]], %[[insert1]][2] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[one0:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[thirty_two0:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[insert3:.*]] = llvm.insertvalue %[[thirty_two0]], %[[insert2]][3, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert4:.*]] = llvm.insertvalue %[[one0]], %[[insert3]][4, 3] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[thirty_two1:.*]] = llvm.mlir.constant(32 : index) : i64 + // CHECK: %[[twelve:.*]] = llvm.mlir.constant(12 : index) : i64 + // CHECK: %[[insert5:.*]] = llvm.insertvalue %[[twelve]], %[[insert4]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert6:.*]] = llvm.insertvalue %[[thirty_two1]], %[[insert5]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[three_hundred_and_eighty_four:.*]] = llvm.mlir.constant(384 : index) : i64 + // CHECK: %[[one1:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK: %[[shape_ptr0:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[shape_gep0:.*]] = llvm.getelementptr %[[shape_ptr0]][%[[one1]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: %[[shape_load0:.*]] = llvm.load %[[shape_gep0]] : !llvm.ptr + // CHECK: %[[insert7:.*]] = llvm.insertvalue %[[shape_load0]], %[[insert6]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert8:.*]] = llvm.insertvalue %[[three_hundred_and_eighty_four]], %[[insert7]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[mul:.*]] = llvm.mul %19, %23 : i64 + // CHECK: %[[zero1:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[shape_ptr1:.*]] = llvm.extractvalue %[[shape_cast]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[shape_gep1:.*]] = llvm.getelementptr %[[shape_ptr1]][%[[zero1]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: %[[shape_load1:.*]] = llvm.load %[[shape_gep1]] : !llvm.ptr + // CHECK: %[[insert9:.*]] = llvm.insertvalue %[[shape_load1]], %[[insert8]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + // CHECK: %[[insert10:.*]] = llvm.insertvalue %[[mul]], %[[insert9]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> + + // CHECK: %[[result_cast:.*]] = builtin.unrealized_conversion_cast %[[insert10]] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> to memref + %0 = memref.reshape %arg(%shape) : (memref, memref<4xi64>) -> memref + + return %0 : memref + // CHECK: return %[[result_cast]] : memref +}