Skip to content

Commit

Permalink
[mlir] Add a DialectAsmParser::getChecked method
Browse files Browse the repository at this point in the history
This function simplifies calling the getChecked methods on Attributes and Types from within the parser, and removes any need to use `getEncodedSourceLocation` for these methods (by using an SMLoc instead). This is much more efficient than using an mlir::Location, as the encoding process to produce an mlir::Location is inefficient and undesirable for parsing (locations used during parsing should not persist afterwards unless otherwise necessary).

Differential Revision: https://reviews.llvm.org/D97900
  • Loading branch information
River707 committed Mar 4, 2021
1 parent 9783e20 commit 6bc767c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 33 deletions.
20 changes: 20 additions & 0 deletions mlir/include/mlir/IR/DialectImplementation.h
Expand Up @@ -121,6 +121,10 @@ class DialectAsmParser {
virtual llvm::SMLoc getNameLoc() const = 0;

/// Re-encode the given source location as an MLIR location and return it.
/// Note: This method should only be used when a `Location` is necessary, as
/// the encoding process is not efficient. In other cases a more suitable
/// alternative should be used, such as the `getChecked` methods defined
/// below.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;

/// Returns the full specification of the symbol being parsed. This allows for
Expand Down Expand Up @@ -163,6 +167,22 @@ class DialectAsmParser {
return success();
}

/// Invoke the `getChecked` method of the given Attribute or Type class, using
/// the provided location to emit errors in the case of failure. Note that
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
/// context parameter.
template <typename T, typename... ParamsT>
T getChecked(llvm::SMLoc loc, ParamsT &&...params) {
return T::getChecked([&] { return emitError(loc); },
std::forward<ParamsT>(params)...);
}
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
/// errors.
template <typename T, typename... ParamsT> T getChecked(ParamsT &&...params) {
return T::getChecked([&] { return emitError(getNameLoc()); },
std::forward<ParamsT>(params)...);
}

//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
Expand Down
28 changes: 14 additions & 14 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
Expand Up @@ -178,7 +178,7 @@ static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
/// Parses an LLVM dialect function type.
/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
Type returnType;
if (parser.parseLess() || dispatchParse(parser, returnType) ||
parser.parseLParen())
Expand All @@ -187,8 +187,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
// Function type without arguments.
if (succeeded(parser.parseOptionalRParen())) {
if (succeeded(parser.parseGreater()))
return LLVMFunctionType::getChecked(loc, returnType, llvm::None,
/*isVarArg=*/false);
return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
/*isVarArg=*/false);
return LLVMFunctionType();
}

Expand All @@ -198,8 +198,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
if (succeeded(parser.parseOptionalEllipsis())) {
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
return LLVMFunctionType::getChecked(loc, returnType, argTypes,
/*isVarArg=*/true);
return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
/*isVarArg=*/true);
}

Type arg;
Expand All @@ -210,14 +210,14 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {

if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
return LLVMFunctionType::getChecked(loc, returnType, argTypes,
/*isVarArg=*/false);
return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
/*isVarArg=*/false);
}

/// Parses an LLVM dialect pointer type.
/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
Type elementType;
if (parser.parseLess() || dispatchParse(parser, elementType))
return LLVMPointerType();
Expand All @@ -228,7 +228,7 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
return LLVMPointerType();
if (failed(parser.parseGreater()))
return LLVMPointerType();
return LLVMPointerType::getChecked(loc, elementType, addressSpace);
return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
}

/// Parses an LLVM dialect vector type.
Expand All @@ -238,7 +238,7 @@ static Type parseVectorType(DialectAsmParser &parser) {
SmallVector<int64_t, 2> dims;
llvm::SMLoc dimPos, typePos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
parser.getCurrentLocation(&typePos) ||
Expand All @@ -259,13 +259,13 @@ static Type parseVectorType(DialectAsmParser &parser) {

bool isScalable = dims.size() == 2;
if (isScalable)
return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]);
return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
if (elementType.isSignlessIntOrFloat()) {
parser.emitError(typePos)
<< "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
return Type();
}
return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]);
return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
}

