diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 80d2dbba187b8..1d81dafcd0eb8 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -25,6 +25,7 @@ namespace sparse_tensor { TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) : kind(k), val(v), op(o) { switch (kind) { + // Leaf. case kTensor: assert(x != -1u && y == -1u && !v && !o); tensor = x; @@ -36,6 +37,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) assert(x != -1u && y == -1u && !v && !o); index = x; break; + // Unary operations. case kAbsF: case kAbsC: case kCeilF: @@ -86,13 +88,32 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) children.e0 = x; children.e1 = y; break; - case kBinary: - assert(x != -1u && y != -1u && !v && o); + // Binary operations. + case kMulF: + case kMulC: + case kMulI: + case kDivF: + case kDivC: + case kDivS: + case kDivU: + case kAddF: + case kAddC: + case kAddI: + case kSubF: + case kSubC: + case kSubI: + case kAndI: + case kOrI: + case kXorI: + case kShrS: + case kShrU: + case kShlI: + assert(x != -1u && y != -1u && !v && !o); children.e0 = x; children.e1 = y; break; - default: - assert(x != -1u && y != -1u && !v && !o); + case kBinary: + assert(x != -1u && y != -1u && !v && o); children.e0 = x; children.e1 = y; break; @@ -280,8 +301,13 @@ bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { bool Merger::isSingleCondition(unsigned t, unsigned e) const { switch (tensorExps[e].kind) { + // Leaf. case kTensor: return tensorExps[e].tensor == t; + case kInvariant: + case kIndex: + return false; + // Unary operations. case kAbsF: case kAbsC: case kCeilF: @@ -313,6 +339,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const { case kCRe: case kBitCast: return isSingleCondition(t, tensorExps[e].children.e0); + case kBinaryBranch: + case kUnary: + return false; + // Binary operations. case kDivF: // note: x / c only case kDivC: case kDivS: @@ -339,7 +369,12 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const { case kAddI: return isSingleCondition(t, tensorExps[e].children.e0) && isSingleCondition(t, tensorExps[e].children.e1); - default: + case kSubF: + case kSubC: + case kSubI: + case kOrI: + case kXorI: + case kBinary: return false; } } @@ -352,12 +387,14 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const { static const char *kindToOpSymbol(Kind kind) { switch (kind) { + // Leaf. case kTensor: return "tensor"; case kInvariant: return "invariant"; case kIndex: return "index"; + // Unary operations. case kAbsF: case kAbsC: return "abs"; @@ -404,6 +441,7 @@ static const char *kindToOpSymbol(Kind kind) { return "binary_branch"; case kUnary: return "unary"; + // Binary operations. case kMulF: case kMulC: case kMulI: @@ -441,6 +479,7 @@ static const char *kindToOpSymbol(Kind kind) { void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { + // Leaf. case kTensor: if (tensorExps[e].tensor == syntheticTensor) llvm::dbgs() << "synthetic_"; @@ -454,7 +493,9 @@ void Merger::dumpExp(unsigned e) const { case kIndex: llvm::dbgs() << "index_" << tensorExps[e].index; break; + // Unary operations. case kAbsF: + case kAbsC: case kCeilF: case kFloorF: case kSqrtF: @@ -462,10 +503,13 @@ void Merger::dumpExp(unsigned e) const { case kExpm1F: case kExpm1C: case kLog1pF: + case kLog1pC: case kSinF: + case kSinC: case kTanhF: case kTanhC: case kNegF: + case kNegC: case kNegI: case kTruncF: case kExtF: @@ -477,11 +521,35 @@ void Merger::dumpExp(unsigned e) const { case kCastU: case kCastIdx: case kTruncI: + case kCIm: + case kCRe: case kBitCast: + case kBinaryBranch: + case kUnary: llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e0); break; - default: + // Binary operations. + case kMulF: + case kMulC: + case kMulI: + case kDivF: + case kDivC: + case kDivS: + case kDivU: + case kAddF: + case kAddC: + case kAddI: + case kSubF: + case kSubC: + case kSubI: + case kAndI: + case kOrI: + case kXorI: + case kShrS: + case kShrU: + case kShlI: + case kBinary: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; @@ -542,6 +610,7 @@ void Merger::dumpBits(const BitVector &bits) const { unsigned Merger::buildLattices(unsigned e, unsigned i) { Kind kind = tensorExps[e].kind; switch (kind) { + // Leaf. case kTensor: case kInvariant: case kIndex: { @@ -560,11 +629,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) { latSets[s].push_back(addLat(t, i, e)); return s; } + // Unary operations. case kAbsF: case kAbsC: case kCeilF: - case kCIm: - case kCRe: case kFloorF: case kSqrtF: case kSqrtC: @@ -589,6 +657,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) { case kCastU: case kCastIdx: case kTruncI: + case kCIm: + case kCRe: case kBitCast: // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the // lattice set of the operand through the operator into a new set. @@ -625,6 +695,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) { unsigned rhs = addExp(kInvariant, absentVal); return takeDisj(kind, child0, buildLattices(rhs, i), unop); } + // Binary operations. case kMulF: case kMulC: case kMulI: @@ -955,16 +1026,17 @@ static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1) { switch (tensorExps[e].kind) { + // Leaf. case kTensor: case kInvariant: case kIndex: llvm_unreachable("unexpected non-op"); - // Unary ops. + // Unary operations. case kAbsF: return rewriter.create(loc, v0); case kAbsC: { - auto type = v0.getType().template cast(); - auto eltType = type.getElementType().template cast(); + auto type = v0.getType().cast(); + auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } case kCeilF: @@ -1021,18 +1093,19 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, return rewriter.create(loc, inferType(e, v0), v0); case kTruncI: return rewriter.create(loc, inferType(e, v0), v0); - case kCIm: + case kCIm: { + auto type = v0.getType().cast(); + auto eltType = type.getElementType().cast(); + return rewriter.create(loc, eltType, v0); + } case kCRe: { - auto type = v0.getType().template cast(); - auto eltType = type.getElementType().template cast(); - if (tensorExps[e].kind == kCIm) - return rewriter.create(loc, eltType, v0); - + auto type = v0.getType().cast(); + auto eltType = type.getElementType().cast(); return rewriter.create(loc, eltType, v0); } case kBitCast: return rewriter.create(loc, inferType(e, v0), v0); - // Binary ops. + // Binary operations. case kMulF: return rewriter.create(loc, v0, v1); case kMulC: @@ -1071,8 +1144,7 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, return rewriter.create(loc, v0, v1); case kShlI: return rewriter.create(loc, v0, v1); - // Semiring ops with custom logic. - case kBinaryBranch: + case kBinaryBranch: // semi-ring ops with custom logic. return insertYieldOp(rewriter, loc, *tensorExps[e].op->getBlock()->getParent(), {v0}); case kUnary: diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index 4bdfa71d8bc49..f64251953c9f5 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -136,43 +136,78 @@ class MergerTestBase : public ::testing::Test { } /// Compares expressions for equality. Equality is defined recursively as: - /// - Two expressions can only be equal if they have the same Kind. - /// - Two binary expressions are equal if they have the same Kind and their - /// children are equal. - /// - Expressions with Kind invariant or tensor are equal if they have the - /// same expression id. + /// - Operations are equal if they have the same kind and children. + /// - Leaf tensors are equal if they refer to the same tensor. bool compareExpression(unsigned e, const std::shared_ptr &pattern) { auto tensorExp = merger.exp(e); if (tensorExp.kind != pattern->kind) return false; - assert(tensorExp.kind != Kind::kInvariant && - "Invariant comparison not yet supported"); switch (tensorExp.kind) { - case Kind::kTensor: + // Leaf. + case kTensor: return tensorExp.tensor == pattern->tensorNum; - case Kind::kAbsF: - case Kind::kCeilF: - case Kind::kFloorF: - case Kind::kNegF: - case Kind::kNegI: + case kInvariant: + case kIndex: + llvm_unreachable("invariant not handled yet"); + // Unary operations. + case kAbsF: + case kAbsC: + case kCeilF: + case kFloorF: + case kSqrtF: + case kSqrtC: + case kExpm1F: + case kExpm1C: + case kLog1pF: + case kLog1pC: + case kSinF: + case kSinC: + case kTanhF: + case kTanhC: + case kNegF: + case kNegC: + case kNegI: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kCastIdx: + case kTruncI: + case kCIm: + case kCRe: + case kBitCast: + case kBinaryBranch: + case kUnary: + case kShlI: + case kBinary: return compareExpression(tensorExp.children.e0, pattern->e0); - case Kind::kMulF: - case Kind::kMulI: - case Kind::kDivF: - case Kind::kDivS: - case Kind::kDivU: - case Kind::kAddF: - case Kind::kAddI: - case Kind::kSubF: - case Kind::kSubI: - case Kind::kAndI: - case Kind::kOrI: - case Kind::kXorI: + // Binary operations. + case kMulF: + case kMulC: + case kMulI: + case kDivF: + case kDivC: + case kDivS: + case kDivU: + case kAddF: + case kAddC: + case kAddI: + case kSubF: + case kSubC: + case kSubI: + case kAndI: + case kOrI: + case kXorI: + case kShrS: + case kShrU: return compareExpression(tensorExp.children.e0, pattern->e0) && compareExpression(tensorExp.children.e1, pattern->e1); - default: - llvm_unreachable("Unhandled Kind"); } + llvm_unreachable("unexpected kind"); } unsigned numTensors;