Skip to content

Commit

Permalink
[mlir][tosa] Improve lowering support for tosa.concat
Browse files Browse the repository at this point in the history
The existing lowering for tosa.concat fails in some instances when the
output shape contains more information the input shapes. The result is
an illegal tensor.empty operation.

This change bases the output shape on the original tosa.concat
operation, while querying the input tensor shapes to build the slicing
operations.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D151707
  • Loading branch information
sabauma authored and rsuderman committed Jun 15, 2023
1 parent 4c2fc26 commit 86c4972
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 75 deletions.
72 changes: 37 additions & 35 deletions mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Expand Up @@ -12,7 +12,9 @@

#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -355,56 +357,56 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
LogicalResult
matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputType = cast<ShapedType>(op.getOperand(0).getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());

Location loc = op.getLoc();
int axis = op.getAxis();
Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(axis));
int rank = resultType.getRank();
SmallVector<Value, 3> offsets, sizes, strides;
sizes.reserve(rank);
strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
int64_t rank = resultType.getRank();

SmallVector<Value> dynDims;
for (int i = 0; i < rank; ++i) {
sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
loc, adaptor.getOperands()[0], i));
if (inputType.isDynamicDim(i)) {
dynDims.push_back(
rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
}
}
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = tensor::createDimValues(
rewriter, op.getLoc(), adaptor.getOperands()[0]);

// Pre-compute the offsets along the axis dimension.
// The axisOffsets will be of size rank + 1, where the last value
// will hold the total size of the tensor along the 'axis' dimension.
SmallVector<OpFoldResult> axisOffsets;
axisOffsets.push_back(rewriter.getIndexAttr(0));
axisOffsets.push_back(sizes[axis]);

Value resultDimSize = sizes[axis];
for (auto arg : adaptor.getOperands().drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
resultDimSize =
rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
auto currentOffset =
getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
auto total =
rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
axisOffsets.push_back(getAsOpFoldResult(total));
}
sizes[axis] = axisOffsets.back();

// Compute the dynamic sizes of the tensor.empty operation.
// This is based off of the specified result type of the tosa.concat
// operation, since we don't want to change the result type of the operation
// during the conversion.
SmallVector<Value> dynDims;
for (int64_t i = 0; i < rank; ++i) {
if (resultType.isDynamicDim(i)) {
dynDims.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
}
}
sizes[axis] = resultDimSize;

Value emptyTensor = rewriter.create<tensor::EmptyOp>(
Value result = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType(), dynDims);

auto toOpFoldResult = [](Value v) -> OpFoldResult {
auto op = v.getDefiningOp<arith::ConstantIndexOp>();
if (!op)
return v;
return op.getValue();
};
Value result = emptyTensor;
for (auto arg : adaptor.getOperands()) {
sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
auto sizes = tensor::createDimValues(rewriter, op.getLoc(), arg);
offsets[axis] = offset;
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, arg, result,
llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
offsets[axis] =
rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
loc, arg, result, offsets, sizes, strides);
}
rewriter.replaceOp(op, result);
return success();
Expand Down
125 changes: 85 additions & 40 deletions mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir
Expand Up @@ -202,23 +202,13 @@ func.func @pad_dyn_padding(%arg0 : tensor<1x2xf32>) -> (tensor<?x9xf32>) {
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x1xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<6x1xf32>
func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// CHECK: [[AXIS:%.+]] = arith.constant 0
// CHECK: [[STRIDE:%.+]] = arith.constant 1
// CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
// CHECK: [[IDX0:%.+]] = arith.constant 0 : index
// CHECK: [[IDX1:%.+]] = arith.constant 1 : index
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
// CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
// CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<11x1xf32>
// CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
// CHECK-DAG: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG1]] into [[INSERT0]][5, 0] [6, 1] [1, 1]
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x1xf32>, tensor<6x1xf32>) -> (tensor<11x1xf32>)

