Skip to content

Commit

Permalink
[MLIR][Tosa] Pass encoding through tosa-to-linalg
Browse files Browse the repository at this point in the history
As pointed out by @Sinclair-Dee in
#62304, the `tosa-to-linalg`
conversion ignored the `encoding` attribute.

Also, this patch avoids an assertion error crash on unranked tensors.
Instead, the conversion now throws a "failed to legalize" error.

Fixes #62304 and fixes #63165.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D152171
  • Loading branch information
rikhuijzer committed Jun 15, 2023
1 parent b1c683f commit c8ac14d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
19 changes: 8 additions & 11 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Expand Up @@ -526,12 +526,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
assert(operation->getNumResults() == 1 &&
"All TOSA elementwise ops should only return a single result.");

auto results = operation->getResults();
auto resultTy = dyn_cast<ShapedType>(operation->getResult(0).getType());
auto result = operation->getResult(0);
auto resultTy = dyn_cast<RankedTensorType>(result.getType());

if (!resultTy)
return rewriter.notifyMatchFailure(operation,
"All results must be a shaped type");
return rewriter.notifyMatchFailure(
operation, "All results must be a ranked tensor type");

unsigned rank = resultTy.getRank();

Expand All @@ -545,7 +545,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,
SmallVector<Value> emptyTensors;

SmallVector<Value> dynDims;
dynDims.resize(cast<ShapedType>(results.front().getType()).getRank());
dynDims.resize(rank);

for (auto arg : operation->getOperands()) {
auto operandTy = cast<ShapedType>(arg.getType());
Expand All @@ -557,12 +557,9 @@ elementwiseMatchAndRewriteHelper(Operation *operation,

SmallVector<Value> filteredDims = condenseValues(dynDims);

for (auto result : results) {
auto resultTy = cast<ShapedType>(result.getType());
emptyTensors.push_back(rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultTy.getElementType(), filteredDims));
opResultTypes.push_back(result.getType());
}
emptyTensors.push_back(
rewriter.create<tensor::EmptyOp>(loc, resultTy, filteredDims));
opResultTypes.push_back(result.getType());

auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
emptyTensors, [](Value v) { return getElementTypeOrSelf(v); }));
Expand Down
11 changes: 10 additions & 1 deletion mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -1,8 +1,17 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics

// CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type
func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
// expected-error@+1 {{failed to legalize operation 'tosa.avg_pool2d'}}
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
}

// -----

// CHECK-LABEL: @tensor_with_unknown_rank
func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
// expected-error@+1 {{failed to legalize operation 'tosa.abs'}}
%0 = "tosa.abs"(%arg0) : (tensor<*xi8>) -> tensor<*xi8>
return %0 : tensor<*xi8>
}
12 changes: 12 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Expand Up @@ -85,8 +85,20 @@ func.func @test_abs_dyn(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> {
%0 = "tosa.abs"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32>
return %0 : tensor<2x?xf32>
}

// -----

#SparseVector = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>

// CHECK-LABEL: @test_encoding_passthrough
func.func @test_encoding_passthrough(%arg0: tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector> {
// CHECK: linalg.generic
// CHECK: sparse_tensor
%0 = "tosa.abs"(%arg0) : (tensor<2xi8, #SparseVector>) -> tensor<2xi8, #SparseVector>
return %0 : tensor<2xi8, #SparseVector>
}

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> ()>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
Expand Down

0 comments on commit c8ac14d

Please sign in to comment.