Skip to content

Commit

Permalink
Remove the need for passing a location to parseAttribute/parseType.
Browse files Browse the repository at this point in the history
Now that a proper parser is passed to these methods, there isn't a need to explicitly pass a source location. The source location can be recovered from the parser as necessary. This removes the need to explicitly decode an SMLoc in the case where we don't need to, which can be expensive.

This requires adding some basic nesting support to the parser for supporting nested parsers to allow for remapping source locations of the nested parsers to the top level parser for accurate diagnostics. This is due to the fact that the attribute and type parsers use different source buffers than the top level parser, as they may be represented in string form.

PiperOrigin-RevId: 278014858
  • Loading branch information
River707 authored and tensorflower-gardener committed Nov 1, 2019
1 parent 445cc3f commit 2ba4d80
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 56 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
Expand Up @@ -173,7 +173,7 @@ class LLVMDialect : public Dialect {
llvm::Module &getLLVMModule();

/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;

/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
Expand Up @@ -37,7 +37,7 @@ class LinalgDialect : public Dialect {
static StringRef getDialectNamespace() { return "linalg"; }

/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;

/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/QuantOps/QuantOps.h
Expand Up @@ -35,7 +35,7 @@ class QuantizationDialect : public Dialect {
QuantizationDialect(MLIRContext *context);

/// Parse a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;

/// Print a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
Expand Up @@ -46,7 +46,7 @@ class SPIRVDialect : public Dialect {
static std::string getAttributeName(Decoration decoration);

/// Parses a type registered to this dialect.
Type parseType(DialectAsmParser &parser, Location loc) const override;
Type parseType(DialectAsmParser &parser) const override;

/// Prints a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/Diagnostics.h
Expand Up @@ -391,7 +391,7 @@ class InFlightDiagnostic {
friend DiagnosticEngine;

/// The engine that this diagnostic is to report to.
DiagnosticEngine *owner;
DiagnosticEngine *owner = nullptr;

/// The raw diagnostic that is inflight to be reported.
llvm::Optional<Diagnostic> impl;
Expand Down
5 changes: 2 additions & 3 deletions mlir/include/mlir/IR/Dialect.h
Expand Up @@ -117,8 +117,7 @@ class Dialect {

/// Parse an attribute registered to this dialect. If 'type' is nonnull, it
/// refers to the expected type of the attribute.
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type,
Location loc) const;
virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;

/// Print an attribute registered to this dialect. Note: The type of the
/// attribute need not be printed by this method as it is always printed by
Expand All @@ -128,7 +127,7 @@ class Dialect {
}

/// Parse a type registered to this dialect.
virtual Type parseType(DialectAsmParser &parser, Location loc) const;
virtual Type parseType(DialectAsmParser &parser) const;

/// Print a type registered to this dialect.
virtual void printType(Type, DialectAsmPrinter &) const {
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/DialectImplementation.h
Expand Up @@ -129,6 +129,9 @@ class DialectAsmParser {
/// Return the location of the original name token.
virtual llvm::SMLoc getNameLoc() const = 0;

/// Re-encode the given source location as an MLIR location and return it.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;

/// Returns the full specification of the symbol being parsed. This allows for
/// using a separate parser if necessary.
virtual StringRef getFullSymbolSpec() const = 0;
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Expand Up @@ -1250,7 +1250,7 @@ llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }

/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
Type LLVMDialect::parseType(DialectAsmParser &parser) const {
StringRef tyData = parser.getFullSymbolSpec();

// LLVM is not thread-safe, so lock access to it.
Expand All @@ -1259,7 +1259,8 @@ Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const {
llvm::SMDiagnostic errorMessage;
llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
if (!type)
return (emitError(loc, errorMessage.getMessage()), nullptr);
return (parser.emitError(parser.getNameLoc(), errorMessage.getMessage()),
nullptr);
return LLVMType::get(getContext(), type);
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
Expand Up @@ -108,8 +108,8 @@ Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
return getImpl()->getBufferSize();
}

Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser,
Location loc) const {
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
StringRef spec = parser.getFullSymbolSpec();
StringRef origSpec = spec;
MLIRContext *context = getContext();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp
Expand Up @@ -616,8 +616,8 @@ bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) {
}

/// Parse a type registered to this dialect.
Type QuantizationDialect::parseType(DialectAsmParser &parser,
Location loc) const {
Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc);
Type parsedType = typeParser.parseType();
if (parsedType == nullptr) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
Expand Up @@ -610,8 +610,9 @@ static Type parseStructType(SPIRVDialect const &dialect, StringRef spec,
// | pointer-type
// | runtime-array-type
// | struct-type
Type SPIRVDialect::parseType(DialectAsmParser &parser, Location loc) const {
Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
StringRef spec = parser.getFullSymbolSpec();
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());

if (spec.startswith("array"))
return parseArrayType(*this, spec, loc);
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/IR/Dialect.cpp
Expand Up @@ -102,23 +102,23 @@ LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
}

/// Parse an attribute registered to this dialect.
Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type,
Location loc) const {
emitError(loc) << "dialect '" << getNamespace()
<< "' provides no attribute parsing hook";
Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
parser.emitError(parser.getNameLoc())
<< "dialect '" << getNamespace()
<< "' provides no attribute parsing hook";
return Attribute();
}

/// Parse a type registered to this dialect.
Type Dialect::parseType(DialectAsmParser &parser, Location loc) const {
Type Dialect::parseType(DialectAsmParser &parser) const {
// If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext());
return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
}

emitError(loc) << "dialect '" << getNamespace()
<< "' provides no type parsing hook";
parser.emitError(parser.getNameLoc())
<< "dialect '" << getNamespace() << "' provides no type parsing hook";
return Type();
}

Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Parser/Lexer.h
Expand Up @@ -45,6 +45,9 @@ class Lexer {
/// at the designated point in the input.
void resetPointer(const char *newPointer) { curPtr = newPointer; }

/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }

private:
// Helpers.
Token formToken(Token::Kind kind, const char *tokStart) {
Expand Down

0 comments on commit 2ba4d80

Please sign in to comment.