Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1132,35 +1132,22 @@ struct ConcatOpInterface

// Extract the dimension for the concat op
uint64_t concatDim = concatOp.getDim();
bool dynamicConcatDim = false;

SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes;

for (const auto &[dimIdx, dimSize] :
llvm::enumerate(tensorType.getShape())) {
if (dimSize == ShapedType::kDynamic) {
auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
sizes.push_back(dimOp.getResult());
if (dimIdx == concatDim)
dynamicConcatDim = true;
} else {
sizes.push_back(rewriter.getIndexAttr(dimSize));
}
}

int64_t concatDimOffset = 0;
std::optional<Value> dynamicOffset;
std::optional<Value> dynamicSize;
if (dynamicConcatDim) {
// One or more operands have dynamic size, so we must accumulate the
// offset with arith ops.
dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
}
SmallVector<OpFoldResult> sizes =
memref::getMixedSizes(rewriter, loc, dstBuffer);

AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
return affine::makeComposedFoldedAffineApply(rewriter, loc, s0 + s1,
{v1, v2});
};

OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
Expand All @@ -1171,18 +1158,10 @@ struct ConcatOpInterface
// so the offset on that axis must accumulate through the loop, and the
// size must change to the size of the current operand.
auto operandTensorType = cast<RankedTensorType>(operand.getType());
int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);

if (dynamicConcatDim) {
offsets[concatDim] = dynamicOffset.value();
dynamicSize =
memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
.getResult();
sizes[concatDim] = dynamicSize.value();
} else {
sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
}
offsets[concatDim] = concatDimOffset;
OpFoldResult concatDimSize =
memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
sizes[concatDim] = concatDimSize;

// Create a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(memrefType);
Expand All @@ -1197,12 +1176,7 @@ struct ConcatOpInterface
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
return failure();

if (dynamicConcatDim) {
dynamicOffset = arith::AddIOp::create(
rewriter, loc, dynamicOffset.value(), dynamicSize.value());
} else {
concatDimOffset += operandConcatDimSize;
}
concatDimOffset = sum(concatDimOffset, concatDimSize);
}

replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
Expand Down
40 changes: 33 additions & 7 deletions mlir/test/Dialect/Tensor/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<8x?xf32>
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
Expand All @@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<?x?xf32>
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
Expand All @@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?

// -----

// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>

// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
// CHECK: }
func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
%0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks odd to me: I expected that the result dimension must be dynamic if and only if one of the input dimensions is dynamic. We follow this design in other operations such tensor.collapse_shape. E.g., see CollapseShapeOp::inferCollapsedType for details. Can the verifier be made more strict instead? @qedawkins

Copy link
Contributor Author

@CoTinker CoTinker Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, actually the docs give an example like this:

// Dynamic + dynamic -> static
%0 = tensor.concat dim(1) %0, %1, %2 :
(tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>

But setting aside whether this verifier check is reasonable, the way this PR unifies the computation using OpFoldResult should still be valuable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally allowed dynamic + dynamic -> static since it's possible for such situations to arise where forcing the result to be dynamic would require a tensor.cast to introduce the static info. Arguably collapse could do the same, though in the collapse case it's probably a lot less likely someone is choosing to collapse two truly dynamic dimensions into a static one intentionally, so being defensive was probably a win there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a problem with an explicit tensor.cast?

It would be nice to have a consistent op design across the tensor dialect. I believe one reason why we chose input dynamicity == output dynamicity for collapse_shape/expand_shape is that we can print better error messages: if there's only one allowable output type, you can print it during verification errors.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response. We aren't currently using the functionality for dynamic + dynamic -> static downstream so we wouldn't notice if it was removed (right now), but in general casts hinder optimization by cluttering use-def chains.

As an example, imagine if we added a tensor.split op as the inverse of concat. The folder for it would look for tensor.split(tensor.concat) but the cast would get in the way.

return %0 : tensor<8x10xf32>
}

// -----

// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
Expand Down