Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] Parser cleanup #69792

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 0 additions & 35 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,6 @@ std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
return k ? std::make_optional(k.getValue()) : std::nullopt;
}

// This helper method is akin to `AffineExpr::operator==(int64_t)`
// except it uses a different implementation, namely the implementation
// used within `AsmPrinter::Impl::printAffineExprInternal`.
//
// wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses
// this implementation because it avoids constructing the intermediate
// `AffineConstantExpr(val)` and thus should in theory be a bit faster.
// However, if it is indeed faster, then the `AffineExpr::operator==`
// method should be updated to do this instead. And if it isn't any
// faster, then we should be using `AffineExpr::operator==` instead.
bool DimLvlExpr::hasConstantValue(int64_t val) const {
const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
return k && k.getValue() == val;
Expand Down Expand Up @@ -216,12 +206,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && (!expr || ranks.isValid(expr));
}

bool DimSpec::isFunctionOf(VarSet const &vars) const {
return vars.occursIn(expr);
}

void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }

void DimSpec::dump() const {
print(llvm::errs(), /*wantElision=*/false);
llvm::errs() << "\n";
Expand Down Expand Up @@ -262,12 +246,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && ranks.isValid(expr);
}

bool LvlSpec::isFunctionOf(VarSet const &vars) const {
return vars.occursIn(expr);
}

void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }

void LvlSpec::dump() const {
print(llvm::errs(), /*wantElision=*/false);
llvm::errs() << "\n";
Expand Down Expand Up @@ -301,19 +279,6 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
// below cannot cause OOB errors.
assert(isWF());

// TODO: Second, we need to infer/validate the `lvlToDim` mapping.
// Along the way we should set every `DimSpec::elideExpr` according
// to whether the given expression is inferable or not. Notably, this
// needs to happen before the code for setting every `LvlSpec::elideVar`,
// since if the LvlVar is only used in elided DimExpr, then the
// LvlVar should also be elided.
// NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
// to ensure that we maintain the invariant established by `isWF` above.

// Third, we set every `LvlSpec::elideVar` according to whether that
// LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
// NOTE: The invariant established by `isWF` ensures that the following
// calls to `VarSet::add` cannot raise OOB errors.
VarSet usedVars(getRanks());
for (const auto &dimSpec : dimSpecs)
if (!dimSpec.canElideExpr())
Expand Down
65 changes: 2 additions & 63 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
// be named as the storage representation of the parameter for the tblgen
// defn of STEA. We may well need to make the other classes public too,
// so that the rest of the compiler can use them when necessary.
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
Expand All @@ -23,16 +18,8 @@ namespace sparse_tensor {
namespace ir_detail {

//===----------------------------------------------------------------------===//
// TODO(wrengr): Give this enum a better name, so that it fits together
// with the name of the `DimLvlExpr` class (which may also want a better
// name). Perhaps make this a nested-type too.
//
// NOTE: In the future we will extend this enum to include "counting
// expressions" required for supporting ITPACK/ELL. Therefore the current
// underlying-type and representation values should not be relied upon.
enum class ExprKind : bool { Dimension = false, Level = true };

// TODO(wrengr): still needs a better name....
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
using VK = std::underlying_type_t<VarKind>;
return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
Expand All @@ -41,19 +28,8 @@ static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);

//===----------------------------------------------------------------------===//
// TODO(wrengr): The goal of this class is to capture a proof that
// we've verified that the given `AffineExpr` only has variables of the
// appropriate kind(s). So we need to actually prove/verify that in the
// ctor or all its callsites!
class DimLvlExpr {
private:
// FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
// the `kind` field should be private and const. However, beware
// that if we mark any field as `const` or if the fields have differing
// `private`/`protected` privileges then the `IsZeroCostAbstraction`
// assertion will fail!
// (Also, iirc, if we end up moving the `expr` to the subclasses
// instead, that'll also cause `IsZeroCostAbstraction` to fail.)
ExprKind kind;
AffineExpr expr;

Expand Down Expand Up @@ -100,11 +76,6 @@ class DimLvlExpr {
//
// Getters for handling `AffineExpr` subclasses.
//
// TODO(wrengr): is there any way to make these typesafe without too much
// templating?
// TODO(wrengr): Most if not all of these don't actually need to be
// methods, they could be free-functions instead.
//
Var castAnyVar() const;
std::optional<Var> dyn_castAnyVar() const;
SymVar castSymVar() const;
Expand All @@ -131,9 +102,6 @@ class DimLvlExpr {
// Variant of `mlir::AsmPrinter::Impl::BindingStrength`
enum class BindingStrength : bool { Weak = false, Strong = true };

// TODO(wrengr): Does our version of `printAffineExprInternal` really
// need to be a method, or could it be a free-function instead? (assuming
// `BindingStrength` goes with it).
void printAffineExprInternal(llvm::raw_ostream &os,
BindingStrength enclosingTightness) const;
void printStrong(llvm::raw_ostream &os) const {
Expand All @@ -145,12 +113,7 @@ class DimLvlExpr {
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);

// FUTURE_CL(wrengr): It would be nice to have the subclasses override
// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
// the proper covariant return types.
//
class DimExpr final : public DimLvlExpr {
// FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}

Expand All @@ -170,7 +133,6 @@ class DimExpr final : public DimLvlExpr {
static_assert(IsZeroCostAbstraction<DimExpr>);

class LvlExpr final : public DimLvlExpr {
// FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}

Expand All @@ -189,7 +151,6 @@ class LvlExpr final : public DimLvlExpr {
};
static_assert(IsZeroCostAbstraction<LvlExpr>);

// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
template <typename U>
constexpr bool DimLvlExpr::isa() const {
if constexpr (std::is_same_v<U, DimExpr>)
Expand Down Expand Up @@ -247,18 +208,12 @@ class DimSpec final {
/// the result of this predicate.
[[nodiscard]] bool isValid(Ranks const &ranks) const;

// TODO(wrengr): Use it or loose it.
bool isFunctionOf(Var var) const;
bool isFunctionOf(VarSet const &vars) const;
void getFreeVars(VarSet &vars) const;

std::string str(bool wantElision = true) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
// Although this class is more than just a newtype/wrapper, we do want
// to ensure that storing them into `SmallVector` is efficient.

static_assert(IsZeroCostAbstraction<DimSpec>);

//===----------------------------------------------------------------------===//
Expand All @@ -270,13 +225,6 @@ class LvlSpec final {
/// whereas the `DimLvlMap` ctor will reset this as appropriate.
bool elideVar = false;
/// The level-expression.
//
// NOTE: For now we use `LvlExpr` because all level-expressions must be
// `AffineExpr`; however, in the future we will also want to allow "counting
// expressions", and potentially other kinds of non-affine level-expressions.
// Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
// so we may consider defining another class for pairing those two together
// to ensure that the pair is well-formed.
LvlExpr expr;
/// The level-type (== level-format + lvl-properties).
DimLevelType type;
Expand All @@ -298,23 +246,14 @@ class LvlSpec final {

/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
//
// NOTE: Once we introduce "counting expressions" this will need
// a more sophisticated implementation than `DimSpec::isValid` does.
[[nodiscard]] bool isValid(Ranks const &ranks) const;

// TODO(wrengr): Use it or loose it.
bool isFunctionOf(Var var) const;
bool isFunctionOf(VarSet const &vars) const;
void getFreeVars(VarSet &vars) const;

std::string str(bool wantElision = true) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
// Although this class is more than just a newtype/wrapper, we do want
// to ensure that storing them into `SmallVector` is efficient.

static_assert(IsZeroCostAbstraction<LvlSpec>);

//===----------------------------------------------------------------------===//
Expand Down
64 changes: 4 additions & 60 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,20 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
VarInfo::ID &varID,
bool &didCreate) {
// Save the current location so that we can have error messages point to
// the right place. Note that `Parser::emitWrongTokenError` starts off
// with the same location as `AsmParserImpl::getCurrentLocation` returns;
// however, `Parser` will then do some various munging with the location
// before actually using it, so `AsmParser::emitError` can't quite be used
// as a drop-in replacement for `Parser::emitWrongTokenError`.
// the right place.
const auto loc = parser.getCurrentLocation();

// Several things to note.
// (1) the `Parser::isCurrentTokenAKeyword` method checks the exact
// same conditions as the `AffineParser.cpp`-static free-function
// `isIdentifier` which is used by `AffineParser::parseBareIdExpr`.
// (2) the `{Parser,AsmParserImpl}::parseOptionalKeyword(StringRef*)`
// methods do the same song and dance about using
// `isCurrentTokenAKeyword`, `getTokenSpelling`, et `consumeToken` as we
// would want to do if we could use the `Parser` class directly. It
// doesn't provide the nice error handling we want, but we can work around
// that.
StringRef name;
if (failed(parser.parseOptionalKeyword(&name))) {
// If not actually optional, then `emitError`.
ERROR_IF(!isOptional, "expected bare identifier")
// If is actually optional, then return the null `OptionalParseResult`.
return std::nullopt;
}

// I don't know if we need to worry about the possibility of the caller
// recovering from error and then reusing the `DimLvlMapParser` for subsequent
// `parseVar`, but I'm erring on the side of caution by distinguishing
// all three possible creation policies.
if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
varID = res->first;
didCreate = res->second;
return success();
}
// TODO(wrengr): these error messages make sense for our intended usage,
// but not in general; but it's unclear how best to factor that part out.

switch (creationPolicy) {
case Policy::MustNot:
return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
Expand Down Expand Up @@ -167,8 +145,6 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
FAILURE_IF_FAILED(parseDimSpecList())
FAILURE_IF_FAILED(parser.parseArrow())
FAILURE_IF_FAILED(parseLvlSpecList())
// TODO(wrengr): Try to improve the error messages from
// `VarEnv::emitErrorIfAnyUnbound`.
InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
if (failed(ifd))
return ifd;
Expand All @@ -182,29 +158,6 @@ ParseResult DimLvlMapParser::parseSymbolBindingList() {
" in symbol binding list");
}

// FIXME: The forward-declaration of level-vars is a stop-gap workaround
// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the
// variables must be bound before entering `AsmParser::parseAffineExpr`,
// since that method requires every variable to already have a fixed/known
// `Var::Num`.)
//
// However, the forward-declaration list duplicates information which is
// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
// the names of the variables themselves, and the order in which the names
// are bound). This redundancy causes bad UX, and also means we must be
// sure to verify consistency between the two sources of information.
//
// Therefore, it would be best to remove the forward-declaration list from
// the syntax. This can be achieved by implementing our own version of
// `AffineParser::parseAffineExpr` which calls
// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
// some `Var::Num` immediately). This would also enable us to use the `VarEnv`
// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
// side, and thus would also enable us to avoid the O(n^2) behavior of copying
// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
// every time `AsmParser::parseAffineExpr` is called.
ParseResult DimLvlMapParser::parseLvlVarBindingList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::OptionalBraces,
Expand Down Expand Up @@ -233,9 +186,6 @@ ParseResult DimLvlMapParser::parseDimSpec() {
AffineExpr affine;
if (succeeded(parser.parseOptionalEqual())) {
// Parse the dim affine expr, with only any lvl-vars in scope.
// FIXME(wrengr): This still has the O(n^2) behavior of copying
// our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
// field every time `parseDimSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
}
DimExpr expr{affine};
Expand Down Expand Up @@ -304,9 +254,6 @@ static inline Twine nth(Var::Num n) {
}
}

// NOTE: This is factored out as a separate method only because `Var`
// lacks a default-ctor, which makes this conditional difficult to inline
// at the one call-site.
FailureOr<LvlVar>
DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
// Nothing to parse, just bind an unnamed variable.
Expand Down Expand Up @@ -336,17 +283,14 @@ DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
}

ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
// Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding`
// specifies whether that "optional" is actually Must or MustNot.)
// Parse the optional lvl-var binding. `requireLvlVarBinding`
// specifies whether that "optional" is actually Must or MustNot.
const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
FAILURE_IF_FAILED(varRes)
const LvlVar var = *varRes;

// Parse the lvl affine expr, with only the dim-vars in scope.
AffineExpr affine;
// FIXME(wrengr): This still has the O(n^2) behavior of copying
// our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
// field every time `parseLvlSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
LvlExpr expr{affine};

Expand Down
25 changes: 9 additions & 16 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,32 @@ class DimLvlMapParser final {
FailureOr<DimLvlMap> parseDimLvlMap();

private:
/// The core code for parsing `Var`. This method abstracts out a lot
/// of complex details to avoid code duplication; however, client code
/// should prefer using `parseVarUsage` and `parseVarBinding` rather than
/// calling this method directly.
/// Client code should prefer using `parseVarUsage`
/// and `parseVarBinding` rather than calling this method directly.
OptionalParseResult parseVar(VarKind vk, bool isOptional,
Policy creationPolicy, VarInfo::ID &id,
bool &didCreate);

/// Parse a variable occurence which is a *use* of that variable.
/// The `requireKnown` parameter specifies how to handle the case of
/// encountering a valid variable name which is currently unused: when
/// `requireKnown=true`, an error is raised; when `requireKnown=false`,
/// Parses a variable occurence which is a *use* of that variable.
/// When a valid variable name is currently unused, if
/// `requireKnown=true`, an error is raised; if `requireKnown=false`,
/// a new unbound variable will be created.
///
/// NOTE: Just because a variable is *known* (i.e., the name has been
/// associated with an `VarInfo::ID`), does not mean that the variable
/// is actually *in scope*.
FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);

/// Parse a variable occurence which is a *binding* of that variable.
/// Parses a variable occurence which is a *binding* of that variable.
/// The `requireKnown` parameter is for handling the binding of
/// forward-declared variables.
FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);

/// Parse an optional variable binding. When the next token is
/// Parses an optional variable binding. When the next token is
/// not a valid variable name, this will bind a new unnamed variable.
/// The returned `bool` indicates whether a variable name was parsed.
FailureOr<std::pair<Var, bool>>
parseOptionalVarBinding(VarKind vk, bool requireKnown = false);

/// Binds the given variable: both updating the `VarEnv` itself, and
/// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
/// to `AsmParser::parseAffineExpr`). This method is already called by the
/// the `{dims,lvls}AndSymbols` lists (which will be passed
/// to `AsmParser::parseAffineExpr`). This method is already called by the
/// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
/// not need to be called elsewhere.
Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
Expand Down