Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] introduce sparse_tensor.lvl operation. #69978

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def SparseTensor_Dialect : Dialect {

let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
let hasConstantMaterializer = 1;
}

#endif // SPARSETENSOR_BASE
60 changes: 59 additions & 1 deletion mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,67 @@ def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set"
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Coordinate Translation Operation.
// Sparse Tensor Coordinate Operations.
//===----------------------------------------------------------------------===//

def SparseTensor_LvlOp : SparseTensor_Op<"lvl", [ConditionallySpeculatable, NoMemoryEffect]>,
Arguments<(ins AnySparseTensor:$source, Index:$index)>,
Results<(outs Index:$result)> {
let summary = "level index operation";
let description = [{
The `sparse_tensor.lvl` behaves similar to `tensor.dim` operation.
It takes a sparse tensor and a level operand of type `index` and returns
the size of the requested level of the given sparse tensor.
If the sparse tensor has an identity dimension to level mapping, it returns
the same result as `tensor.dim`.
If the level index is out of bounds, the behavior is undefined.
PeimingLiu marked this conversation as resolved.
Show resolved Hide resolved

Example:

```mlir
#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

// Always returns 2 (4 floordiv 2), can be constant folded:
%c0 = arith.constant 0 : index
%x = sparse_tensor.lvl %A, %c0 : tensor<4x?xf32, #BSR>

// Return the dynamic dimension of %A computed by %j mod 3.
%c1 = arith.constant 1 : index
%y = sparse_tensor.lvl %A, %c1 : tensor<4x?xf32, #BSR>

// Always return 3 (since j mod 3 < 3), can be constant fold
%c3 = arith.constant 3 : index
%y = sparse_tensor.lvl %A, %c3 : tensor<4x?xf32, #BSR>
```
}];

let assemblyFormat = [{
attr-dict $source `,` $index `:` type($source)
}];

let builders = [
OpBuilder<(ins "Value":$source, "int64_t":$index)>
];

let extraClassDeclaration = [{
/// Helper function to get the index as a simple integer if it is constant.
std::optional<uint64_t> getConstantLvlIndex();

/// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
}];

let hasVerifier = 1;
let hasFolder = 1;
}

def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
Arguments<(ins Variadic<Index>:$in_crds,
SparseTensorCrdTransDirectionAttr:$direction,
Expand Down
88 changes: 88 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,84 @@ LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
return success();
}

LogicalResult LvlOp::verify() {
if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
auto stt = getSparseTensorType(getSource());
if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
emitError("Level index exceeds the rank of the input sparse tensor");
}
return success();
}

std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
return getConstantIntValue(getIndex());
}

Speculation::Speculatability LvlOp::getSpeculatability() {
auto constantIndex = getConstantLvlIndex();
if (!constantIndex)
return Speculation::NotSpeculatable;

assert(constantIndex <
cast<RankedTensorType>(getSource().getType()).getRank());
return Speculation::Speculatable;
}

OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!lvlIndex)
return {};

Level lvl = lvlIndex.getAPSInt().getZExtValue();
auto stt = getSparseTensorType(getSource());
if (lvl >= stt.getLvlRank()) {
// Follows the same convention used by tensor.dim operation. Out of bound
// indices produce undefined behavior but are still valid IR. Don't choke on
// them.
return {};
}

// Helper lambda to build an IndexAttr.
auto getIndexAttr = [this](int64_t lvlSz) {
return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
};

// TODO: we can remove this after SparseTensorEncoding always returns non-null
// dimToLvl map.
ArrayRef<DynSize> shape = stt.getDimShape();
if (stt.isPermutation()) {
Dimension dim = toOrigDim(stt, lvl);
if (!ShapedType::isDynamic(shape[dim])) {
return getIndexAttr(shape[dim]);
}
return {};
}

// Non-permutation dim2lvl/lvl2dim maps.
AffineExpr lvlExpr = stt.getDimToLvl().getResult(lvl);
if (auto binExpr = lvlExpr.dyn_cast<AffineBinaryOpExpr>()) {
if (lvlExpr.getKind() == AffineExprKind::Mod) {
// j % block_sz, the level size equals to the block size.
int64_t lvlSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
return getIndexAttr(lvlSz);
}
if (lvlExpr.getKind() == AffineExprKind::FloorDiv) {
// j / block_sz, the level size equals to dim[j] / block_sz.
Dimension dim = binExpr.getLHS().cast<AffineDimExpr>().getPosition();
int64_t blockSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
if (ShapedType::isDynamic(shape[dim]))
return {};
return getIndexAttr(shape[dim] / blockSz);
}
}

auto dim = lvlExpr.cast<AffineDimExpr>().getPosition();
if (!ShapedType::isDynamic(dim))
return getIndexAttr(shape[dim]);

return {};
}

LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
Expand Down Expand Up @@ -1639,6 +1717,16 @@ LogicalResult YieldOp::verify() {
// TensorDialect Methods.
//===----------------------------------------------------------------------===//

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
return op;
return nullptr;
}

void SparseTensorDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/SparseTensor/fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index) {
%d0, %d1 = sparse_tensor.crd_translate lvl_to_dim [%l0, %l1, %l2, %l3] as #BSR : index, index
return %d0, %d1 : index, index
}

// CHECK-LABEL: func.func @sparse_lvl_0(
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: return %[[C5]] : index
func.func @sparse_lvl_0(%t : tensor<10x?xi32, #BSR>) -> index {
%lvl = arith.constant 0 : index
%l0 = sparse_tensor.lvl %t, %lvl : tensor<10x?xi32, #BSR>
return %l0 : index
}

// CHECK-LABEL: func.func @sparse_lvl_3(
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: return %[[C3]] : index
func.func @sparse_lvl_3(%t : tensor<?x?xi32, #BSR>) -> index {
%lvl = arith.constant 3 : index
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
return %l0 : index
}
18 changes: 18 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,21 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index, %arg2: index) -> (in
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1, %arg2] as #BSR : index, index, index, index
return %l0, %l1, %l2, %l3 : index, index, index, index
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

func.func @sparse_lvl(%t : tensor<?x?xi32, #BSR>) -> index {
%lvl = arith.constant 5 : index
// expected-error@+1 {{Level index exceeds the rank of the input sparse tensor}}
%l0 = sparse_tensor.lvl %t, %lvl : tensor<?x?xi32, #BSR>
return %l0 : index
}
21 changes: 21 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,24 @@ func.func @sparse_crd_translate(%arg0: index, %arg1: index) -> (index, index, in
%l0, %l1, %l2, %l3 = sparse_tensor.crd_translate dim_to_lvl [%arg0, %arg1] as #BSR : index, index, index, index
return %l0, %l1, %l2, %l3 : index, index, index, index
}

// -----

#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i floordiv 2 : dense,
j floordiv 3 : compressed,
i mod 2 : dense,
j mod 3 : dense
)
}>

// CHECK-LABEL: func.func @sparse_lvl(
// CHECK-SAME: %[[VAL_0:.*]]: index,
// CHECK-SAME: %[[VAL_1:.*]]: tensor
// CHECK: %[[VAL_2:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_0]]
// CHECK: return %[[VAL_2]]
func.func @sparse_lvl(%arg0: index, %t : tensor<?x?xi32, #BSR>) -> index {
%l0 = sparse_tensor.lvl %t, %arg0 : tensor<?x?xi32, #BSR>
return %l0 : index
}