Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] minor refactoring of sparsification file #74403

Merged
merged 1 commit into from
Dec 5, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Dec 5, 2023

Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter

Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Dec 5, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

Changes

Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter


Full diff: https://github.com/llvm/llvm-project/pull/74403.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+19-38)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index e0d3ce241e454..d171087f56ab1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -34,6 +34,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/TensorEncoding.h"
 #include "llvm/ADT/SmallBitVector.h"
+
 #include <optional>
 
 using namespace mlir;
@@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor;
 // Sparsifier analysis methods.
 //===----------------------------------------------------------------------===//
 
-// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
-// and those letters are too easy to confuse visually.  We should switch
-// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
-// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
-
 /// Determines if affine expression is invariant.
 static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
                               bool &isAtLoop) {
@@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
     const LoopId i = cast<AffineDimExpr>(a).getPosition();
     if (i == ldx) {
       isAtLoop = true;
-      // Must be invariant if we are at the given loop.
-      return true;
+      return true; // invariant at given loop
     }
-    // The DimExpr is invariant the loop has already been generated.
-    return i < loopDepth;
+    return i < loopDepth; // invariant when already generated
   }
   case AffineExprKind::Add:
   case AffineExprKind::Mul: {
@@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
     const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
     if (!isUndefLT(merger.getLvlType(tid, idx)))
       return false; // used more than once
-
     if (setLvlFormat)
       merger.setLevelAndType(tid, idx, lvl, lt);
     return true;
@@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
   }
 }
 
-/// Get the total number of compound affine expressions in the
+/// Gets the total number of compound affine expressions in the
 /// `getMatchingIndexingMap` for the given tensor.  For the following inputs:
 ///
 /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
@@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
   return num;
 }
 
-/// Get the total number of sparse levels with compound affine
+/// Gets the total number of sparse levels with compound affine
 /// expressions, summed over all operands of the `GenericOp`.
 static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
   unsigned num = 0;
@@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
   return num;
 }
 
+// Returns true iff output has nontrivial affine indices.
 static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
   OpOperand *out = op.getDpsInitOperand(0);
   if (getSparseTensorType(out->get()).isAllDense())
@@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
     const auto enc = getSparseTensorEncoding(t.get().getType());
     if (enc)
       annotated = true;
-
     const Level lvlRank = map.getNumResults();
     assert(!enc || lvlRank == enc.getLvlRank());
     assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
-
     // We only need to do index reduction if there is at least one non-trivial
     // index expression on sparse levels.
     // If all non-trivial index expression is on dense levels, we can
@@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
 }
 
 /// Generates index for load/store on sparse tensor.
-// FIXME: It's not entirely clear what "index" means here (i.e., is it
-// a "coordinate", or "Ldx", or what).  So the function should be renamed
-// and/or the documentation expanded in order to clarify.
 static Value genIndex(CodegenEnv &env, OpOperand *t) {
   const auto map = env.op().getMatchingIndexingMap(t);
   const auto stt = getSparseTensorType(t->get());
@@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
   Value val = env.exp(exp).val;
   if (val)
     return val;
-
   // Load during insertion.
   linalg::GenericOp op = env.op();
   OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
@@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
 /// exception of index computations, which need to be relinked to actual
 /// inlined cloned code.
 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
-                          Value e, LoopId ldx) {
+                          Value e) {
   if (auto arg = dyn_cast<BlockArgument>(e)) {
     // Direct arguments of the original linalg op must be converted
     // into dense tensor loads. Note that we should not encounter
@@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
       for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
         rewriter.updateRootInPlace(def, [&]() {
           def->setOperand(
-              i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
+              i, relinkBranch(env, rewriter, block, def->getOperand(i)));
         });
       }
     }
@@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
 }
 
 /// Recursively generates tensor expression.
-static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
-                    LoopId ldx) {
+static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
   if (e == ::mlir::sparse_tensor::detail::kInvalidId)
     return Value();
 
@@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
   // based on the type of the other operand.
   if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
       env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
-    v1 = genExp(env, rewriter, exp.children.e1, ldx);
+    v1 = genExp(env, rewriter, exp.children.e1);
     v0 = constantZero(rewriter, loc, v1.getType());
   } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
              env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
-    v0 = genExp(env, rewriter, exp.children.e0, ldx);
+    v0 = genExp(env, rewriter, exp.children.e0);
     v1 = constantZero(rewriter, loc, v0.getType());
   } else {
-    v0 = genExp(env, rewriter, exp.children.e0, ldx);
-    v1 = genExp(env, rewriter, exp.children.e1, ldx);
+    v0 = genExp(env, rewriter, exp.children.e0);
+    v1 = genExp(env, rewriter, exp.children.e1);
   }
 
   Value ee;
@@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
          kind == TensorExp::Kind::kReduce ||
          kind == TensorExp::Kind::kSelect)) {
       OpBuilder::InsertionGuard guard(rewriter);
-      ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+      ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
     }
   }
 
@@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
     const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
     return isCompressedLT(lt) || isSingletonLT(lt);
   });
-
   return isParallelFor(env, isOuter, isSparse);
 }
 
@@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs(
             // level. We need to generate the address according to the
             // affine expression. This is also the best place we can do it
             // to avoid putting it inside inner loops.
-            // NOTE: It assumes that the levels of the input tensor are
-            // initialized in order (and it is also currently guaranteed by
-            // computeIterationGraph), another more admissible approach
-            // might be accepting out-of-order access between consecutive
-            // dense levels.
             affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
           }
         }
@@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
                     LoopOrd at) {
   // At each leaf, assign remaining tensor (sub)expression to output tensor.
   if (at == env.getLoopNum()) {
-    Value rhs = genExp(env, rewriter, exp, at - 1);
+    Value rhs = genExp(env, rewriter, exp);
     genTensorStore(env, rewriter, exp, rhs);
     return;
   }
@@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
   bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
 
   // Emit a loop for every lattice point L0 >= Li in this loop sequence.
-  //
-  // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
+  // We cannot change this to `for (const LatPointId li : env.set(lts))`
   // because the loop body causes data-movement which invalidates
   // the iterator.
   const unsigned lsize = env.set(lts).size();
@@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
     Value cntInput = env.getExpandCount();
     Value insInput = env.getInsertionChain();
     Value validIns = env.getValidLexInsert();
-    // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
+    // We cannot change this to `for (const LatPointId lj : env.set(lts))`
     // because the loop body causes data-movement which invalidates the
     // iterator.
     for (unsigned j = 0; j < lsize; j++) {
@@ -1323,6 +1303,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
     if (hasNonTrivialAffineOnSparseOut(op))
       return failure();
 
+    // Only accept scheduled loops.
     if (!op->hasAttr("sorted")) {
       return rewriter.notifyMatchFailure(
           op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
@@ -1348,9 +1329,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
       }
     }
 
-    CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
     // Detects sparse annotations and translates the per-level sparsity
     // information for all tensors to loop indices in the kernel.
+    CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
     if (!findSparseAnnotations(env, needIdxRed))
       return failure();
 

@aartbik aartbik merged commit 067bebb into llvm:main Dec 5, 2023
5 checks passed
@aartbik aartbik deleted the bik branch December 5, 2023 17:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants