diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index d087e98ac42f3a..4141c68a5e3799 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -26,24 +26,39 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; /// Dimension level type for a tensor (undef means index does not appear). enum class Dim { kSparse, kDense, kSingle, kUndef }; +/// Children expressions of a binary TensorExp. +struct Children { + unsigned e0; + unsigned e1; +}; + /// Tensor expression. Represents a MLIR expression in tensor index notation. /// For tensors, e0 denotes the tensor index. For invariants, the IR value is /// stored directly. For binary operations, e0 and e1 denote the index of the /// children tensor expressions. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), e0(x), e1(y), val(v) { - assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || - (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || - (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); + TensorExp(Kind k, unsigned x, unsigned y, Value v) : kind(k), val(v) { + assert((kind == Kind::kTensor && x != -1u && y == -1u && !val) || + (kind == Kind::kInvariant && x == -1u && y == -1u && val) || + (kind >= Kind::kMulF && x != -1u && y != -1u && !val)); + if (kind == Kind::kTensor) { + tensor = x; + } else if (kind >= Kind::kMulF) { + children.e0 = x; + children.e1 = y; + } } /// Tensor expression kind. Kind kind; - /// Indices of children expression(s). - unsigned e0; - unsigned e1; + union { + /// Expressions representing tensors simply have a tensor number. + unsigned tensor; + + /// Binary operations hold the indices of their child expressions. + Children children; + }; /// Direct link to IR for an invariant. During code generation, /// field is used to cache "hoisted" loop invariant tensor loads. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 0409a7eabdfb7b..813fe683ae6195 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -214,11 +214,11 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { switch (merger.exp(exp).kind) { case Kind::kTensor: - return merger.exp(exp).e0 == tensor; + return merger.exp(exp).tensor == tensor; case Kind::kMulF: case Kind::kMulI: - return isConjunction(merger, tensor, merger.exp(exp).e0) || - isConjunction(merger, tensor, merger.exp(exp).e1); + return isConjunction(merger, tensor, merger.exp(exp).children.e0) || + isConjunction(merger, tensor, merger.exp(exp).children.e1); default: return false; } @@ -455,7 +455,7 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, } // Actual load. SmallVector args; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; unsigned tensor = t->getOperandNumber(); auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); @@ -628,8 +628,8 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, return genTensorLoad(merger, codegen, rewriter, op, exp); else if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); - Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); - Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); + Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); + Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); switch (merger.exp(exp).kind) { case Kind::kTensor: case Kind::kInvariant: @@ -653,7 +653,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { @@ -675,8 +675,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen, // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - unsigned e0 = merger.exp(exp).e0; - unsigned e1 = merger.exp(exp).e1; + unsigned e0 = merger.exp(exp).children.e0; + unsigned e1 = merger.exp(exp).children.e1; genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 0c869be07a1250..6150c15a0ad180 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -72,7 +72,8 @@ unsigned Merger::optimizeSet(unsigned s0) { if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor) + if (tensorExps[e].kind == Kind::kTensor && + tensorExps[e].tensor == outTensor) continue; // Conjunction already covered? for (unsigned p2 : latSets[s]) { @@ -150,11 +151,11 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { case Kind::kTensor: - if (tensorExps[e].e0 == syntheticTensor) + if (tensorExps[e].tensor == syntheticTensor) llvm::dbgs() << "synthetic_"; - else if (tensorExps[e].e0 == outTensor) + else if (tensorExps[e].tensor == outTensor) llvm::dbgs() << "output_"; - llvm::dbgs() << "tensor_" << tensorExps[e].e0; + llvm::dbgs() << "tensor_" << tensorExps[e].tensor; break; case Kind::kInvariant: llvm::dbgs() << "invariant"; @@ -162,17 +163,17 @@ void Merger::dumpExp(unsigned e) const { default: case Kind::kMulI: llvm::dbgs() << "("; - dumpExp(tensorExps[e].e0); + dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " * "; - dumpExp(tensorExps[e].e1); + dumpExp(tensorExps[e].children.e1); llvm::dbgs() << ")"; break; case Kind::kAddF: case Kind::kAddI: llvm::dbgs() << "("; - dumpExp(tensorExps[e].e0); + dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " + "; - dumpExp(tensorExps[e].e1); + dumpExp(tensorExps[e].children.e1); llvm::dbgs() << ")"; break; } @@ -234,12 +235,13 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) { // set to the undefined index in that dimension. An invariant expression // is set to a synthetic tensor with undefined indices only. unsigned s = addSet(); - unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor; + unsigned t = + kind == Kind::kTensor ? tensorExps[e].children.e0 : syntheticTensor; latSets[s].push_back(addLat(t, idx, e)); return s; } - unsigned s0 = buildLattices(tensorExps[e].e0, idx); - unsigned s1 = buildLattices(tensorExps[e].e1, idx); + unsigned s0 = buildLattices(tensorExps[e].children.e0, idx); + unsigned s1 = buildLattices(tensorExps[e].children.e1, idx); switch (kind) { case Kind::kTensor: case Kind::kInvariant: