Skip to content

Commit

Permalink
[mlir][tensor][sparse] don't drop encoding when infer result type (#9…
Browse files Browse the repository at this point in the history
…1817)

A general question is: is it possible to support hooks here to infer the
encoding? E.g., when the extracted tensor slice is rank-reduced, the
encoding need to be updated accordingly as well.
  • Loading branch information
Peiming Liu authored May 13, 2024
1 parent 6140b5b commit 37ffbbb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2020,7 +2020,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
sourceTensorType.getEncoding());
}

RankedTensorType ExtractSliceOp::inferResultType(
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/SparseTensor/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s

#BCOO = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
}>

// CHECK-DAG: #[[$BCOO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }>
// CHECK-LABEL: func @sparse_slice_canonicalize
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32, #[[$BCOO]]>
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
// CHECK-SAME: : tensor<?x?x?xf32, #[[$BCOO]]> to tensor<4x1x?xf32, #[[$BCOO]]>
// CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]]
// CHECK: return %[[RESULT]]
func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : index,
%arg2 : index) -> tensor<?x?x?xf32, #BCOO>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
return %0 : tensor<?x?x?xf32, #BCOO>
}

0 comments on commit 37ffbbb

Please sign in to comment.