// CHECK: [[AXIS:%.+]] = arith.constant 1
// CHECK: [[STRIDE:%.+]] = arith.constant 1
// CHECK: [[OFFSET:%.+]] = arith.constant 0 : index
// CHECK: [[IDX0:%.+]] = arith.constant 0 : index
// CHECK: [[IDX1:%.+]] = arith.constant 1 : index
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
// CHECK: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
// CHECK-DAG: [[INIT:%.+]] = tensor.empty() : tensor<5x2xf32>
// CHECK-DAG: [[INSERT0:%.+]] = tensor.insert_slice %[[ARG0]] into [[INIT]][0, 0] [5, 1] [1, 1]
// CHECK: [[INSERT1:%.+]] = tensor.insert_slice %[[ARG0]] into [[INSERT0]][0, 1] [5, 1] [1, 1]
%1 = "tosa.concat"(%arg0, %arg0) { axis = 1 : i64} : (tensor<5x1xf32>, tensor<5x1xf32>) -> (tensor<5x2xf32>)
return
Expand All @@ -230,17 +220,16 @@ func.func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () {
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]
func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -> () {
// CHECK: %[[AXIS:.+]] = arith.constant 0
// CHECK: %[[STRIDE:.+]] = arith.constant 1
// CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
// CHECK: %[[IDX0:.+]] = arith.constant 0 : index
// CHECK: %[[IDX1:.+]] = arith.constant 1 : index
// CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
// CHECK: %[[IDX1_2:.+]] = arith.constant 1 : index
// CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX1_2]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<11x?xf32>
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[SIZE]]] [1, 1]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[SIZE]]] [1, 1]
// CHECK-DAG: %[[AXIS:.+]] = arith.constant 0
// CHECK-DAG: %[[IDX1:.+]] = arith.constant 1
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[IDX1]]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<11x?xf32>
// CHECK-DAG: %[[IDX1_1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[IDX1_1]]
// CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [5, %[[DIM1]]] [1, 1]
// CHECK-DAG: %[[IDX1_2:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG1]], %[[IDX1_2]] : tensor<6x?xf32>
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][5, 0] [6, %[[DIM2]]] [1, 1]
%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<5x?xf32>, tensor<6x?xf32>) -> (tensor<11x?xf32>)
return
}
Expand All @@ -251,20 +240,76 @@ func.func @concat_non_axis_dyn(%arg0: tensor<5x?xf32>, %arg1: tensor<6x?xf32>) -
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
func.func @concat_axis_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<?x3xf32>) -> () {
// CHECK: %[[AXIS:.+]] = arith.constant 0
// CHECK: %[[STRIDE:.+]] = arith.constant 1
// CHECK: %[[OFFSET:.+]] = arith.constant 0 : index
// CHECK: %[[IDX0:.+]] = arith.constant 0 : index
// CHECK: %[[SIZE:.+]] = tensor.dim %[[ARG0]], %[[IDX0]]
// CHECK: %[[IDX0_2:.+]] = arith.constant 0 : index
// CHECK: %[[DYN:.+]] = tensor.dim %[[ARG0]], %[[IDX0_2]]
// CHECK: %[[IDX1:.+]] = arith.constant 1 : index
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DYN]]) : tensor<?x3xf32>
// CHECK: %[[DYN1:.+]] = tensor.dim %[[ARG0]], %[[AXIS]]
// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DYN1]], 3] [1, 1]
// CHECK: %[[SUM:.+]] = arith.addi %[[OFFSET]], %[[DYN1]]
// CHECK: %[[DYN2:.+]] = tensor.dim %[[ARG1]], %[[AXIS]]
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[SUM]], 0] [%[[DYN2]], 3] [1, 1]
// CHECK-DAG: %[[AXIS:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[IDX0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[IDX0]] : tensor<?x3xf32>
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[AXIS]] : tensor<?x3xf32>
// CHECK-DAG: %[[SUM:.+]] = arith.addi %[[DIM0]], %[[DIM1]] : index
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[SUM]]) : tensor<?x3xf32>
// CHECK-DAG: %[[IDX0_1:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[IDX0_1]] : tensor<?x3xf32>
// CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM2]], 3] [1, 1] : tensor<?x3xf32> into tensor<?x3xf32>
// CHECK-DAG: %[[IDX0_2:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM3:.+]] = tensor.dim %[[ARG1]], %[[IDX0_2]] : tensor<?x3xf32>
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[DIM0]], 0] [%[[DIM3]], 3] [1, 1] : tensor<?x3xf32> into tensor<?x3xf32>

