-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] minor refactoring of sparsification file #74403
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Aart Bik (aartbik) ChangesRemoved obsoleted TODOs and NOTEs, formatting, removed unused parameter Full diff: https://github.com/llvm/llvm-project/pull/74403.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e0d3ce241e454..d171087f56ab1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -34,6 +34,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/SmallBitVector.h"
+
#include <optional>
using namespace mlir;
@@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor;
// Sparsifier analysis methods.
//===----------------------------------------------------------------------===//
-// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
-// and those letters are too easy to confuse visually. We should switch
-// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
-// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
-
/// Determines if affine expression is invariant.
static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
bool &isAtLoop) {
@@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
const LoopId i = cast<AffineDimExpr>(a).getPosition();
if (i == ldx) {
isAtLoop = true;
- // Must be invariant if we are at the given loop.
- return true;
+ return true; // invariant at given loop
}
- // The DimExpr is invariant the loop has already been generated.
- return i < loopDepth;
+ return i < loopDepth; // invariant when already generated
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
@@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
if (!isUndefLT(merger.getLvlType(tid, idx)))
return false; // used more than once
-
if (setLvlFormat)
merger.setLevelAndType(tid, idx, lvl, lt);
return true;
@@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
}
}
-/// Get the total number of compound affine expressions in the
+/// Gets the total number of compound affine expressions in the
/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
///
/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
@@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
return num;
}
-/// Get the total number of sparse levels with compound affine
+/// Gets the total number of sparse levels with compound affine
/// expressions, summed over all operands of the `GenericOp`.
static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
unsigned num = 0;
@@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
return num;
}
+// Returns true iff output has nontrivial affine indices.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
OpOperand *out = op.getDpsInitOperand(0);
if (getSparseTensorType(out->get()).isAllDense())
@@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
const auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;
-
const Level lvlRank = map.getNumResults();
assert(!enc || lvlRank == enc.getLvlRank());
assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
-
// We only need to do index reduction if there is at least one non-trivial
// index expression on sparse levels.
// If all non-trivial index expression is on dense levels, we can
@@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
}
/// Generates index for load/store on sparse tensor.
-// FIXME: It's not entirely clear what "index" means here (i.e., is it
-// a "coordinate", or "Ldx", or what). So the function should be renamed
-// and/or the documentation expanded in order to clarify.
static Value genIndex(CodegenEnv &env, OpOperand *t) {
const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
@@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value val = env.exp(exp).val;
if (val)
return val;
-
// Load during insertion.
linalg::GenericOp op = env.op();
OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
@@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
/// exception of index computations, which need to be relinked to actual
/// inlined cloned code.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
- Value e, LoopId ldx) {
+ Value e) {
if (auto arg = dyn_cast<BlockArgument>(e)) {
// Direct arguments of the original linalg op must be converted
// into dense tensor loads. Note that we should not encounter
@@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.updateRootInPlace(def, [&]() {
def->setOperand(
- i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+ i, relinkBranch(env, rewriter, block, def->getOperand(i)));
});
}
}
@@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
}
/// Recursively generates tensor expression.
-static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
- LoopId ldx) {
+static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
if (e == ::mlir::sparse_tensor::detail::kInvalidId)
return Value();
@@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
// based on the type of the other operand.
if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
- v1 = genExp(env, rewriter, exp.children.e1, ldx);
+ v1 = genExp(env, rewriter, exp.children.e1);
v0 = constantZero(rewriter, loc, v1.getType());
} else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
- v0 = genExp(env, rewriter, exp.children.e0, ldx);
+ v0 = genExp(env, rewriter, exp.children.e0);
v1 = constantZero(rewriter, loc, v0.getType());
} else {
- v0 = genExp(env, rewriter, exp.children.e0, ldx);
- v1 = genExp(env, rewriter, exp.children.e1, ldx);
+ v0 = genExp(env, rewriter, exp.children.e0);
+ v1 = genExp(env, rewriter, exp.children.e1);
}
Value ee;
@@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
kind == TensorExp::Kind::kReduce ||
kind == TensorExp::Kind::kSelect)) {
OpBuilder::InsertionGuard guard(rewriter);
- ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+ ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
}
}
@@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
return isCompressedLT(lt) || isSingletonLT(lt);
});
-
return isParallelFor(env, isOuter, isSparse);
}
@@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
- // NOTE: It assumes that the levels of the input tensor are
- // initialized in order (and it is also currently guaranteed by
- // computeIterationGraph), another more admissible approach
- // might be accepting out-of-order access between consecutive
- // dense levels.
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
}
}
@@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
LoopOrd at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
if (at == env.getLoopNum()) {
- Value rhs = genExp(env, rewriter, exp, at - 1);
+ Value rhs = genExp(env, rewriter, exp);
genTensorStore(env, rewriter, exp, rhs);
return;
}
@@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
- //
- // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
+ // We cannot change this to `for (const LatPointId li : env.set(lts))`
// because the loop body causes data-movement which invalidates
// the iterator.
const unsigned lsize = env.set(lts).size();
@@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
Value validIns = env.getValidLexInsert();
- // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
+ // We cannot change this to `for (const LatPointId lj : env.set(lts))`
// because the loop body causes data-movement which invalidates the
// iterator.
for (unsigned j = 0; j < lsize; j++) {
@@ -1323,6 +1303,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (hasNonTrivialAffineOnSparseOut(op))
return failure();
+ // Only accept scheduled loops.
if (!op->hasAttr("sorted")) {
return rewriter.notifyMatchFailure(
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
@@ -1348,9 +1329,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
}
}
- CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
// Detects sparse annotations and translates the per-level sparsity
// information for all tensors to loop indices in the kernel.
+ CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
if (!findSparseAnnotations(env, needIdxRed))
return failure();
|
PeimingLiu
approved these changes
Dec 5, 2023
yinying-lisa-li
approved these changes
Dec 5, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter