diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 5847fecc45404..2837361a3a3f5 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -1402,7 +1402,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) { // is sparse. For a disjunctive operation, it yields a "sparse" result if // all operands are sparse. bool conjSpVals = xSpVals || ySpVals; - bool disjSpVals = xSpVals && ySpVals; if (x.has_value() && y.has_value()) { const ExprId e0 = *x; const ExprId e1 = *y; @@ -1421,23 +1420,23 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (isa(def) && !maybeZero(e1)) return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kAddF, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kAddC, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kAddI, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kSubF, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kSubC, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kSubI, e0, e1), conjSpVals}; if (isa(def)) return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kOrI, e0, e1), conjSpVals}; if (isa(def)) - return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals}; + return {addExp(TensorExp::Kind::kXorI, e0, e1), conjSpVals}; if (isa(def) && isInvariant(e1)) return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals}; if (isa(def) && isInvariant(e1)) diff --git a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir index ea29f1d677eff..14abad77a9177 100644 --- a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir +++ b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir @@ -93,3 +93,35 @@ func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>, } -> tensor<2x10x100xf32> return %179 : tensor<2x10x100xf32> } + +// +// This kernel cannot be sparsified: the unsparsifiable op (math.exp) takes +// a result of arith.subf whose first operand is a sparse tensor. Even though +// arith.subf is a disjunctive op (result is 0 when LHS=0 regardless of RHS), +// it still carries a sparse tensor dependency so kDenseOp wrapping is invalid. +// Regression test for github.com/llvm/llvm-project/issues/114855. +// +#sparse3d = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }> + +// CHECK-LABEL: func @dense_op_with_sparse_out_and_sp_dep +// CHECK: linalg.generic {{.*}} +func.func @dense_op_with_sparse_out_and_sp_dep( + %arg0: tensor<2x3x4xf32, #sparse3d>, + %arg1: tensor<2x4xf32>) -> tensor<2x3x4xf32, #sparse3d> { + %0 = tensor.empty() : tensor<2x3x4xf32, #sparse3d> + %1 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)> + ], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<2x3x4xf32, #sparse3d>, tensor<2x4xf32>) + outs(%0 : tensor<2x3x4xf32, #sparse3d>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.subf %in, %in_0 : f32 + %3 = math.exp %2 : f32 + linalg.yield %3 : f32 + } -> tensor<2x3x4xf32, #sparse3d> + return %1 : tensor<2x3x4xf32, #sparse3d> +}