/// Parses an LLVM dialect array type.
Expand All @@ -274,7 +274,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
SmallVector<int64_t, 1> dims;
llvm::SMLoc sizePos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
dispatchParse(parser, elementType) || parser.parseGreater())
Expand All @@ -285,7 +285,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
return LLVMArrayType();
}

return LLVMArrayType::getChecked(loc, elementType, dims[0]);
return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
}

/// Attempts to set the body of an identified structure type. Reports a parsing
Expand Down
31 changes: 14 additions & 17 deletions mlir/lib/Dialect/Quant/IR/TypeParser.cpp
Expand Up @@ -117,7 +117,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser, Location loc) {
static Type parseAnyType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
Expand Down Expand Up @@ -155,9 +155,8 @@ static Type parseAnyType(DialectAsmParser &parser, Location loc) {
return nullptr;
}

return AnyQuantizedType::getChecked(loc, typeFlags, storageType,
expressedType, storageTypeMin,
storageTypeMax);
return parser.getChecked<AnyQuantizedType>(
typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
}

static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
Expand Down Expand Up @@ -192,7 +191,7 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= float-literal `:` integer-literal
/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
static Type parseUniformType(DialectAsmParser &parser, Location loc) {
static Type parseUniformType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
Expand Down Expand Up @@ -279,14 +278,14 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
if (isPerAxis) {
ArrayRef<double> scalesRef(scales.begin(), scales.end());
ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
return UniformQuantizedPerAxisType::getChecked(
loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
return parser.getChecked<UniformQuantizedPerAxisType>(
typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
quantizedDimension, storageTypeMin, storageTypeMax);
}

return UniformQuantizedType::getChecked(
loc, typeFlags, storageType, expressedType, scales.front(),
zeroPoints.front(), storageTypeMin, storageTypeMax);
return parser.getChecked<UniformQuantizedType>(
typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
storageTypeMin, storageTypeMax);
}

/// Parses an CalibratedQuantizedType.
Expand All @@ -295,7 +294,7 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
/// expressed-spec ::= expressed-type `<` calibrated-range `>`
/// expressed-type ::= `f` integer-literal
/// calibrated-range ::= float-literal `:` float-literal
static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
static Type parseCalibratedType(DialectAsmParser &parser) {
FloatType expressedType;
double min;
double max;
Expand All @@ -314,24 +313,22 @@ static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
return nullptr;
}

return CalibratedQuantizedType::getChecked(loc, expressedType, min, max);
return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
}

/// Parse a type registered to this dialect.
Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());

// All types start with an identifier that we switch on.
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling)))
return nullptr;

if (typeNameSpelling == "uniform")
return parseUniformType(parser, loc);
return parseUniformType(parser);
if (typeNameSpelling == "any")
return parseAnyType(parser, loc);
return parseAnyType(parser);
if (typeNameSpelling == "calibrated")
return parseCalibratedType(parser, loc);
return parseCalibratedType(parser);

parser.emitError(parser.getNameLoc(),
"unknown quantized type " + typeNameSpelling);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Parser/DialectSymbolParser.cpp
Expand Up @@ -524,7 +524,7 @@ Attribute Parser::parseExtendedAttr(Type type) {

// Otherwise, form a new opaque attribute.
return OpaqueAttr::getChecked(
getEncodedSourceLocation(loc),
[&] { return emitError(loc); },
Identifier::get(dialectName, state.context), symbolData,
attrType ? attrType : NoneType::get(state.context));
});
Expand Down Expand Up @@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {

// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
getEncodedSourceLocation(loc),
[&] { return emitError(loc); },
Identifier::get(dialectName, state.context), symbolData);
});
}
Expand Down

0 comments on commit 6bc767c

Please sign in to comment.