Skip to content

Commit

Permalink
[mlir][sparse] introduce operations to query sparse tensor slice offs…
Browse files Browse the repository at this point in the history
…et/strides at the given dimenion

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D141442
  • Loading branch information
PeimingLiu committed Feb 16, 2023
1 parent a851d46 commit c738b43
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 0 deletions.
Expand Up @@ -335,13 +335,21 @@ def SparseTensorStorageSpecifierKindAttr
def IsSparseTensorPred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;

def IsSparseTensorSlicePred
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;

// The following four follow the same idiom as `TensorOf`, `AnyTensor`,
// `RankedTensorOf`, `AnyRankedTensor`.

class SparseTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;

class SparseTensorSliceOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;

def AnySparseTensor : SparseTensorOf<[AnyType]>;
def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>;

class RankedSparseTensorOf<list<Type> allowedTypes>
: RankedTensorOf<allowedTypes, [IsSparseTensorPred], "ranked sparse tensor">;
Expand Down
55 changes: 55 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Expand Up @@ -294,6 +294,61 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [Pure]>,
let hasVerifier = 1;
}

def SparseTensor_ToSliceOffsetOp : SparseTensor_Op<"slice.offset", [Pure]>,
Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>,
Results<(outs Index:$offset)> {
let summary = "Extracts the offset of the sparse tensor slice at the given dimension";
let description = [{
Extracts the offset of the sparse tensor slice at the given dimension.

Currently, sparse tensor slices are still a work in progress, and only
works when runtime library is disabled (i.e., running sparse compiler
with `enable-runtime-library=false`).

Example:

```mlir
%0 = tensor.extract_slice %s[%v1, %v2][64, 64][1, 1] : tensor<128x128xf64, #DCSR>
to tensor<64x64xf64, #Slice>

%1 = sparse_tensor.slice.offset %0 at 0 : tensor<64x64xf64, #Slice>
%2 = sparse_tensor.slice.offset %0 at 1 : tensor<64x64xf64, #Slice>
// %1 = %v1
// %2 = %v2
```
}];
let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)";
let hasVerifier = 1;
}

def SparseTensor_ToSliceStrideOp : SparseTensor_Op<"slice.stride", [Pure]>,
Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>,
Results<(outs Index:$stride)> {
let summary = "Extracts the stride of the sparse tensor slice at the given dimension";
let description = [{
Extracts the stride of the sparse tensor slice at the given dimension.

Currently, sparse tensor slices are still a work in progress, and only
works when runtime library is disabled (i.e., running sparse compiler
with `enable-runtime-library=false`).

Example:

```mlir
%0 = tensor.extract_slice %s[%v1, %v2][64, 64][%s1, %s2] : tensor<128x128xf64, #DCSR>
to tensor<64x64xf64, #Slice>

%1 = sparse_tensor.slice.stride %0 at 0 : tensor<64x64xf64, #Slice>
%2 = sparse_tensor.slice.stride %0 at 1 : tensor<64x64xf64, #Slice>
// %1 = %s1
// %2 = %s2

```
}];
let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)";
let hasVerifier = 1;
}

def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>,
Results<(outs SparseTensorStorageSpecifier:$result)> {
let summary = "";
Expand Down
14 changes: 14 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Expand Up @@ -765,6 +765,20 @@ LogicalResult ToValuesOp::verify() {
return success();
}

LogicalResult ToSliceOffsetOp::verify() {
auto rank = getRankedTensorType(getSlice()).getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
return emitError("requested dimension out of bound");
return success();
}

LogicalResult ToSliceStrideOp::verify() {
auto rank = getRankedTensorType(getSlice()).getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
return emitError("requested dimension out of bound");
return success();
}

LogicalResult GetStorageSpecifierOp::verify() {
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
getSpecifierKind(), getDim(), getSpecifier(), getOperation()))
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Expand Up @@ -224,6 +224,32 @@ func.func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<

// -----

#CSR_SLICE = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
slice = [ (1, 4, 1), (1, 4, 2) ]
}>

func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index {
// expected-error@+1 {{requested dimension out of bound}}
%0 = sparse_tensor.slice.offset %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE>
return %0 : index
}

// -----

#CSR_SLICE = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
slice = [ (1, 4, 1), (1, 4, 2) ]
}>

func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index {
// expected-error@+1 {{requested dimension out of bound}}
%0 = sparse_tensor.slice.stride %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE>
return %0 : index
}

// -----

#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>

func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 {
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Expand Up @@ -148,6 +148,38 @@ func.func @sparse_values(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xf64>

// -----

#CSR_SLICE = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
slice = [ (1, 4, 1), (1, 4, 2) ]
}>

// CHECK-LABEL: func @sparse_slice_offset(
// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>)
// CHECK: %[[T:.*]] = sparse_tensor.slice.offset %[[A]] at 1 : tensor<2x8xf64, #{{.*}}>
// CHECK: return %[[T]] : index
func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index {
%0 = sparse_tensor.slice.offset %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE>
return %0 : index
}

// -----

#CSR_SLICE = #sparse_tensor.encoding<{
dimLevelType = [ "dense", "compressed" ],
slice = [ (1, 4, 1), (1, 4, 2) ]
}>

// CHECK-LABEL: func @sparse_slice_stride(
// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>)
// CHECK: %[[T:.*]] = sparse_tensor.slice.stride %[[A]] at 1 : tensor<2x8xf64, #{{.*}}>
// CHECK: return %[[T]] : index
func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index {
%0 = sparse_tensor.slice.stride %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE>
return %0 : index
}

// -----

#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>

// CHECK-LABEL: func @sparse_metadata_init(
Expand Down

0 comments on commit c738b43

Please sign in to comment.