Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] Parser cleanup #69792

Merged
merged 1 commit into from Oct 23, 2023
Merged

Conversation

yinying-lisa-li
Copy link
Contributor

Removed TODOs, FIXMEs and long notes that are more suited for design doc.

Removed TODOs, FIXMEs and long notes that are more suited for design doc.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 23, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Yinying Li (yinying-lisa-li)

Changes

Removed TODOs, FIXMEs and long notes that are more suited for design doc.


Patch is 39.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69792.diff

8 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp (-35)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h (+2-63)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp (+4-60)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h (+9-16)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (-18)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h (+5-27)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp (-37)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h (+11-95)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 5f947b67c6d848e..851867926fe679e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -60,16 +60,6 @@ std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
   return k ? std::make_optional(k.getValue()) : std::nullopt;
 }
 
-// This helper method is akin to `AffineExpr::operator==(int64_t)`
-// except it uses a different implementation, namely the implementation
-// used within `AsmPrinter::Impl::printAffineExprInternal`.
-//
-// wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses
-// this implementation because it avoids constructing the intermediate
-// `AffineConstantExpr(val)` and thus should in theory be a bit faster.
-// However, if it is indeed faster, then the `AffineExpr::operator==`
-// method should be updated to do this instead.  And if it isn't any
-// faster, then we should be using `AffineExpr::operator==` instead.
 bool DimLvlExpr::hasConstantValue(int64_t val) const {
   const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
   return k && k.getValue() == val;
@@ -216,12 +206,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && (!expr || ranks.isValid(expr));
 }
 
-bool DimSpec::isFunctionOf(VarSet const &vars) const {
-  return vars.occursIn(expr);
-}
-
-void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
 void DimSpec::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";
@@ -262,12 +246,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && ranks.isValid(expr);
 }
 
-bool LvlSpec::isFunctionOf(VarSet const &vars) const {
-  return vars.occursIn(expr);
-}
-
-void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
 void LvlSpec::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";
@@ -301,19 +279,6 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
   // below cannot cause OOB errors.
   assert(isWF());
 
-  // TODO: Second, we need to infer/validate the `lvlToDim` mapping.
-  // Along the way we should set every `DimSpec::elideExpr` according
-  // to whether the given expression is inferable or not.  Notably, this
-  // needs to happen before the code for setting every `LvlSpec::elideVar`,
-  // since if the LvlVar is only used in elided DimExpr, then the
-  // LvlVar should also be elided.
-  // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
-  // to ensure that we maintain the invariant established by `isWF` above.
-
-  // Third, we set every `LvlSpec::elideVar` according to whether that
-  // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
-  // NOTE: The invariant established by `isWF` ensures that the following
-  // calls to `VarSet::add` cannot raise OOB errors.
   VarSet usedVars(getRanks());
   for (const auto &dimSpec : dimSpecs)
     if (!dimSpec.canElideExpr())
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 2d02eed2cf9972e..664b49509f070f4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -5,11 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
-// be named as the storage representation of the parameter for the tblgen
-// defn of STEA.  We may well need to make the other classes public too,
-// so that the rest of the compiler can use them when necessary.
-//===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
@@ -23,16 +18,8 @@ namespace sparse_tensor {
 namespace ir_detail {
 
 //===----------------------------------------------------------------------===//
-// TODO(wrengr): Give this enum a better name, so that it fits together
-// with the name of the `DimLvlExpr` class (which may also want a better
-// name).  Perhaps make this a nested-type too.
-//
-// NOTE: In the future we will extend this enum to include "counting
-// expressions" required for supporting ITPACK/ELL.  Therefore the current
-// underlying-type and representation values should not be relied upon.
 enum class ExprKind : bool { Dimension = false, Level = true };
 
-// TODO(wrengr): still needs a better name....
 constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
   using VK = std::underlying_type_t<VarKind>;
   return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
@@ -41,19 +28,8 @@ static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
               getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
 
 //===----------------------------------------------------------------------===//
-// TODO(wrengr): The goal of this class is to capture a proof that
-// we've verified that the given `AffineExpr` only has variables of the
-// appropriate kind(s).  So we need to actually prove/verify that in the
-// ctor or all its callsites!
 class DimLvlExpr {
 private:
-  // FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
-  // the `kind` field should be private and const.  However, beware
-  // that if we mark any field as `const` or if the fields have differing
-  // `private`/`protected` privileges then the `IsZeroCostAbstraction`
-  // assertion will fail!
-  // (Also, iirc, if we end up moving the `expr` to the subclasses
-  // instead, that'll also cause `IsZeroCostAbstraction` to fail.)
   ExprKind kind;
   AffineExpr expr;
 
@@ -100,11 +76,6 @@ class DimLvlExpr {
   //
   // Getters for handling `AffineExpr` subclasses.
   //
-  // TODO(wrengr): is there any way to make these typesafe without too much
-  // templating?
-  // TODO(wrengr): Most if not all of these don't actually need to be
-  // methods, they could be free-functions instead.
-  //
   Var castAnyVar() const;
   std::optional<Var> dyn_castAnyVar() const;
   SymVar castSymVar() const;
@@ -131,9 +102,6 @@ class DimLvlExpr {
   // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
   enum class BindingStrength : bool { Weak = false, Strong = true };
 
-  // TODO(wrengr): Does our version of `printAffineExprInternal` really
-  // need to be a method, or could it be a free-function instead? (assuming
-  // `BindingStrength` goes with it).
   void printAffineExprInternal(llvm::raw_ostream &os,
                                BindingStrength enclosingTightness) const;
   void printStrong(llvm::raw_ostream &os) const {
@@ -145,12 +113,7 @@ class DimLvlExpr {
 };
 static_assert(IsZeroCostAbstraction<DimLvlExpr>);
 
-// FUTURE_CL(wrengr): It would be nice to have the subclasses override
-// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
-// the proper covariant return types.
-//
 class DimExpr final : public DimLvlExpr {
-  // FIXME(wrengr): These two are needed for the current RTTI implementation.
   friend class DimLvlExpr;
   constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
 
@@ -170,7 +133,6 @@ class DimExpr final : public DimLvlExpr {
 static_assert(IsZeroCostAbstraction<DimExpr>);
 
 class LvlExpr final : public DimLvlExpr {
-  // FIXME(wrengr): These two are needed for the current RTTI implementation.
   friend class DimLvlExpr;
   constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
 
@@ -189,7 +151,6 @@ class LvlExpr final : public DimLvlExpr {
 };
 static_assert(IsZeroCostAbstraction<LvlExpr>);
 
-// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
 template <typename U>
 constexpr bool DimLvlExpr::isa() const {
   if constexpr (std::is_same_v<U, DimExpr>)
@@ -247,18 +208,12 @@ class DimSpec final {
   /// the result of this predicate.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
 
-  // TODO(wrengr): Use it or loose it.
-  bool isFunctionOf(Var var) const;
-  bool isFunctionOf(VarSet const &vars) const;
-  void getFreeVars(VarSet &vars) const;
-
   std::string str(bool wantElision = true) const;
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;
 };
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
 static_assert(IsZeroCostAbstraction<DimSpec>);
 
 //===----------------------------------------------------------------------===//
@@ -270,13 +225,6 @@ class LvlSpec final {
   /// whereas the `DimLvlMap` ctor will reset this as appropriate.
   bool elideVar = false;
   /// The level-expression.
-  //
-  // NOTE: For now we use `LvlExpr` because all level-expressions must be
-  // `AffineExpr`; however, in the future we will also want to allow "counting
-  // expressions", and potentially other kinds of non-affine level-expressions.
-  // Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
-  // so we may consider defining another class for pairing those two together
-  // to ensure that the pair is well-formed.
   LvlExpr expr;
   /// The level-type (== level-format + lvl-properties).
   DimLevelType type;
@@ -298,23 +246,14 @@ class LvlSpec final {
 
   /// Checks whether the variables bound/used by this spec are valid
   /// with respect to the given ranks.
-  //
-  // NOTE: Once we introduce "counting expressions" this will need
-  // a more sophisticated implementation than `DimSpec::isValid` does.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
 
-  // TODO(wrengr): Use it or loose it.
-  bool isFunctionOf(Var var) const;
-  bool isFunctionOf(VarSet const &vars) const;
-  void getFreeVars(VarSet &vars) const;
-
   std::string str(bool wantElision = true) const;
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;
 };
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
 static_assert(IsZeroCostAbstraction<LvlSpec>);
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 680411f8008ea3d..6fb69d1397e6cfb 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -44,42 +44,20 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
                                               VarInfo::ID &varID,
                                               bool &didCreate) {
   // Save the current location so that we can have error messages point to
-  // the right place.  Note that `Parser::emitWrongTokenError` starts off
-  // with the same location as `AsmParserImpl::getCurrentLocation` returns;
-  // however, `Parser` will then do some various munging with the location
-  // before actually using it, so `AsmParser::emitError` can't quite be used
-  // as a drop-in replacement for `Parser::emitWrongTokenError`.
+  // the right place.
   const auto loc = parser.getCurrentLocation();
-
-  // Several things to note.
-  // (1) the `Parser::isCurrentTokenAKeyword` method checks the exact
-  //     same conditions as the `AffineParser.cpp`-static free-function
-  //     `isIdentifier` which is used by `AffineParser::parseBareIdExpr`.
-  // (2) the `{Parser,AsmParserImpl}::parseOptionalKeyword(StringRef*)`
-  //     methods do the same song and dance about using
-  //     `isCurrentTokenAKeyword`, `getTokenSpelling`, et `consumeToken` as we
-  //     would want to do if we could use the `Parser` class directly.  It
-  //     doesn't provide the nice error handling we want, but we can work around
-  //     that.
   StringRef name;
   if (failed(parser.parseOptionalKeyword(&name))) {
-    // If not actually optional, then `emitError`.
     ERROR_IF(!isOptional, "expected bare identifier")
-    // If is actually optional, then return the null `OptionalParseResult`.
     return std::nullopt;
   }
 
-  // I don't know if we need to worry about the possibility of the caller
-  // recovering from error and then reusing the `DimLvlMapParser` for subsequent
-  // `parseVar`, but I'm erring on the side of caution by distinguishing
-  // all three possible creation policies.
   if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
     varID = res->first;
     didCreate = res->second;
     return success();
   }
-  // TODO(wrengr): these error messages make sense for our intended usage,
-  // but not in general; but it's unclear how best to factor that part out.
+
   switch (creationPolicy) {
   case Policy::MustNot:
     return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
@@ -167,8 +145,6 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
   FAILURE_IF_FAILED(parseDimSpecList())
   FAILURE_IF_FAILED(parser.parseArrow())
   FAILURE_IF_FAILED(parseLvlSpecList())
-  // TODO(wrengr): Try to improve the error messages from
-  // `VarEnv::emitErrorIfAnyUnbound`.
   InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
   if (failed(ifd))
     return ifd;
@@ -182,29 +158,6 @@ ParseResult DimLvlMapParser::parseSymbolBindingList() {
       " in symbol binding list");
 }
 
-// FIXME: The forward-declaration of level-vars is a stop-gap workaround
-// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
-// `DimLvlMapParser::parseDimSpec`.  (In particular, note that all the
-// variables must be bound before entering `AsmParser::parseAffineExpr`,
-// since that method requires every variable to already have a fixed/known
-// `Var::Num`.)
-//
-// However, the forward-declaration list duplicates information which is
-// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
-// the names of the variables themselves, and the order in which the names
-// are bound).  This redundancy causes bad UX, and also means we must be
-// sure to verify consistency between the two sources of information.
-//
-// Therefore, it would be best to remove the forward-declaration list from
-// the syntax.  This can be achieved by implementing our own version of
-// `AffineParser::parseAffineExpr` which calls
-// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
-// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
-// some `Var::Num` immediately).  This would also enable us to use the `VarEnv`
-// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
-// side, and thus would also enable us to avoid the O(n^2) behavior of copying
-// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
-// every time `AsmParser::parseAffineExpr` is called.
 ParseResult DimLvlMapParser::parseLvlVarBindingList() {
   return parser.parseCommaSeparatedList(
       OpAsmParser::Delimiter::OptionalBraces,
@@ -233,9 +186,6 @@ ParseResult DimLvlMapParser::parseDimSpec() {
   AffineExpr affine;
   if (succeeded(parser.parseOptionalEqual())) {
     // Parse the dim affine expr, with only any lvl-vars in scope.
-    // FIXME(wrengr): This still has the O(n^2) behavior of copying
-    // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
-    // field every time `parseDimSpec` is called.
     FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
   }
   DimExpr expr{affine};
@@ -304,9 +254,6 @@ static inline Twine nth(Var::Num n) {
   }
 }
 
-// NOTE: This is factored out as a separate method only because `Var`
-// lacks a default-ctor, which makes this conditional difficult to inline
-// at the one call-site.
 FailureOr<LvlVar>
 DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
   // Nothing to parse, just bind an unnamed variable.
@@ -336,17 +283,14 @@ DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
 }
 
 ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
-  // Parse the optional lvl-var binding.  (Actually, `requireLvlVarBinding`
-  // specifies whether that "optional" is actually Must or MustNot.)
+  // Parse the optional lvl-var binding. `requireLvlVarBinding`
+  // specifies whether that "optional" is actually Must or MustNot.
   const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
   FAILURE_IF_FAILED(varRes)
   const LvlVar var = *varRes;
 
   // Parse the lvl affine expr, with only the dim-vars in scope.
   AffineExpr affine;
-  // FIXME(wrengr): This still has the O(n^2) behavior of copying
-  // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
-  // field every time `parseLvlSpec` is called.
   FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
   LvlExpr expr{affine};
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
index 013a89ea172b0b5..11f727c40644ae1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
@@ -42,39 +42,32 @@ class DimLvlMapParser final {
   FailureOr<DimLvlMap> parseDimLvlMap();
 
 private:
-  /// The core code for parsing `Var`.  This method abstracts out a lot
-  /// of complex details to avoid code duplication; however, client code
-  /// should prefer using `parseVarUsage` and `parseVarBinding` rather than
-  /// calling this method directly.
+  /// Client code should prefer using `parseVarUsage`
+  /// and `parseVarBinding` rather than calling this method directly.
   OptionalParseResult parseVar(VarKind vk, bool isOptional,
                                Policy creationPolicy, VarInfo::ID &id,
                                bool &didCreate);
 
-  /// Parse a variable occurence which is a *use* of that variable.
-  /// The `requireKnown` parameter specifies how to handle the case of
-  /// encountering a valid variable name which is currently unused: when
-  /// `requireKnown=true`, an error is raised; when `requireKnown=false`,
+  /// Parses a variable occurence which is a *use* of that variable.
+  /// When a valid variable name is currently unused, if
+  /// `requireKnown=true`, an error is raised; if `requireKnown=false`,
   /// a new unbound variable will be created.
-  ///
-  /// NOTE: Just because a variable is *known* (i.e., the name has been
-  /// associated with an `VarInfo::ID`), does not mean that the variable
-  /// is actually *in scope*.
   FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
 
-  /// Parse a variable occurence which is a *binding* of that variable.
+  /// Parses a variable occurence which is a *binding* of that variable.
   /// The `requireKnown` parameter is for handling the binding of
   /// forward-declared variables.
   FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
 
-  /// Parse an optional variable binding.  When the next token is
+  /// Parses an optional variable binding. When the next token is
   /// not a valid variable name, this will bind a new unnamed variable.
   /// The returned `bool` indicates whether a variable name was parsed.
   FailureOr<std::pair<Var, bool>>
   parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
 
   /// Binds the given variable: both updating the `VarEnv` itself, and
-  /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
-  /// to `AsmParser::parseAffineExpr`).  This method is already called by the
+  /// the `{dims,lvls}AndSymbols` lists (which will be passed
+  /// to `AsmParser::parseAffineExpr`). This method is already called by the
   /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
   /// not need to be called elsewhere.
   Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 053e067fff64ddb..8cc7068e3113aff 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/...
[truncated]

@yinying-lisa-li yinying-lisa-li merged commit 34ed07e into llvm:main Oct 23, 2023
6 checks passed
@yinying-lisa-li yinying-lisa-li deleted the cleanup branch October 23, 2023 18:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants