Skip to content

Commit

Permalink
[mlir][sparse] minor merger API simplification
Browse files Browse the repository at this point in the history
Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133821
  • Loading branch information
aartbik committed Sep 14, 2022
1 parent ecb5ea6 commit 47a715d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 14 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Expand Up @@ -261,8 +261,8 @@ class Merger {
return getDimLevelFormat(t, i).levelType == tp;
}

/// Returns true if any set bit corresponds to given dimension level type.
bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const;
/// Returns true if any set bit corresponds to sparse dimension level type.
bool hasAnySparse(const BitVector &bits) const;

/// Dimension level format getter.
DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Expand Up @@ -1641,8 +1641,7 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
unsigned lsize = merger.set(lts).size();
for (unsigned i = 1; i < lsize; i++) {
unsigned li = merger.set(lts)[i];
if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) &&
!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton))
if (!merger.hasAnySparse(merger.lat(li).simple))
return true;
}
}
Expand Down
15 changes: 5 additions & 10 deletions mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Expand Up @@ -262,13 +262,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
}
}
// Now apply the two basic rules.
//
// TODO: improve for singleton and properties
//
BitVector simple = latPoints[p0].bits;
bool reset = isSingleton &&
(hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) ||
hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton));
bool reset = isSingleton && hasAnySparse(simple);
for (unsigned b = 0, be = simple.size(); b < be; b++) {
if (simple[b] &&
(!isDimLevelType(b, DimLvlType::kCompressed) &&
Expand Down Expand Up @@ -297,8 +292,7 @@ bool Merger::latGT(unsigned i, unsigned j) const {
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
BitVector tmp = latPoints[j].bits;
tmp ^= latPoints[i].bits;
return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) &&
!hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton);
return !hasAnySparse(tmp);
}

bool Merger::isSingleCondition(unsigned t, unsigned e) const {
Expand Down Expand Up @@ -384,9 +378,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
llvm_unreachable("unexpected kind");
}

bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const {
bool Merger::hasAnySparse(const BitVector &bits) const {
for (unsigned b = 0, be = bits.size(); b < be; b++)
if (bits[b] && isDimLevelType(b, tp))
if (bits[b] && (isDimLevelType(b, DimLvlType::kCompressed) ||
isDimLevelType(b, DimLvlType::kSingleton)))
return true;
return false;
}
Expand Down

0 comments on commit 47a715d

Please sign in to comment.