Skip to content

Commit

Permalink
[mlir][tosa] Tosa reverse to linalg supporting dynamic shapes
Browse files Browse the repository at this point in the history
Needed to switch to extract to support tosa.reverse using dynamic shapes.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D108744
  • Loading branch information
rsuderman committed Aug 26, 2021
1 parent 84cbd71 commit 9047825
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 32 deletions.
50 changes: 29 additions & 21 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Expand Up @@ -2043,40 +2043,48 @@ class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
Value input = op.input();
auto inputTy = input.getType().template cast<ShapedType>();
auto resultTy = op.getType().template cast<ShapedType>();
auto rank = resultTy.getRank();
auto axis = op.axis();

if (!inputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
}
}

Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);

// First fill the output buffer with the init value.
auto initTensor = rewriter
.create<linalg::InitTensorOp>(
loc, ArrayRef<Value>({}), inputTy.getShape(),
inputTy.getElementType())
loc, ArrayRef<Value>({dynDims}),
inputTy.getShape(), inputTy.getElementType())
.result();

SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());

for (int i = 0; i < rank; i++)
inputExprs[i] = rewriter.getAffineDimExpr(i);

inputExprs[axis] =
rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) -
inputExprs[axis];

SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};

rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, op.input(), ValueRange{initTensor}, affineMaps,
op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTy.getRank(); i++) {
auto index =
rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
if (i == axis) {
auto one = rewriter.create<ConstantIndexOp>(nestedLoc, 1);
auto sizeMinusOne =
rewriter.create<SubIOp>(nestedLoc, axisDimSize, one);
index = rewriter.create<SubIOp>(nestedLoc, sizeMinusOne, index);
}

indices.push_back(index);
}

auto extract = nestedBuilder.create<tensor::ExtractOp>(
nestedLoc, input, indices);
nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
extract.getResult());
});
return success();
}
Expand Down
56 changes: 45 additions & 11 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Expand Up @@ -881,28 +881,62 @@ func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (-d0 + 4, d1)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 3)>
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: @reverse
func @reverse(%arg0: tensor<5x4xi32>) -> () {
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) {
// CHECK: ^bb0(%arg1: i32, %arg2: i32):
// CHECK: linalg.yield %arg1 : i32
// CHECK: %[[C0:.+]] = constant 0
// CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>)
// CHECK-DAG: %[[I0:.+]] = linalg.index 0
// CHECK-DAG: %[[I1:.+]] = linalg.index 1
// CHECK-DAG: %[[SUB1:.+]] = constant 1
// CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]]
// CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]]
// CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]], %[[I1]]] : tensor<5x4xi32>
// CHECK: linalg.yield %[[EXTRACT]]
%0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>

// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4]
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) {
// CHECK: ^bb0(%arg1: i32, %arg2: i32):
// CHECK: linalg.yield %arg1 : i32
// CHECK: %[[C1:.+]] = constant 1
// CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C1]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>)
// CHECK-DAG: %[[I0:.+]] = linalg.index 0
// CHECK-DAG: %[[I1:.+]] = linalg.index 1
// CHECK-DAG: %[[SUB1:.+]] = constant 1
// CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]]
// CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I1]]
// CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[I0]], %[[READ_DIM]]] : tensor<5x4xi32>
// CHECK: linalg.yield %[[EXTRACT]]
%1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32>
return
}

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

// CHECK-LABEL: @reverse_dyn
func @reverse_dyn(%arg0: tensor<?xi32>) -> () {
// CHECK: %[[C0_1:.+]] = constant 0
// CHECK: %[[D0_1:.+]] = tensor.dim %arg0, %[[C0_1]]
// CHECK: %[[C0_2:.+]] = constant 0
// CHECK: %[[D0_2:.+]] = tensor.dim %arg0, %[[C0_2]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0_1]]]
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel"]} outs(%[[INIT]] : tensor<?xi32>)
// CHECK-DAG: %[[I0:.+]] = linalg.index 0
// CHECK-DAG: %[[SUB1:.+]] = constant 1
// CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[D0_2]], %[[SUB1]]
// CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]]
// CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]]] : tensor<?xi32>
// CHECK: linalg.yield %[[EXTRACT]]
%0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<?xi32>) -> tensor<?xi32>
return
}

// -----

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

Expand Down

0 comments on commit 9047825

Please sign in to comment.