diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index c607ece418dff..310e72587eb81 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1132,35 +1132,22 @@ struct ConcatOpInterface // Extract the dimension for the concat op uint64_t concatDim = concatOp.getDim(); - bool dynamicConcatDim = false; SmallVector offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); SmallVector strides(tensorType.getRank(), rewriter.getIndexAttr(1)); - SmallVector sizes; - - for (const auto &[dimIdx, dimSize] : - llvm::enumerate(tensorType.getShape())) { - if (dimSize == ShapedType::kDynamic) { - auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx); - sizes.push_back(dimOp.getResult()); - if (dimIdx == concatDim) - dynamicConcatDim = true; - } else { - sizes.push_back(rewriter.getIndexAttr(dimSize)); - } - } - - int64_t concatDimOffset = 0; - std::optional dynamicOffset; - std::optional dynamicSize; - if (dynamicConcatDim) { - // One or more operands have dynamic size, so we must accumulate the - // offset with arith ops. - dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); - } + SmallVector sizes = + memref::getMixedSizes(rewriter, loc, dstBuffer); + + AffineExpr s0, s1; + bindSymbols(rewriter.getContext(), s0, s1); + auto sum = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1, + {v1, v2}); + }; + OpFoldResult concatDimOffset = rewriter.getIndexAttr(0); for (auto operand : concatOp.getInputs()) { // Get the buffer for the operand. FailureOr srcBuffer = getBuffer(rewriter, operand, options, state); @@ -1171,18 +1158,10 @@ struct ConcatOpInterface // so the offset on that axis must accumulate through the loop, and the // size must change to the size of the current operand. auto operandTensorType = cast(operand.getType()); - int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim); - - if (dynamicConcatDim) { - offsets[concatDim] = dynamicOffset.value(); - dynamicSize = - memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim) - .getResult(); - sizes[concatDim] = dynamicSize.value(); - } else { - sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize); - offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset); - } + offsets[concatDim] = concatDimOffset; + OpFoldResult concatDimSize = + memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim); + sizes[concatDim] = concatDimSize; // Create a subview of the destination buffer. auto dstMemrefType = cast(memrefType); @@ -1197,12 +1176,7 @@ struct ConcatOpInterface if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview))) return failure(); - if (dynamicConcatDim) { - dynamicOffset = arith::AddIOp::create( - rewriter, loc, dynamicOffset.value(), dynamicSize.value()); - } else { - concatDimOffset += operandConcatDimSize; - } + concatDimOffset = sum(concatDimOffset, concatDimSize); } replaceOpWithBufferizedValues(rewriter, op, dstBuffer); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index 5eb2360a29b8f..be8ce20d8f154 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3 // CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]] // CHECK: %[[ALLOC:.*]] = memref.alloc // CHECK-SAME: memref<8x?xf32> -// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index -// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1] // CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] -// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index -// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1] // CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: return %[[RET]] @@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te // CHECK: %[[ALLOC:.*]] = memref.alloc // CHECK-SAME: memref // CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]] -// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1] // CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] -// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index -// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1] // CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] // CHECK: return %[[RET]] @@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor, %g: tensor (s0 + s1)> + +// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static( +// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>, +// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>) +// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]] +// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]] +// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]] +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32> +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]] +// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1] +// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]] +// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]] +// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1] +// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]] +// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]] +// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1] +// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]] +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: return %[[RET]] +// CHECK: } +func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> { + %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32> + return %0 : tensor<8x10xf32> +} + +// ----- + // CHECK-LABEL: func @tensor.splat_dynamic( // CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32 // CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index