diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index 2db552f5fe039..c0c3d4920a1fe 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -93,4 +93,28 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", }]; } +def IsSparseTensorPred + : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">; + +// The following four follow the same idiom as `TensorOf`, `AnyTensor`, +// `RankedTensorOf`, `AnyRankedTensor`. + +class SparseTensorOf allowedTypes> + : ShapedContainerType< + allowedTypes, + And<[IsTensorTypePred, IsSparseTensorPred]>, + "sparse tensor", + "::mlir::TensorType">; + +def AnySparseTensor : SparseTensorOf<[AnyType]>; + +class RankedSparseTensorOf allowedTypes> + : ShapedContainerType< + allowedTypes, + And<[IsTensorTypePred, HasRankPred, IsSparseTensorPred]>, + "ranked sparse tensor", + "::mlir::TensorType">; + +def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>; + #endif // SPARSETENSOR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 4f31031b1fe84..bdc27b57fe10f 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -27,7 +27,7 @@ class SparseTensor_Op traits = []> def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>, Arguments<(ins AnyType:$source)>, - Results<(outs TensorOf<[AnyType]>:$result)> { + Results<(outs AnySparseTensor:$result)> { string summary = "Materializes a new sparse tensor from given source"; string description = [{ Materializes a sparse tensor with contents taken from an opaque pointer @@ -46,7 +46,6 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>, ``` }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; - let hasVerifier = 1; } def SparseTensor_ConvertOp : SparseTensor_Op<"convert", @@ -92,7 +91,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert", } def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>, - Arguments<(ins AnyTensor:$tensor, Index:$dim)>, + Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>, Results<(outs AnyStridedMemRefOfRank<1>:$result)> { let summary = "Extracts pointers array at given dimension from a tensor"; let description = [{ @@ -117,7 +116,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>, } def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>, - Arguments<(ins AnyTensor:$tensor, Index:$dim)>, + Arguments<(ins AnySparseTensor:$tensor, Index:$dim)>, Results<(outs AnyStridedMemRefOfRank<1>:$result)> { let summary = "Extracts indices array at given dimension from a tensor"; let description = [{ @@ -142,7 +141,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>, } def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>, - Arguments<(ins AnyTensor:$tensor)>, + Arguments<(ins AnySparseTensor:$tensor)>, Results<(outs AnyStridedMemRefOfRank<1>:$result)> { let summary = "Extracts numerical values array from a tensor"; let description = [{ @@ -173,7 +172,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>, //===----------------------------------------------------------------------===// def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>, - Arguments<(ins AnyTensor:$tensor, + Arguments<(ins AnySparseTensor:$tensor, StridedMemRefRankOf<[Index], [1]>:$indices, AnyType:$value)> { string summary = "Inserts a value into given sparse tensor in lexicographical index order"; @@ -196,11 +195,10 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>, }]; let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`" " type($tensor) `,` type($indices) `,` type($value)"; - let hasVerifier = 1; } def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, - Arguments<(ins AnyTensor:$tensor)>, + Arguments<(ins AnySparseTensor:$tensor)>, Results<(outs AnyStridedMemRefOfRank<1>:$values, StridedMemRefRankOf<[I1],[1]>:$filled, StridedMemRefRankOf<[Index],[1]>:$added, @@ -238,11 +236,10 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)" " `,` type($filled) `,` type($added) `,` type($count)"; - let hasVerifier = 1; } def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>, - Arguments<(ins AnyTensor:$tensor, + Arguments<(ins AnySparseTensor:$tensor, StridedMemRefRankOf<[Index],[1]>:$indices, AnyStridedMemRefOfRank<1>:$values, StridedMemRefRankOf<[I1],[1]>:$filled, @@ -273,11 +270,10 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>, " $added `,` $count attr-dict `:` type($tensor) `,`" " type($indices) `,` type($values) `,` type($filled) `,`" " type($added) `,` type($count)"; - let hasVerifier = 1; } def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>, - Arguments<(ins AnyTensor:$tensor, UnitAttr:$hasInserts)>, + Arguments<(ins AnySparseTensor:$tensor, UnitAttr:$hasInserts)>, Results<(outs AnyTensor:$result)> { let summary = "Rematerializes tensor from underlying sparse storage format"; @@ -306,11 +302,10 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>, ``` }]; let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)"; - let hasVerifier = 1; } def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>, - Arguments<(ins AnyTensor:$tensor)> { + Arguments<(ins AnySparseTensor:$tensor)> { string summary = "Releases underlying sparse storage format of given tensor"; string description = [{ Releases the underlying sparse storage format for a tensor that @@ -332,11 +327,10 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>, ``` }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; - let hasVerifier = 1; } def SparseTensor_OutOp : SparseTensor_Op<"out", []>, - Arguments<(ins AnyType:$tensor, AnyType:$dest)> { + Arguments<(ins AnySparseTensor:$tensor, AnyType:$dest)> { string summary = "Outputs a sparse tensor to the given destination"; string description = [{ Outputs the contents of a sparse tensor to the destination defined by an @@ -353,7 +347,6 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>, ``` }]; let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; - let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index b860f07528956..418e7fe3bd822 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -208,12 +208,6 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) { return failure(); } -LogicalResult NewOp::verify() { - if (!getSparseTensorEncoding(result().getType())) - return emitError("expected a sparse tensor result"); - return success(); -} - LogicalResult ConvertOp::verify() { if (auto tp1 = source().getType().dyn_cast()) { if (auto tp2 = dest().getType().dyn_cast()) { @@ -240,30 +234,24 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) { } LogicalResult ToPointersOp::verify() { - if (auto e = getSparseTensorEncoding(tensor().getType())) { - if (failed(isInBounds(dim(), tensor()))) - return emitError("requested pointers dimension out of bounds"); - if (failed(isMatchingWidth(result(), e.getPointerBitWidth()))) - return emitError("unexpected type for pointers"); - return success(); - } - return emitError("expected a sparse tensor to get pointers"); + auto e = getSparseTensorEncoding(tensor().getType()); + if (failed(isInBounds(dim(), tensor()))) + return emitError("requested pointers dimension out of bounds"); + if (failed(isMatchingWidth(result(), e.getPointerBitWidth()))) + return emitError("unexpected type for pointers"); + return success(); } LogicalResult ToIndicesOp::verify() { - if (auto e = getSparseTensorEncoding(tensor().getType())) { - if (failed(isInBounds(dim(), tensor()))) - return emitError("requested indices dimension out of bounds"); - if (failed(isMatchingWidth(result(), e.getIndexBitWidth()))) - return emitError("unexpected type for indices"); - return success(); - } - return emitError("expected a sparse tensor to get indices"); + auto e = getSparseTensorEncoding(tensor().getType()); + if (failed(isInBounds(dim(), tensor()))) + return emitError("requested indices dimension out of bounds"); + if (failed(isMatchingWidth(result(), e.getIndexBitWidth()))) + return emitError("unexpected type for indices"); + return success(); } LogicalResult ToValuesOp::verify() { - if (!getSparseTensorEncoding(tensor().getType())) - return emitError("expected a sparse tensor to get values"); RankedTensorType ttp = tensor().getType().cast(); MemRefType mtp = result().getType().cast(); if (ttp.getElementType() != mtp.getElementType()) @@ -271,46 +259,6 @@ LogicalResult ToValuesOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// TensorDialect Management Operations. -//===----------------------------------------------------------------------===// - -LogicalResult LexInsertOp::verify() { - if (!getSparseTensorEncoding(tensor().getType())) - return emitError("expected a sparse tensor for insertion"); - return success(); -} - -LogicalResult ExpandOp::verify() { - if (!getSparseTensorEncoding(tensor().getType())) - return emitError("expected a sparse tensor for expansion"); - return success(); -} - -LogicalResult CompressOp::verify() { - if (!getSparseTensorEncoding(tensor().getType())) - return emitError("expected a sparse tensor for compression"); - return success(); -} - -LogicalResult LoadOp::verify() { - if (!getSparseTensorEncoding(tensor().getType())) - return emitError("expected a sparse tensor to materialize"); - return success(); -} - -LogicalResult ReleaseOp::verify() { - if (!getSparseTensorEncoding(tensor().getType())) - return emitError("expected a sparse tensor to release"); - return success(); -} - -LogicalResult OutOp::verify() { - if (!getSparseTensorEncoding(tensor().getType())) - return emitError("expected a sparse tensor for output"); - return success(); -} - //===----------------------------------------------------------------------===// // TensorDialect Linalg.Generic Operations. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 68f3cf3b7c5e6..8df924c0b0404 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics func.func @invalid_new_dense(%arg0: !llvm.ptr) -> tensor<32xf32> { - // expected-error@+1 {{expected a sparse tensor result}} + // expected-error@+1 {{'sparse_tensor.new' op result #0 must be sparse tensor of any type values, but got 'tensor<32xf32>'}} %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<32xf32> return %0 : tensor<32xf32> } @@ -9,7 +9,7 @@ func.func @invalid_new_dense(%arg0: !llvm.ptr) -> tensor<32xf32> { // ----- func.func @invalid_release_dense(%arg0: tensor<4xi32>) { - // expected-error@+1 {{expected a sparse tensor to release}} + // expected-error@+1 {{'sparse_tensor.release' op operand #0 must be sparse tensor of any type values, but got 'tensor<4xi32>'}} sparse_tensor.release %arg0 : tensor<4xi32> return } @@ -18,7 +18,7 @@ func.func @invalid_release_dense(%arg0: tensor<4xi32>) { func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref { %c = arith.constant 0 : index - // expected-error@+1 {{expected a sparse tensor to get pointers}} + // expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} %0 = sparse_tensor.pointers %arg0, %c : tensor<128xf64> to memref return %0 : memref } @@ -27,7 +27,7 @@ func.func @invalid_pointers_dense(%arg0: tensor<128xf64>) -> memref { func.func @invalid_pointers_unranked(%arg0: tensor<*xf64>) -> memref { %c = arith.constant 0 : index - // expected-error@+1 {{expected a sparse tensor to get pointers}} + // expected-error@+1 {{'sparse_tensor.pointers' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}} %0 = sparse_tensor.pointers %arg0, %c : tensor<*xf64> to memref return %0 : memref } @@ -58,7 +58,7 @@ func.func @pointers_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref) -> memref { %c = arith.constant 1 : index - // expected-error@+1 {{expected a sparse tensor to get indices}} + // expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<10x10xi32>'}} %0 = sparse_tensor.indices %arg0, %c : tensor<10x10xi32> to memref return %0 : memref } @@ -67,7 +67,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref { func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref { %c = arith.constant 0 : index - // expected-error@+1 {{expected a sparse tensor to get indices}} + // expected-error@+1 {{'sparse_tensor.indices' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}} %0 = sparse_tensor.indices %arg0, %c : tensor<*xf64> to memref return %0 : memref } @@ -97,7 +97,7 @@ func.func @indices_oob(%arg0: tensor<128xf64, #SparseVector>) -> memref // ----- func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref { - // expected-error@+1 {{expected a sparse tensor to get values}} + // expected-error@+1 {{'sparse_tensor.values' op operand #0 must be sparse tensor of any type values, but got 'tensor<1024xf32>'}} %0 = sparse_tensor.values %arg0 : tensor<1024xf32> to memref return %0 : memref } @@ -115,7 +115,7 @@ func.func @mismatch_values_types(%arg0: tensor) -> memref< // ----- func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64> { - // expected-error@+1 {{expected a sparse tensor to materialize}} + // expected-error@+1 {{'sparse_tensor.load' op operand #0 must be sparse tensor of any type values, but got 'tensor<16x32xf64>'}} %0 = sparse_tensor.load %arg0 : tensor<16x32xf64> return %0 : tensor<16x32xf64> } @@ -123,7 +123,7 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64 // ----- func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref, %arg2: f64) { - // expected-error@+1 {{expected a sparse tensor for insertion}} + // expected-error@+1 {{'sparse_tensor.lex_insert' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} sparse_tensor.lex_insert %arg0, %arg1, %arg2 : tensor<128xf64>, memref, f64 return } @@ -131,7 +131,7 @@ func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: memref) { - // expected-error@+1 {{expected a sparse tensor for expansion}} + // expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} %values, %filled, %added, %count = sparse_tensor.expand %arg0 : tensor<128xf64> to memref, memref, memref, index return @@ -142,7 +142,7 @@ func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) { func.func @sparse_unannotated_compression(%arg0: tensor<128xf64>, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: index) { - // expected-error@+1 {{expected a sparse tensor for compression}} + // expected-error@+1 {{'sparse_tensor.compress' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : tensor<128xf64>, memref, memref, memref, memref, index } @@ -178,7 +178,7 @@ func.func @sparse_convert_dim_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10x // ----- func.func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr) { - // expected-error@+1 {{expected a sparse tensor for output}} + // expected-error@+1 {{'sparse_tensor.out' op operand #0 must be sparse tensor of any type values, but got 'tensor<10xf64>'}} sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr return }