%0 = "tosa.concat"(%arg0, %arg1) { axis = 0 : i64} : (tensor<?x3xf32>, tensor<?x3xf32>) -> (tensor<?x3xf32>)
return
}

// -----

// CHECK-LABEL: @concat_axis_dyn_mixed
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z_]*]]:
func.func @concat_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> () {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[OFFSET0:.+]] = tensor.dim %[[ARG0]], %[[C0_0]] : tensor<?x1xf32>
// CHECK-DAG: %[[DIM1_0:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x1xf32>
// CHECK-DAG: %[[OFFSET1:.+]] = arith.addi %[[OFFSET0]], %[[DIM1_0]] : index
// CHECK-DAG: %[[DIM2_2:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x1xf32>
// CHECK-DAG: %[[OFFSET2:.+]] = arith.addi %[[OFFSET1]], %[[DIM2_2]] : index
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x1xf32>
// CHECK-DAG: %[[C0_3:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0_3]] : tensor<?x1xf32>
// CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM_4]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>
// CHECK-DAG: %[[C0_4:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM_6:.+]] = tensor.dim %[[ARG1]], %[[C0_4]] : tensor<?x1xf32>
// CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][%[[OFFSET0]], 0] [%[[DIM_6]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>
// CHECK-DAG: %[[C0_8:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM_9:.+]] = tensor.dim %[[ARG2]], %[[C0_8]] : tensor<?x1xf32>
// CHECK-DAG: %[[INSERT3:.+]] = tensor.insert_slice %[[ARG2]] into %[[INSERT1]][%[[OFFSET1]], 0] [%[[DIM_9]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x1xf32>

// CHECK: return

%0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 0 : i64}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x1xf32>
return
}

// -----

// CHECK-LABEL: @concat_non_axis_dyn_mixed
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]:
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z_]*]]:
func.func @concat_non_axis_dyn_mixed(%arg0: tensor<?x1xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x1xf32>) -> () {
// CHECK-DAG: %[[UNUSED0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[UNUSED1:.+]] = tensor.dim %[[ARG0]], %[[UNUSED0]] : tensor<?x1xf32>

// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<5x3xf32>
// CHECK-DAG: %[[C0_0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM0_0:.+]] = tensor.dim %[[ARG0]], %[[C0_0]] : tensor<?x1xf32>
// CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [%[[DIM0_0]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
// CHECK-DAG: %[[C0_1:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM1_0:.+]] = tensor.dim %[[ARG1]], %[[C0_1]] : tensor<?x1xf32>
// CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[ARG1]] into %[[INSERT0]][0, 1] [%[[DIM1_0]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
// CHECK-DAG: %[[C0_2:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM2_0:.+]] = tensor.dim %[[ARG2]], %[[C0_2]] : tensor<?x1xf32>
// CHECK-DAG: %[[INSERT2:.+]] = tensor.insert_slice %[[ARG2]] into %[[INSERT1]][0, 2] [%[[DIM2_0]], 1] [1, 1] : tensor<?x1xf32> into tensor<5x3xf32>
// CHECK: return

%0 = "tosa.concat"(%arg0, %arg1, %arg2) <{axis = 1 : i64}> : (tensor<?x1xf32>, tensor<?x1xf32>, tensor<?x1xf32>) -> tensor<5x3xf32>
return
}

0 comments on commit 86c4972

Please sign in to comment.